Source code for ezflow.decoder.noniterative.soft_regression

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...config import configurable
from ..build import DECODER_REGISTRY


[docs]@DECODER_REGISTRY.register() class SoftArg2DFlowRegression(nn.Module): """ Applies 2D soft argmin/argmax operation to regress flow. Used in **DICL** (https://arxiv.org/abs/2010.14851) Parameters ---------- max_u : int, default : 3 Maximum displacement in the horizontal direction max_v : int, default : 3 Maximum displacement in the vertical direction operation : str, default : argmax The argmax/argmin operation for flow regression """ @configurable def __init__(self, max_u=3, max_v=3, operation="argmax"): super(SoftArg2DFlowRegression, self).__init__() assert ( operation.lower() == "argmax" or operation.lower() == "argmin" ), "Invalid operation. Supported operations: argmax and argmin" self.max_u = max_u self.max_v = max_v self.operation = operation.lower() @classmethod def from_config(cls, cfg): return { "max_u": cfg.MAX_U, "max_v": cfg.MAX_V, "operation": cfg.OPERATION, }
[docs] def forward(self, x): """ Performs forward pass. Parameters ---------- x : torch.Tensor Input feature map Returns ------- torch.Tensor A tensor of shape N x 2 x H x W representing the flow """ sizeU = 2 * self.max_u + 1 sizeV = 2 * self.max_v + 1 x = x.squeeze(1) B, _, _, H, W = x.shape disp_u = torch.reshape( torch.arange( -self.max_u, self.max_u + 1, dtype=torch.float32, ), [1, sizeU, 1, 1, 1], ).to(x.device) disp_u = disp_u.expand(B, -1, sizeV, H, W).contiguous() disp_u = disp_u.view(B, sizeU * sizeV, H, W) disp_v = torch.reshape( torch.arange( -self.max_v, self.max_v + 1, dtype=torch.float32, ), [1, 1, sizeV, 1, 1], ).to(x.device) disp_v = disp_v.expand(B, sizeU, -1, H, W).contiguous() disp_v = disp_v.view(B, sizeU * sizeV, H, W) x = x.view(B, sizeU * sizeV, H, W) if self.operation == "argmin": x = F.softmin(x, dim=1) else: x = F.softmax(x, dim=1) flow_u = (x * disp_u).sum(dim=1) flow_v = (x * disp_v).sum(dim=1) flow = torch.cat((flow_u.unsqueeze(1), flow_v.unsqueeze(1)), dim=1) return flow
[docs]@DECODER_REGISTRY.register() class Soft4DFlowRegression(nn.Module): """ Applies 4D soft argmax operation to regress flow. Parameters ---------- size : List[int] List containing values of B, H, W max_disp : int, default : 4 Maximum displacement entropy : bool, default : False If True, computes local and global entropy from matching cost factorization : int, default : 1 Max displacement factorization value """ @configurable def __init__(self, size, max_disp=4, entropy=False, factorization=1): super(Soft4DFlowRegression, self).__init__() B, H, W = size self.entropy = entropy self.md = max_disp self.factorization = factorization self.truncated = True self.w_size = 3 flowrange_y = range(-max_disp, max_disp + 1) flowrange_x = range( -int(max_disp // self.factorization), int(max_disp // self.factorization) + 1, ) meshgrid = np.meshgrid(flowrange_x, flowrange_y) flow_y = np.tile( np.reshape( meshgrid[0], [ 1, 2 * max_disp + 1, 2 * int(max_disp // self.factorization) + 1, 1, 1, ], ), (B, 1, 1, H, W), ) flow_x = np.tile( np.reshape( meshgrid[1], [ 1, 2 * max_disp + 1, 2 * int(max_disp // self.factorization) + 1, 1, 1, ], ), (B, 1, 1, H, W), ) self.register_buffer("flow_x", torch.Tensor(flow_x)) self.register_buffer("flow_y", torch.Tensor(flow_y)) self.pool3d = nn.MaxPool3d( (self.w_size * 2 + 1, self.w_size * 2 + 1, 1), stride=1, padding=(self.w_size, self.w_size, 0), ) @classmethod def from_config(cls, cfg): return { "size": cfg.SIZE, "max_disp": cfg.MAX_DISP, "entropy": cfg.ENTROPY, "factorization": cfg.FACTORIZATION, }
[docs] def forward(self, x): """ Performs forward pass. Parameters ---------- x : torch.Tensor Input cost feature map of shape B x U x V x H x W Returns ------- torch.Tensor A tensor of shape B x C x H x W representing the flow torch.Tensor A tensor representing the local and global entropy cost """ B, U, V, H, W = x.shape orig_x = x if self.truncated: # truncated softmax x = x.view(B, U * V, H, W) idx = x.argmax(1)[:, np.newaxis] mask = torch.FloatTensor(B, U * V, H, W).fill_(0).to(x.device) mask.scatter_(1, idx, 1) mask = mask.view(B, 1, U, V, -1) mask = self.pool3d(mask)[:, 0].view(B, U, V, H, W) n_inf = x.clone().fill_(-np.inf).view(B, U, V, H, W) x = torch.where(mask.byte(), orig_x, n_inf) else: self.w_size = (np.sqrt(U * V) - 1) / 2 B, U, V, H, W = x.shape x = F.softmax(x.view(B, -1, H, W), 1).view(B, U, V, H, W) out_x = torch.sum(torch.sum(x * self.flow_x, 1), 1, keepdim=True) out_y = torch.sum(torch.sum(x * self.flow_y, 1), 1, keepdim=True) if self.entropy: # local local_entropy = ( (-x * torch.clamp(x, 1e-9, 1 - 1e-9).log()).sum(1).sum(1)[:, np.newaxis] ) if self.w_size == 0: local_entropy[:] = 1.0 else: local_entropy /= np.log((self.w_size * 2 + 1) ** 2) # global x = F.softmax(orig_x.view(B, -1, H, W), 1).view(B, U, V, H, W) global_entropy = ( (-x * torch.clamp(x, 1e-9, 1 - 1e-9).log()).sum(1).sum(1)[:, np.newaxis] ) global_entropy /= np.log(x.shape[1] * x.shape[2]) return torch.cat([out_x, out_y], 1), torch.cat( [local_entropy, global_entropy], 1 ) else: return torch.cat([out_x, out_y], 1), None