Source code for ezflow.models.flownet_c

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import constant_, kaiming_normal_

from ..decoder import build_decoder
from ..encoder import BasicConvEncoder, build_encoder
from ..modules import BaseModule, conv
from ..similarity import IterSpatialCorrelationSampler as SpatialCorrelationSampler
from .build import MODEL_REGISTRY


[docs]@MODEL_REGISTRY.register() class FlowNetC(BaseModule): """ Implementation of **FlowNetCorrelation** from the paper `FlowNet: Learning Optical Flow with Convolutional Networks <https://arxiv.org/abs/1504.06852>`_ Parameters ---------- cfg : :class:`CfgNode` Configuration for the model """ def __init__(self, cfg): super(FlowNetC, self).__init__() self.cfg = cfg channels = cfg.ENCODER.CONFIG cfg.ENCODER.CONFIG = channels[:3] self.feature_encoder = build_encoder(cfg.ENCODER) self.correlation_layer = SpatialCorrelationSampler( kernel_size=1, patch_size=2 * cfg.SIMILARITY.MAX_DISPLACEMENT + 1, padding=cfg.SIMILARITY.PAD_SIZE, dilation_patch=2, ) self.corr_activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.conv_redirect = conv( in_channels=cfg.ENCODER.CONFIG[-1], out_channels=32, norm=cfg.ENCODER.NORM ) self.corr_encoder = BasicConvEncoder( in_channels=473, config=channels[3:], norm=cfg.ENCODER.NORM ) self.decoder = build_decoder(cfg.DECODER) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): kaiming_normal_(m.weight, 0.1) if m.bias is not None: constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): constant_(m.weight, 1) constant_(m.bias, 0)
[docs] def forward(self, img1, img2): """ Performs forward pass of the network Parameters ---------- img1 : torch.Tensor Image to predict flow from img2 : torch.Tensor Image to predict flow to Returns ------- :class:`dict` <flow_preds> torch.Tensor : intermediate flow predications from img1 to img2 <flow_upsampled> torch.Tensor : if model is in eval state, return upsampled flow """ H, W = img1.shape[-2:] conv_outputs1 = self.feature_encoder(img1) conv_outputs2 = self.feature_encoder(img2) corr_output = self.correlation_layer(conv_outputs1[-1], conv_outputs2[-1]) corr_output = corr_output.view( corr_output.shape[0], -1, corr_output.shape[3], corr_output.shape[4] ) corr_output = self.corr_activation(corr_output) # Redirect final feature output of img1 conv_redirect_output = self.conv_redirect(conv_outputs1[-1]) x = torch.cat([conv_redirect_output, corr_output], dim=1) conv_outputs = self.corr_encoder(x) # Add first two convolution output from img1 conv_outputs = [conv_outputs1[0], conv_outputs1[1]] + conv_outputs flow_preds = self.decoder(conv_outputs) output = {"flow_preds": flow_preds} if self.training: return output flow_up = flow_preds[-1] flow_up = F.interpolate( flow_up, size=(H, W), mode="bilinear", align_corners=False ) output["flow_upsampled"] = flow_up return output