Source code for ezflow.data.dataloader.dataloader_creator

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.distributed import DistributedSampler

from ..dataset import *


[docs]class DataloaderCreator: """ A class to configure a data loader for optical flow datasets. Multiple datasets can be added to configure a data loader for training and validation. Parameters ---------- batch_size : int Number of samples per batch to load pin_memory : bool, default : False If True, the data loader will copy Tensors into CUDA pinned memory before returning them shuffle : bool, default : True If True, data is reshuffled at every epoch num_workers : int, default : 4 Number of subprocesses to use for data loading drop_last : bool, default : True If True, the last incomplete batch is dropped init_seed : bool, default : False If True, sets random seed to worker append_valid_mask : bool, default : False If True, appends the valid flow mask to the original flow mask at dim=0 is_prediction : bool, default : False If True, only image data are loaded for prediction otherwise both images and flow data are loaded distributed : bool, default : False If True, initializes DistributedSampler for Distributed Training world_size : int, default : None The total number of GPU devices per node for Distributed Training """ def __init__( self, batch_size, pin_memory=False, shuffle=True, num_workers=2, drop_last=True, init_seed=False, append_valid_mask=False, is_prediction=False, distributed=False, world_size=1, ): self.dataset_list = [] self.batch_size = batch_size self.pin_memory = pin_memory self.shuffle = shuffle self.num_workers = num_workers self.drop_last = drop_last self.init_seed = init_seed self.append_valid_mask = append_valid_mask self.is_prediction = is_prediction self.distributed = False self.world_size = 1 if distributed: assert ( world_size > 1 ), "world_size must be greater than 1 to perform distributed training" self.distributed = distributed self.world_size = world_size
[docs] def add_FlyingChairs(self, root_dir, split="training", augment=False, **kwargs): """ Adds the Flying Chairs dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the flying chairs dataset split : str, default : "training" specify the training or validation split augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( FlyingChairs( root_dir, split=split, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_FlyingThings3D( self, root_dir, split="training", dstype="frames_cleanpass", augment=False, **kwargs, ): """ Adds the Flying Things 3D dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the flying things 3D dataset in SceneFlow split : str, default : "training" specify the training or validation split dstype : str, default : "frames_cleanpass" specify dataset type augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( FlyingThings3D( root_dir, split=split, dstype=dstype, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_FlyingThings3DSubset( self, root_dir, split="training", augment=False, **kwargs ): """ Adds the Flying Things 3D Subset dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the flying things 3D Subset dataset in SceneFlow split : str, default : "training" specify the training or validation split augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( FlyingThings3DSubset( root_dir, split=split, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_Monkaa(self, root_dir, augment=False, **kwargs): """ Adds the Monkaa dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the Monkaa dataset in SceneFlow augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( Monkaa( root_dir, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_Driving(self, root_dir, augment=False, **kwargs): """ Adds the Driving dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the Driving dataset in SceneFlow augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( Driving( root_dir, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_SceneFlow(self, root_dir, augment=False, **kwargs): """ Adds FlyingThings3D, Driving and Monkaa datasets to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the SceneFlow dataset augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.add_FlyingThings3D( root_dir=root_dir + "/FlyingThings3D", augment=augment, **kwargs ) self.add_Monkaa(root_dir=root_dir + "/Monkaa", augment=augment, **kwargs) self.add_Driving(root_dir=root_dir + "/Driving", augment=augment, **kwargs)
[docs] def add_MPISintel( self, root_dir, split="training", dstype="clean", augment=False, **kwargs ): """ Adds the MPI Sintel dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the MPI Sintel dataset split : str, default : "training" specify the training or validation split dstype : str, default : "clean" specify dataset type augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( MPISintel( root_dir, split=split, dstype=dstype, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_Kitti(self, root_dir, split="training", augment=False, **kwargs): """ Adds the KITTI dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the MPI Sintel dataset split : str, default : "training" specify the training or validation split augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( Kitti( root_dir, split=split, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=True, augment=augment, **kwargs, ) )
[docs] def add_HD1K(self, root_dir, augment=False, **kwargs): """ Adds the HD1K dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the MPI Sintel dataset augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( HD1K( root_dir, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_AutoFlow(self, root_dir, augment=False, **kwargs): """ Adds the AutoFLow dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the Monkaa dataset in SceneFlow augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( AutoFlow( root_dir, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_Kubric(self, root_dir, split="training", augment=False, **kwargs): """ Adds the Kubric dataset to the DataloaderCreator object. Parameters ---------- root_dir : str path of the root directory for the Monkaa dataset in SceneFlow augment : bool, default : True If True, applies data augmentation **kwargs Arbitrary keyword arguments for augmentation specifying crop_size and the probability of color, eraser and spatial transformation """ self.dataset_list.append( Kubric( root_dir, split=split, init_seed=self.init_seed, is_prediction=self.is_prediction, append_valid_mask=self.append_valid_mask, augment=augment, **kwargs, ) )
[docs] def add_dataset(self, dataset): """ Add an optical flow dataset to the DataloaderCreator object. Parameters ---------- dataset : torch.utils.data.Dataset the optical flow dataset """ assert dataset is not None and isinstance( dataset, BaseDataset ), "Invalid dataset type." self.dataset_list.append(dataset)
[docs] def get_dataloader(self, rank=0): """ Gets the Dataloader for the added datasets. Params ------ rank : int, default : 0 The process id within a group for Distributed Training Returns ------- torch.utils.data.DataLoader PyTorch DataLoader object """ assert len(self.dataset_list) != 0, "No datasets were added" dataset = self.dataset_list[0] if len(self.dataset_list) > 1: for i in range(len(self.dataset_list) - 1): dataset += self.dataset_list[i + 1] if self.distributed: sampler = DistributedSampler( dataset, rank=rank, num_replicas=self.world_size, shuffle=self.shuffle, drop_last=self.drop_last, ) data_loader = DataLoader( dataset, batch_size=self.batch_size // self.world_size, pin_memory=self.pin_memory, num_workers=self.num_workers, sampler=sampler, ) else: data_loader = DataLoader( dataset, batch_size=self.batch_size, pin_memory=self.pin_memory, shuffle=self.shuffle, num_workers=self.num_workers, drop_last=self.drop_last, ) total_samples = len(data_loader) * (self.batch_size // self.world_size) print( f"Total image pairs loaded: {total_samples}/{len(dataset)} in device: {rank}" ) return data_loader