import os.path as osp
from glob import glob
import numpy as np
from ...config import configurable
from ...functional import FlowAugmentor
from ..build import DATASET_REGISTRY
from .base_dataset import BaseDataset
[docs]@DATASET_REGISTRY.register()
class FlyingChairs(BaseDataset):
"""
Dataset Class for preparing the Flying Chair Synthetic dataset for training and validation.
Parameters
----------
root_dir : str
path of the root directory for the flying chairs dataset
split : str, default : "training"
specify the training or validation split
is_prediction : bool, default : False
If True, only image data are loaded for prediction otherwise both images and flow data are loaded
init_seed : bool, default : False
If True, sets random seed to worker
append_valid_mask : bool, default : False
If True, appends the valid flow mask to the original flow mask at dim=0
crop: bool, default : True
Whether to perform cropping
crop_size : :obj:`tuple` of :obj:`int`
The size of the image crop
crop_type : :obj:`str`, default : 'center'
The type of croppping to be performed, one of "center", "random"
augment : bool, default : True
If True, applies data augmentation
aug_params : :obj:`dict`, optional
The parameters for data augmentation
norm_params : :obj:`dict`, optional
The parameters for normalization
flow_offset_params: :obj:`dict`, optional
The parameters for adding bilinear interpolated weights surrounding each ground truth flow values.
"""
@configurable
def __init__(
self,
root_dir,
split="training",
is_prediction=False,
init_seed=False,
append_valid_mask=False,
crop=False,
crop_size=(256, 256),
crop_type="center",
augment=True,
aug_params={
"eraser_aug_params": {"enabled": False},
"noise_aug_params": {"enabled": False},
"flip_aug_params": {"enabled": False},
"color_aug_params": {"enabled": False},
"spatial_aug_params": {"enabled": False},
"advanced_spatial_aug_params": {"enabled": False},
},
norm_params={"use": False},
flow_offset_params={"use": False},
):
super(FlyingChairs, self).__init__(
init_seed=init_seed,
is_prediction=is_prediction,
append_valid_mask=append_valid_mask,
crop=crop,
crop_size=crop_size,
crop_type=crop_type,
augment=augment,
aug_params=aug_params,
sparse_transform=False,
norm_params=norm_params,
flow_offset_params=flow_offset_params,
)
assert (
split.lower() == "training" or split.lower() == "validation"
), "Incorrect split values. Accepted split values: training, validation"
self.is_prediction = is_prediction
self.append_valid_mask = append_valid_mask
if augment:
self.augmentor = FlowAugmentor(crop_size=crop_size, **aug_params)
images = sorted(glob(osp.join(root_dir, "*.ppm")))
flows = sorted(glob(osp.join(root_dir, "*.flo")))
assert len(images) // 2 == len(flows)
try:
split_list = np.loadtxt(
osp.join(root_dir, "FlyingChairs_train_val.txt"), dtype=np.int32
)
except OSError:
print("FlyingChairs_train_val.txt was not found in " + root_dir)
exit()
for i in range(len(flows)):
xid = split_list[i]
if (split == "training" and xid == 1) or (
split == "validation" and xid == 2
):
self.flow_list += [flows[i]]
self.image_list += [[images[2 * i], images[2 * i + 1]]]
@classmethod
def from_config(cls, cfg):
return {
"root_dir": cfg.ROOT_DIR,
"split": cfg.SPLIT,
"is_prediction": cfg.IS_PREDICTION,
"init_seed": cfg.INIT_SEED,
"append_valid_mask": cfg.APPEND_VALID_MASK,
"crop": cfg.CROP.USE,
"crop_size": cfg.CROP.SIZE,
"crop_type": cfg.CROP.TYPE,
"augment": cfg.AUGMENTATION.USE,
"aug_params": cfg.AUGMENTATION.PARAMS,
"norm_params": cfg.NORM_PARAMS,
"flow_offset_params": cfg.FLOW_OFFSET_PARAMS,
}