Source code for ezflow.engine.trainer

import os
import time
from copy import deepcopy
from datetime import timedelta

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

from ezflow.data import DataloaderCreator

from ..functional import FUNCTIONAL_REGISTRY
from ..utils import AverageMeter, endpointerror, find_free_port, is_port_available
from .registry import loss_functions, optimizers, schedulers


def seed(seed):

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True


class BaseTrainer:
    def __init__(self):
        self.cfg = None

        self.model = None

        self.loss_fn = None
        self.optimizer = None
        self.scheduler = None
        self.scaler = None

        self.train_loader = None
        self.val_loader = None

        self.device = None
        self._trainer = None

        self.model_parallel = False

        self.writer = None
        self.times = []

    def _setup_device(self):
        raise NotImplementedError

    def _setup_model(self):
        raise NotImplementedError

    def _is_main_process(self):
        raise NotImplementedError

    def _setup_training(self, rank=0, loss_fn=None, optimizer=None, scheduler=None):
        self._trainer = self._epoch_trainer
        if self.cfg.NUM_STEPS is not None:
            self._trainer = self._step_trainer
            max_iter = self.cfg.NUM_STEPS
        else:
            max_iter = self.cfg.EPOCHS * len(self.train_loader)

        if loss_fn is None and self.loss_fn is None:

            if self.cfg.CRITERION.CUSTOM:
                loss = FUNCTIONAL_REGISTRY.get(self.cfg.CRITERION.NAME)
            else:
                loss = loss_functions.get(self.cfg.CRITERION.NAME)

            if self.cfg.CRITERION.PARAMS is not None:
                loss_params = self.cfg.CRITERION.PARAMS.to_dict()
                loss_params["max_iter"] = max_iter
                loss_fn = loss(**loss_params)
            else:
                loss_fn = loss()

            print(f"Loss function: {self.cfg.CRITERION.NAME} is initialized!")

        if optimizer is None and self.optimizer is None:

            opt = optimizers.get(self.cfg.OPTIMIZER.NAME)

            if self.cfg.OPTIMIZER.PARAMS is not None:
                optimizer_params = self.cfg.OPTIMIZER.PARAMS.to_dict()
                optimizer = opt(
                    self.model.parameters(),
                    lr=self.cfg.OPTIMIZER.LR,
                    **optimizer_params,
                )
            else:
                optimizer = opt(self.model.parameters(), lr=self.cfg.OPTIMIZER.LR)

            print(f"Optimizer: {self.cfg.OPTIMIZER.NAME} is initialized!")

        if scheduler is None and self.scheduler is None:

            if self.cfg.SCHEDULER.USE:
                sched = schedulers.get(self.cfg.SCHEDULER.NAME)

                if self.cfg.SCHEDULER.PARAMS is not None:
                    scheduler_params = self.cfg.SCHEDULER.PARAMS.to_dict()
                    if "epochs" in scheduler_params:
                        scheduler_params["steps_per_epoch"] = len(self.train_loader)

                    scheduler = sched(optimizer, **scheduler_params)
                else:
                    scheduler = sched(optimizer)

                print(f"Scheduler: {self.cfg.SCHEDULER.NAME} is initialized!")

        if self.loss_fn is None:
            self.loss_fn = loss_fn

        if self.optimizer is None:
            self.optimizer = optimizer

        if self.scheduler is None:
            self.scheduler = scheduler

        if rank == 0:
            """
            Initialize Tensorboard SummyWriter only for main process
            """
            self.writer = SummaryWriter(log_dir=self.cfg.LOG_DIR)

        self.scaler = GradScaler(enabled=self.cfg.MIXED_PRECISION)

        self.min_avg_val_loss = float("inf")
        self.min_avg_val_metric = float("inf")

    def _freeze_bn(self):
        if self.cfg.FREEZE_BATCH_NORM:
            if self.model_parallel:
                self.model.module.freeze_batch_norm()
            else:
                self.model.freeze_batch_norm()

    def _epoch_trainer(self, n_epochs=None, start_epoch=None):
        self.model.train()
        self._freeze_bn()

        loss_meter = AverageMeter()

        if n_epochs is None:
            n_epochs = self.cfg.EPOCHS

        if start_epoch is not None:
            print(f"Resuming training from epoch {start_epoch+1}\n")
        else:
            start_epoch = 0

        for epoch in range(start_epoch, start_epoch + n_epochs):

            print(f"\nEpoch {epoch+1} of {start_epoch+n_epochs}")
            print("-" * 80)

            if self.model_parallel:
                self.train_loader.sampler.set_epoch(epoch)

            loss_meter.reset()
            for iteration, (inp, target) in enumerate(self.train_loader):
                total_iters = iteration + (epoch * len(self.train_loader))

                loss = self._run_step(inp, target, current_iter=total_iters)

                loss_meter.update(loss.item())
                self._log_step(iteration, total_iters, loss_meter)

            print(f"\nEpoch {epoch+1}: Average Training loss = {loss_meter.avg}")

            if self._is_main_process():
                self.writer.add_scalar(
                    "avg_epochs_training_loss", loss_meter.avg, epoch + 1
                )

            if (
                epoch + 1
            ) % self.cfg.VALIDATE_INTERVAL == 0 and self._is_main_process():
                self._validate_model(
                    iter_type="Epoch", iterations=epoch + 1, current_iter=0, logger=None
                )

            if (epoch + 1) % self.cfg.CKPT_INTERVAL == 0 and self._is_main_process():
                self._save_checkpoints(ckpt_type="epoch", ckpt_number=epoch + 1)

            # Synchronize all processes in multi gpu after validation and checkpoint
            if (
                (epoch + 1) % self.cfg.VALIDATE_INTERVAL == 0
                or (epoch + 1) % self.cfg.CKPT_INTERVAL == 0
            ) and self.model_parallel:
                dist.barrier()

        if self._is_main_process():
            self.writer.close()

    def _step_trainer(self, n_steps=None, start_step=None):
        self.model.train()
        self._freeze_bn()

        loss_meter = AverageMeter()

        total_steps = 0

        if n_steps is None:
            n_steps = self.cfg.NUM_STEPS

        if start_step is not None:
            print(f"Resuming training from step {start_step}\n")
            total_steps = start_step
            n_steps += start_step - 1
        else:
            start_step = total_steps = 1
            n_steps += start_step

        if self.model_parallel:
            epoch = 0
            self.train_loader.sampler.set_epoch(epoch)

        train_iter = iter(self.train_loader)

        print(f"\nStarting step {total_steps} of {n_steps}")
        print("-" * 80)
        for step in range(start_step, n_steps):
            try:
                inp, target = next(train_iter)
            except:
                if self.model_parallel:
                    epoch += 1
                    self.train_loader.sampler.set_epoch(epoch)

                # Handle exception if there is no data
                # left in train iterator to continue training.
                train_iter = iter(self.train_loader)
                inp, target = next(train_iter)

            loss = self._run_step(inp, target, current_iter=step)
            loss_meter.update(loss.item())

            self._log_step(step, total_steps, loss_meter)

            if step % self.cfg.VALIDATE_INTERVAL == 0 and self._is_main_process():
                self._validate_model(iter_type="Iteration", iterations=total_steps)
                print("-" * 80)

            if step % self.cfg.CKPT_INTERVAL == 0 and self._is_main_process():
                self._save_checkpoints(ckpt_type="step", ckpt_number=total_steps)

            # Synchronize all processes in multi gpu after validation and checkpoint
            if (
                step % self.cfg.VALIDATE_INTERVAL == 0
                or step % self.cfg.CKPT_INTERVAL == 0
            ) and self.model_parallel:
                dist.barrier()

            total_steps += 1

        if self._is_main_process():
            self.writer.close()

    def _run_step(self, inp, target, **kwargs):
        inp, target = self._to_device(inp, target)
        img1, img2 = inp

        if self._is_main_process():
            start_time = time.time()

        with autocast(enabled=self.cfg.MIXED_PRECISION):
            output = self.model(img1, img2)
            loss = self.loss_fn(**output, **target, **kwargs)

            del output

        self.optimizer.zero_grad()
        self.scaler.scale(loss).backward()
        self.scaler.unscale_(self.optimizer)

        if self.cfg.GRAD_CLIP.USE is True:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.GRAD_CLIP.VALUE)

        self.scaler.step(self.optimizer)

        if self.scheduler is not None:
            self.scheduler.step()

        self.scaler.update()

        if self._is_main_process():
            self.times.append(time.time() - start_time)

        return loss

    def _to_device(self, inp, target):
        img1, img2 = inp
        inp = (img1.to(self.device), img2.to(self.device))

        for key, val in target.items():
            target[key] = val.to(self.device)

        target["flow_gt"] / self.cfg.TARGET_SCALE_FACTOR
        return inp, target

    def _log_step(self, iteration, total_iters, loss_meter):
        if iteration % self.cfg.LOG_ITERATIONS_INTERVAL == 0:
            print(
                f"[{iteration} / {total_iters}] iterations, batch training loss: {loss_meter.val}"
            )
            if self._is_main_process():
                self.writer.add_scalar(
                    "batch_training_loss",
                    loss_meter.val,
                    total_iters,
                )

    def _validate_model(self, iter_type, iterations, **kwargs):
        self.model.eval()
        metric_meter = AverageMeter()
        loss_meter = AverageMeter()

        with torch.no_grad():
            for inp, target in self.val_loader:
                inp, target = self._to_device(inp, target)
                img1, img2 = inp

                if self.model_parallel:
                    output = self.model.module(img1, img2)
                else:
                    output = self.model(img1, img2)

                loss = self.loss_fn(**output, **target, **kwargs)

                loss_meter.update(loss.item())

                metric = self._calculate_metric(output, target)
                metric_meter.update(metric)

                del output

        new_avg_val_loss, new_avg_val_metric = loss_meter.avg, metric_meter.avg

        print("\n", "-" * 80)
        self.writer.add_scalar("avg_validation_loss", new_avg_val_loss, iterations)
        self.writer.add_scalar("avg_validation_metric", new_avg_val_metric, iterations)

        print(
            f"\n{iter_type} {iterations}: Average validation loss = {new_avg_val_loss}"
        )

        print(
            f"{iter_type} {iterations}: Average validation metric = {new_avg_val_metric}\n"
        )
        print("-" * 80, "\n")

        self._save_best_model(new_avg_val_loss, new_avg_val_metric)

        self.model.train()
        self._freeze_bn()

    def _calculate_metric(self, pred, target):
        """
        Predicted upsampled flow should be scaled for EPE calculation.
        """
        flow_pred = pred["flow_upsampled"] * self.cfg.TARGET_SCALE_FACTOR
        flow_gt = target["flow_gt"]
        return endpointerror(flow_pred, flow_gt)

    def _save_checkpoints(self, ckpt_type, ckpt_number):
        if self.model_parallel:
            save_model = self.model.module
        else:
            save_model = self.model

        consolidated_save_dict = {
            "model_state_dict": save_model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            ckpt_type: ckpt_number,
        }
        if self.scheduler is not None:
            consolidated_save_dict["scheduler_state_dict"] = self.scheduler.state_dict()

        torch.save(
            consolidated_save_dict,
            os.path.join(
                self.cfg.CKPT_DIR,
                self.model_name + "_" + ckpt_type + str(ckpt_number) + ".pth",
            ),
        )

    def _save_best_model(self, new_avg_val_loss, new_avg_val_metric):
        if new_avg_val_loss < self.min_avg_val_loss:

            self.min_avg_val_loss = new_avg_val_loss
            print("\nNew minimum average validation loss!")

            if self.cfg.VALIDATE_ON.lower() == "loss":
                best_model = deepcopy(self.model)
                save_best_model = (
                    best_model.module if self.model_parallel else best_model
                )
                torch.save(
                    save_best_model.state_dict(),
                    os.path.join(self.cfg.CKPT_DIR, self.model_name + "_best.pth"),
                )
                print(f"Saved new best model!\n")

        if new_avg_val_metric < self.min_avg_val_metric:

            self.min_avg_val_metric = new_avg_val_metric
            print("\nNew minimum average validation metric!")

            if self.cfg.VALIDATE_ON.lower() == "metric":
                best_model = deepcopy(self.model)
                save_best_model = (
                    best_model.module if self.model_parallel else best_model
                )
                torch.save(
                    save_best_model.state_dict(),
                    os.path.join(self.cfg.CKPT_DIR, self.model_name + "_best.pth"),
                )
                print(f"Saved new best model!\n")

    def _reload_trainer_states(
        self,
        consolidated_ckpt=None,
        model_ckpt=None,
        optimizer_ckpt=None,
        total_iterations=None,
        start_iteration=None,
        scheduler_ckpt=None,
        use_cfg=False,
    ):

        self._setup_device()

        consolidated_ckpt = (
            self.cfg.RESUME_TRAINING.CONSOLIDATED_CKPT
            if use_cfg is True
            else consolidated_ckpt
        )

        if consolidated_ckpt is not None:

            ckpt = torch.load(consolidated_ckpt, map_location=self.device)

            model_state_dict = ckpt["model_state_dict"]
            optimizer_state_dict = ckpt["optimizer_state_dict"]

            if "scheduler_state_dict" in ckpt.keys():
                scheduler_state_dict = ckpt["scheduler_state_dict"]

            if "epochs" in ckpt.keys():
                start_iteration = ckpt["epochs"] + 1

            if "step" in ckpt.keys():
                start_iteration = ckpt["step"] + 1

        else:

            assert (
                model_ckpt is not None and optimizer_ckpt is not None
            ), "Must provide a consolidated ckpt or model and optimizer ckpts separately"

            model_state_dict = torch.load(model_ckpt, map_location=self.device)
            optimizer_state_dict = torch.load(optimizer_ckpt, map_location=self.device)

            if scheduler_ckpt is not None:
                scheduler_state_dict = torch.load(
                    scheduler_ckpt, map_location=self.device
                )

        self._setup_model()
        self.model.load_state_dict(model_state_dict)
        print("Model state loaded!!")

        self._setup_training()

        self.optimizer.load_state_dict(optimizer_state_dict)
        print("Optimizer state loaded!!")

        if self.scheduler is not None:
            self.scheduler.load_state_dict(scheduler_state_dict)
            print("Scheduler state loaded!!")

        if total_iterations is None and use_cfg:
            total_iterations = (
                self.cfg.RESUME_TRAINING.NUM_STEPS
                if self.cfg.RESUME_TRAINING.NUM_STEPS is not None
                else self.cfg.RESUME_TRAINING.EPOCHS
            )

        if start_iteration is None and use_cfg:
            start_iteration = (
                self.cfg.RESUME_TRAINING.START_STEP
                if self.cfg.RESUME_TRAINING.START_STEP is not None
                else self.cfg.RESUME_TRAINING.START_EPOCH
            )

        return (total_iterations, start_iteration)

    def resume_training(
        self,
        consolidated_ckpt=None,
        model_ckpt=None,
        optimizer_ckpt=None,
        total_iterations=None,
        start_iteration=None,
        scheduler_ckpt=None,
        use_cfg=False,
    ):

        """
        Method to resume training of a model
        Parameters
        ----------
        consolidated_ckpt : str, optional
            The path to the consolidated checkpoint file. Defaults to None (which uses the consolidated checkpoint file specified in the config file).
        model_ckpt : str, optional
            The path to the model checkpoint file. Defaults to None (which uses the model checkpoint file specified in the config file).
        optimizer_ckpt : str, optional
            The path to the optimizer checkpoint file. Defaults to None (which uses the optimizer checkpoint file specified in the config file).
        total_iterations : int, optional
            The number of epochs or steps to train for. Defaults to None (which uses the number of epochs specified in the config file)
        start_iteration : int, optional
            The epoch or step number to resume training from. Defaults to None (which starts from 0).
        scheduler_ckpt : str, optional
            The path to the scheduler checkpoint file. Defaults to None (which uses the scheduler checkpoint file specified in the config file).
        use_cfg : bool, optional
            Whether to use the config file or not. Defaults to False.
        """

        total_iterations, start_iteration = self._reload_trainer_states(
            consolidated_ckpt=consolidated_ckpt,
            model_ckpt=model_ckpt,
            optimizer_ckpt=optimizer_ckpt,
            total_iterations=total_iterations,
            start_iteration=start_iteration,
            scheduler_ckpt=scheduler_ckpt,
            use_cfg=use_cfg,
        )

        os.makedirs(self.cfg.CKPT_DIR, exist_ok=True)
        os.makedirs(self.cfg.LOG_DIR, exist_ok=True)

        print("Training config:\n")
        print(self.cfg)
        print("-" * 80)

        self._trainer(total_iterations, start_iteration)

        print("Training complete!")
        print(f"Total training time: {str(timedelta(seconds=sum(self.times)))}")


[docs]class Trainer(BaseTrainer): """ Trainer class for training and evaluating models on a single device CPU/GPU. Parameters ---------- cfg : CfgNode Configuration object for training model : torch.nn.Module Model to be trained train_loader_creator : ezflow.data.DataloaderCreator DataloaderCreator instance for training val_loader_creator : ezflow.data.DataloaderCreator DataloaderCreator instance for validation """ def __init__( self, cfg, model, train_loader_creator: DataloaderCreator, val_loader_creator: DataloaderCreator, ): super(Trainer, self).__init__() self.cfg = cfg self.model_name = model.__class__.__name__.lower() self.model = model train_loader_creator.distributed = False val_loader_creator.distributed = False self.train_loader = train_loader_creator.get_dataloader() self.val_loader = val_loader_creator.get_dataloader() def _setup_device(self): if ( isinstance(self.cfg.DEVICE, str) and self.cfg.DEVICE.lower() == "cpu" ) or int(self.cfg.DEVICE) == -1: self.device = torch.device("cpu") self.cfg.MIXED_PRECISION = False print("Running on CPU\n") elif not torch.cuda.is_available(): self.device = torch.device("cpu") self.cfg.MIXED_PRECISION = False print("CUDA device(s) not available. Running on CPU\n") else: self.device = torch.device(int(self.cfg.DEVICE)) torch.cuda.empty_cache() seed(0) def _setup_model(self): self.model = self.model.to(self.device) def _is_main_process(self): return True
[docs] def train( self, loss_fn=None, optimizer=None, scheduler=None, total_epochs=None, start_epoch=None, ): """ Method to train the model using a single cpu/gpu device. Parameters ---------- loss_fn : torch.nn.modules.loss, optional The loss function to be used. Defaults to None (which uses the loss function specified in the config file). optimizer : torch.optim.Optimizer, optional The optimizer to be used. Defaults to None (which uses the optimizer specified in the config file). scheduler : torch.optim.lr_scheduler, optional The learning rate scheduler to be used. Defaults to None (which uses the scheduler specified in the config file). total_epochs : int, optional The number of epochs train for. Defaults to None (which uses the number of epochs specified in the config file) start_epoch : int, optional The epoch to resume training from. Defaults to None (which starts from 0). """ self._setup_device() self._setup_model() self._setup_training( rank=0, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler ) os.makedirs(self.cfg.CKPT_DIR, exist_ok=True) os.makedirs(self.cfg.LOG_DIR, exist_ok=True) print("Training config:\n") print(self.cfg) print("-" * 80) self._trainer(total_epochs, start_epoch) print("Training complete!") print(f"Total training time: {str(timedelta(seconds=sum(self.times)))}")
[docs]class DistributedTrainer(BaseTrainer): """ Trainer class for distributed training and evaluating models on a single node multi-gpu environment. Parameters ---------- cfg : CfgNode Configuration object for training model : torch.nn.Module Model to be trained train_loader_creator : ezflow.data.DataloaderCreator DataloaderCreator instance for training val_loader_creator : ezflow.data.DataloaderCreator DataloaderCreator instance for validation """ def __init__( self, cfg, model, train_loader_creator: DataloaderCreator, val_loader_creator: DataloaderCreator, ): super(DistributedTrainer, self).__init__() self.model_parallel = True self.cfg = cfg self.model_name = model.__class__.__name__.lower() self.model = model self.local_rank = None self.device_ids = None self.train_loader = None self.val_loader = None self.train_loader_creator = train_loader_creator # Validate model only on the main process. val_loader_creator.distributed = False self.val_loader = val_loader_creator.get_dataloader() self._validate_ddp_config() def _validate_ddp_config(self): if self.cfg.DEVICE != "all": """ Set CUDA_VISIBLE_DEVICES before performing any torch.cuda operations. """ device = self.cfg.DEVICE if type(device) != str: device = str(device) os.environ["CUDA_VISIBLE_DEVICES"] = device device_ids = device.split(",") device_ids = [int(id) for id in device_ids] assert ( len(device_ids) <= torch.cuda.device_count() ), "Total devices cannot be greater than available CUDA devices." self.device_ids = device_ids print(f"\nRunning on devices: {self.device_ids}\n") assert self.cfg.DISTRIBUTED.WORLD_SIZE <= torch.cuda.device_count(), ( "WORLD_SIZE cannot be greater than available CUDA devices. " f"Given WORLD_SIZE: {self.cfg.DISTRIBUTED.WORLD_SIZE} " f"but total CUDA devices available: {torch.cuda.device_count()}" ) if not is_port_available(int(self.cfg.DISTRIBUTED.MASTER_PORT)): print( f"\nPort: {self.cfg.DISTRIBUTED.MASTER_PORT} is not available to use!" ) free_port = find_free_port() print(f"Assigning free port: {free_port}\n") self.cfg.DISTRIBUTED.MASTER_PORT = free_port def _setup_device(self, rank): assert ( torch.cuda.is_available() ), "CUDA devices are not available. Use ezflow.Trainer for single device training." self.device = torch.device(rank) self.local_rank = rank torch.cuda.empty_cache() torch.cuda.set_device(rank) def _setup_ddp(self, rank): os.environ["MASTER_ADDR"] = self.cfg.DISTRIBUTED.MASTER_ADDR os.environ["MASTER_PORT"] = self.cfg.DISTRIBUTED.MASTER_PORT seed(0) dist.init_process_group( backend=self.cfg.DISTRIBUTED.BACKEND, init_method="env://", world_size=self.cfg.DISTRIBUTED.WORLD_SIZE, rank=rank, ) print(f"{rank + 1}/{self.cfg.DISTRIBUTED.WORLD_SIZE} process initialized.") # synchronizes all the threads to reach this point before moving on dist.barrier() def _is_main_process(self): return self.local_rank == 0 def _setup_model(self, rank): if self.cfg.DISTRIBUTED.SYNC_BATCH_NORM: self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = DDP( self.model.cuda(rank), device_ids=[rank], ) self.model = self.model.to(self.device) def _cleanup(self): dist.destroy_process_group() def _main_worker( self, rank, loss_fn=None, optimizer=None, scheduler=None, total_epochs=None, start_epoch=None, ): self._setup_device(rank) self._setup_ddp(rank) self._setup_model(rank) self.train_loader = self.train_loader_creator.get_dataloader(rank=rank) self._setup_training( rank=rank, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler ) os.makedirs(self.cfg.CKPT_DIR, exist_ok=True) os.makedirs(self.cfg.LOG_DIR, exist_ok=True) # synchronizes all the threads to reach this point before moving on dist.barrier() self._trainer(total_epochs, start_epoch) if self._is_main_process(): print("\nTraining complete!") print(f"Total training time: {str(timedelta(seconds=sum(self.times)))}") self._cleanup()
[docs] def train( self, loss_fn=None, optimizer=None, scheduler=None, total_epochs=None, start_epoch=None, ): """ Method to train the model in a distributed fashion using DDP Parameters ---------- loss_fn : torch.nn.modules.loss, optional The loss function to be used. Defaults to None (which uses the loss function specified in the config file). optimizer : torch.optim.Optimizer, optional The optimizer to be used. Defaults to None (which uses the optimizer specified in the config file). scheduler : torch.optim.lr_scheduler, optional The learning rate scheduler to be used. Defaults to None (which uses the scheduler specified in the config file). total_epochs : int, optional The number of epochs to train for. Defaults to None (which uses the number of epochs specified in the config file) start_epoch : int, optional The epoch number to resume training from. Defaults to None (which starts from 0). """ print("Training config:\n") print(self.cfg) print("-" * 80) print("\nPerforming distributed training\n") print("-" * 80) mp.spawn( self._main_worker, args=(loss_fn, optimizer, scheduler, total_epochs, start_epoch), nprocs=self.cfg.DISTRIBUTED.WORLD_SIZE, join=True, )