Source code for ezflow.functional.criterion.multiscale

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...config import configurable
from ..registry import FUNCTIONAL_REGISTRY

[docs]@FUNCTIONAL_REGISTRY.register() class MultiScaleLoss(nn.Module): """ Multi-scale loss for optical flow estimation. Used in **DICL** ( Parameters ---------- norm : str, default: "l1" The norm to use for the loss. Can be either "l2", "l1" or "robust" q : float, default: 0.4 This parameter is used in robust loss for fine tuning. q < 1 gives less penalty to outliers eps : float, default: 0.01 This parameter is a small constant used in robust loss to stabilize fine tuning. weights : list The weights to use for each scale average : str, default: "mean" The mode to set the average of the EPE map. If "mean", the mean of the EPE map is returned. If "sum", the EPE map is summed and divided by the batch size. resize_flow : str, default: "upsample" The mode to resize flow. If "upsample", predicted flow will be upsampled to the size of the ground truth. If "downsample", ground truth flow will be downsampled to the size of the predicted flow. extra_mask : torch.Tensor A mask to apply to the loss. Useful for removing the loss on the background use_valid_range : bool Whether to use the valid range of flow values for the loss valid_range : list The valid range of flow values for each scale """ @configurable def __init__( self, norm="l1", q=0.4, eps=1e-2, weights=(1, 0.5, 0.25), average="mean", resize_flow="upsample", extra_mask=None, use_valid_range=True, valid_range=None, **kwargs ): super(MultiScaleLoss, self).__init__() assert norm.lower() in ( "l1", "l2", "robust", ), "Norm must be one of L1, L2, Robust" assert resize_flow.lower() in ( "upsample", "downsample", ), "Resize flow must be one of upsample or downsample" assert average.lower() in ("mean", "sum"), "Average must be one of mean or sum" self.norm = norm.lower() self.q = q self.eps = eps self.weights = weights self.extra_mask = extra_mask self.use_valid_range = use_valid_range self.valid_range = valid_range self.average = average.lower() self.resize_flow = resize_flow.lower() @classmethod def from_config(cls, cfg): return { "norm": cfg.NORM, "weights": cfg.WEIGHTS, "average": cfg.AVERAGE, "resize_flow": cfg.RESIZE_FLOW, "extra_mask": cfg.EXTRA_MASK, "use_valid_range": cfg.USE_VALID_RANGE, "valid_range": cfg.VALID_RANGE, }
[docs] def forward(self, flow_preds, flow_gt, **kwargs): loss = 0 b, c, h, w = flow_gt.size() if ( (type(flow_preds) is not tuple) and (type(flow_preds) is not list) and (type(flow_preds) is not set) ): flow_preds = {flow_preds} # detect NaN nan_mask = (~torch.isnan(flow_gt)).float() flow_gt[torch.isnan(flow_gt)] = 0.0 for i, level_pred in enumerate(flow_preds): if self.resize_flow.lower() == "upsample": real_flow = F.interpolate( level_pred, (h, w), mode="bilinear", align_corners=True ) real_flow[:, 0, :, :] = real_flow[:, 0, :, :] * ( w / level_pred.shape[3] ) real_flow[:, 1, :, :] = real_flow[:, 1, :, :] * ( h / level_pred.shape[2] ) target = flow_gt elif self.resize_flow.lower() == "downsample": # down sample ground truth following irr solution # b, c, h, w = level_pred.shape target = F.adaptive_avg_pool2d(flow_gt, [h, w]) real_flow = level_pred if self.norm == "l2": loss_value = torch.norm(real_flow - target, p=2, dim=1) elif self.norm == "robust": loss_value = torch.norm(real_flow - target, p=1, dim=1) loss_value = (loss_value + self.eps) ** self.q elif self.norm == "l1": loss_value = torch.norm(real_flow - target, p=1, dim=1) if self.use_valid_range and self.valid_range is not None: with torch.no_grad(): mask = (target[:, 0, :, :].abs() <= self.valid_range[i][1]) & ( target[:, 1, :, :].abs() <= self.valid_range[i][0] ) else: with torch.no_grad(): mask = torch.ones(target[:, 0, :, :].shape).type_as(target) loss_value = loss_value * mask.float() if self.extra_mask is not None: val = self.extra_mask > 0 loss_value = loss_value[val] if self.average.lower() == "mean": level_loss = loss_value.mean() * self.weights[i] elif self.average.lower() == "sum": level_loss = loss_value.sum() / b * self.weights[i] loss += level_loss loss = loss / len(flow_preds) return loss