Source code for ezflow.engine.retrieve

"""
Adapted from Detectron2 (https://github.com/facebookresearch/detectron2)
"""

from ..config import get_cfg


class _TrainerConfigs:
    """
    Container class for training configurations.
    """

    NAME_TO_TRAINER_CONFIG = {
        "BASE": "base.yaml",
        "RAFT": "raft_default.yaml",
        "DICL": "dicl_default.yaml",
    }

    @staticmethod
    def query(trainer_name):

        trainer_name = trainer_name.upper()

        if trainer_name in _TrainerConfigs.NAME_TO_TRAINER_CONFIG:

            cfg_file = _TrainerConfigs.NAME_TO_TRAINER_CONFIG[trainer_name]
            return cfg_file

        raise ValueError(
            f"Trainer name '{trainer_name}' not found in the training configs"
        )


[docs]def get_training_cfg(cfg_path=None, cfg_name=None, custom=True): """ Parameters ---------- cfg_path : str Path to the config file. cfg_name : str Name of the config file. custom : bool If True, the config file is assumed to be a custom config file. If False, the config file is assumed to be a standard config file present in ezflow/configs/trainers. Returns ------- cfg : CfgNode The config object """ assert ( cfg_path is not None or cfg_name is not None ), "Either cfg_path or cfg_name must be provided" if cfg_path is None: cfg_path = _TrainerConfigs.query(cfg_name) return get_cfg(cfg_path, custom=False, grp="trainers") return get_cfg(cfg_path, custom=custom, grp="trainers")