Source code for ezflow.models.dicl

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..decoder import FlowEntropy, build_decoder
from ..encoder import build_encoder
from ..modules import BaseModule, ConvNormRelu, build_module
from ..similarity import build_similarity
from ..utils import warp
from .build import MODEL_REGISTRY


[docs]@MODEL_REGISTRY.register() class DICL(BaseModule): """ Implementation of the paper `Displacement-Invariant Matching Cost Learning for Accurate Optical Flow Estimation <https://arxiv.org/abs/2010.14851>`_ Parameters ---------- cfg : :class:`CfgNode` Configuration for the model """ def __init__(self, cfg): super(DICL, self).__init__() self.cfg = cfg self.context_net = cfg.CONTEXT_NET self.scale_factors = cfg.SCALE_FACTORS self.scale_contexts = cfg.SCALE_CONTEXTS self.feature_net = build_encoder(cfg.ENCODER) self.entropy_fn = FlowEntropy() matching_net = build_similarity(cfg.SIMILARITY.MATCHING_NET) search_range = cfg.SEARCH_RANGE self.cost_fn2 = build_similarity( cfg.SIMILARITY, max_u=search_range[0], max_v=search_range[0], matching_net=matching_net, ) self.cost_fn3 = build_similarity( cfg.SIMILARITY, max_u=search_range[1], max_v=search_range[1], matching_net=matching_net, ) self.cost_fn4 = build_similarity( cfg.SIMILARITY, max_u=search_range[2], max_v=search_range[2], matching_net=matching_net, ) self.cost_fn5 = build_similarity( cfg.SIMILARITY, max_u=search_range[3], max_v=search_range[3], matching_net=matching_net, ) self.cost_fn6 = build_similarity( cfg.SIMILARITY, max_u=search_range[4], max_v=search_range[4], matching_net=matching_net, ) self.flow_decoder2 = build_decoder( cfg.DECODER, max_u=search_range[0], max_v=search_range[0] ) self.flow_decoder3 = build_decoder( cfg.DECODER, max_u=search_range[1], max_v=search_range[1] ) self.flow_decoder4 = build_decoder( cfg.DECODER, max_u=search_range[2], max_v=search_range[2] ) self.flow_decoder5 = build_decoder( cfg.DECODER, max_u=search_range[3], max_v=search_range[3] ) self.flow_decoder6 = build_decoder( cfg.DECODER, max_u=search_range[4], max_v=search_range[4] ) if self.context_net: self.context_net2 = nn.Sequential( ConvNormRelu(38, 64, kernel_size=3, padding=1, dilation=1), ConvNormRelu(64, 128, kernel_size=3, padding=2, dilation=2), ConvNormRelu(128, 128, kernel_size=3, padding=4, dilation=4), ConvNormRelu(128, 96, kernel_size=3, padding=8, dilation=8), ConvNormRelu(96, 64, kernel_size=3, padding=16, dilation=16), ConvNormRelu(64, 32, kernel_size=3, padding=1, dilation=1), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, bias=True), ) self.context_net3 = nn.Sequential( ConvNormRelu(38, 64, kernel_size=3, padding=1, dilation=1), ConvNormRelu(64, 128, kernel_size=3, padding=2, dilation=2), ConvNormRelu(128, 128, kernel_size=3, padding=4, dilation=4), ConvNormRelu(128, 96, kernel_size=3, padding=8, dilation=8), ConvNormRelu(96, 64, kernel_size=3, padding=16, dilation=16), ConvNormRelu(64, 32, kernel_size=3, padding=1, dilation=1), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, bias=True), ) self.context_net4 = nn.Sequential( ConvNormRelu(38, 64, kernel_size=3, padding=1, dilation=1), ConvNormRelu(64, 128, kernel_size=3, padding=2, dilation=2), ConvNormRelu(128, 128, kernel_size=3, padding=4, dilation=4), ConvNormRelu(128, 64, kernel_size=3, padding=8, dilation=8), ConvNormRelu(64, 32, kernel_size=3, padding=1, dilation=1), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, bias=True), ) self.context_net5 = nn.Sequential( ConvNormRelu(38, 64, kernel_size=3, padding=1, dilation=1), ConvNormRelu(64, 128, kernel_size=3, padding=2, dilation=2), ConvNormRelu(128, 64, kernel_size=3, padding=4, dilation=4), ConvNormRelu(64, 32, kernel_size=3, padding=1, dilation=1), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, bias=True), ) self.context_net6 = nn.Sequential( ConvNormRelu(38, 64, kernel_size=3, padding=1, dilation=1), ConvNormRelu(64, 64, kernel_size=3, padding=2, dilation=2), ConvNormRelu(64, 32, kernel_size=3, padding=1, dilation=1), nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1, bias=True), ) self._init_weights() if cfg.DAP.USE_DAP: name = "DisplacementAwareProjection" self.dap_layer2 = build_module( cfg.DAP, name=name, max_displacement=search_range[0] ) self.dap_layer3 = build_module( cfg.DAP, name=name, max_displacement=search_range[1] ) self.dap_layer4 = build_module( cfg.DAP, name=name, max_displacement=search_range[2] ) self.dap_layer5 = build_module( cfg.DAP, name=name, max_displacement=search_range[3] ) self.dap_layer6 = build_module( cfg.DAP, name=name, max_displacement=search_range[4] ) if cfg.DAP.INIT_ID: nn.init.eye_( self.dap_layer2.dap_layer.conv.weight.squeeze(-1).squeeze(-1) ) nn.init.eye_( self.dap_layer3.dap_layer.conv.weight.squeeze(-1).squeeze(-1) ) nn.init.eye_( self.dap_layer4.dap_layer.conv.weight.squeeze(-1).squeeze(-1) ) nn.init.eye_( self.dap_layer5.dap_layer.conv.weight.squeeze(-1).squeeze(-1) ) nn.init.eye_( self.dap_layer6.dap_layer.conv.weight.squeeze(-1).squeeze(-1) ) def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d)) or isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _process_level( self, x, y, orig_img, level, prev_upflow, scale_factor, upflow_size, scale_context=None, warp_flow=True, ): level = str(level) cost_fn = getattr(self, "cost_fn" + level) flow_decoder = getattr(self, "flow_decoder" + level) dap_layer = getattr(self, "dap_layer" + level) if self.context_net: context_net = getattr(self, "context_net" + level) if warp_flow: warped_flow = warp(y, prev_upflow) cost = cost_fn(x, warped_flow) else: cost = cost_fn(x, y) g = F.interpolate( orig_img, scale_factor=scale_factor, mode="bilinear", align_corners=True, recompute_scale_factor=True, ) if self.cfg.DAP.USE_DAP: cost = dap_layer(cost) if warp_flow: flow = flow_decoder(cost) + prev_upflow else: flow = flow_decoder(cost) if self.context_net: if self.cfg.SUP_RAW_FLOW: raw_flow = flow else: raw_flow = None entropy = self.entropy_fn(cost) features = torch.cat((flow.detach(), entropy.detach(), x, g), dim=1) flow = flow + context_net(features) * scale_context upflow = 2.0 * F.interpolate( flow, upflow_size, mode="bilinear", align_corners=True ) upflow = upflow.detach() return upflow, flow, raw_flow
[docs] def forward(self, img1, img2): """ Performs forward pass of the network Parameters ---------- img1 : torch.Tensor Image to predict flow from img2 : torch.Tensor Image to predict flow to Returns ------- :class:`dict` <flow_preds> torch.Tensor : intermediate flow predications from img1 to img2 <flow_upsampled> torch.Tensor : if model is in eval state, return upsampled flow """ _, x2, x3, x4, x5, x6 = self.feature_net(img1) _, y2, y3, y4, y5, y6 = self.feature_net(img2) upflow6, flow6, raw_flow6 = self._process_level( x6, y6, img1, 6, None, self.scale_factors[4], (x5.shape[2], x5.shape[3]), self.scale_contexts[4], warp_flow=False, ) upflow5, flow5, raw_flow5 = self._process_level( x5, y5, img1, 5, upflow6, self.scale_factors[3], (x4.shape[2], x4.shape[3]), self.scale_contexts[3], ) upflow4, flow4, raw_flow4 = self._process_level( x4, y4, img1, 4, upflow5, self.scale_factors[2], (x3.shape[2], x3.shape[3]), self.scale_contexts[2], ) upflow3, flow3, raw_flow3 = self._process_level( x3, y3, img1, 3, upflow4, self.scale_factors[1], (x2.shape[2], x2.shape[3]), self.scale_contexts[1], ) _, flow2, raw_flow2 = self._process_level( x2, y2, img1, 2, upflow3, self.scale_factors[0], (x2.shape[2], x2.shape[3]), self.scale_contexts[0], ) output = {"flow_preds": [flow2, flow3, flow4, flow5, flow6]} if self.training: if self.cfg.SUP_RAW_FLOW: output["flow_preds"] = [ flow2, raw_flow2, flow3, raw_flow3, flow4, raw_flow4, flow5, raw_flow5, flow6, raw_flow6, ] return output _, _, height, width = img1.size() flow_up = F.interpolate( flow2, (height, width), mode="bilinear", align_corners=True ) output["flow_upsampled"] = flow_up return output