Source code for ezflow.models.vcn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..decoder import Butterfly4D, SeparableConv4D, Soft4DFlowRegression
from ..encoder import build_encoder
from ..modules import BaseModule, conv
from ..utils import warp
from .build import MODEL_REGISTRY


def _gen_hypotheses_fusion_block(in_channels, out_channels):

    return nn.Sequential(
        *[
            conv(in_channels, 128, kernel_size=3, stride=1, padding=1, dilation=1),
            conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2),
            conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4),
            conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8),
            conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16),
            conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1),
            nn.Conv2d(32, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
        ]
    )


[docs]@MODEL_REGISTRY.register() class VCN(BaseModule): """ Implementation of the paper `Volumetric Correspondence Networks for Optical Flow <https://papers.nips.cc/paper/2019/hash/bbf94b34eb32268ada57a3be5062fe7d-Abstract.html>`_ Parameters ---------- cfg : :class:`CfgNode` Configuration for the model """ def __init__(self, cfg): super(VCN, self).__init__() self.cfg = cfg self.encoder = build_encoder(cfg.ENCODER) f_dim_a1 = cfg.DECODER.F_DIM_A1 f_dim_a2 = cfg.DECODER.F_DIM_A2 f_dim_b1 = cfg.DECODER.F_DIM_B1 f_dim_b2 = cfg.DECODER.F_DIM_B2 self.max_disps = cfg.MAX_DISPLACEMENTS self.factorization = cfg.FACTORIZATION self.butterfly_filters = nn.ModuleList() self.sep_conv_4d_filters = nn.ModuleList() for _ in range(3): self.butterfly_filters.append( Butterfly4D( f_dim_a1, f_dim_b1, norm=cfg.DECODER.NORM, full=False, ) ) self.sep_conv_4d_filters.append( SeparableConv4D(f_dim_b1, f_dim_b1, norm=False, full=False) ) self.butterfly_filters.append( Butterfly4D( f_dim_a2, f_dim_b1, norm=cfg.DECODER.NORM, full=False, ) ) self.sep_conv_4d_filters.append( SeparableConv4D(f_dim_b1, f_dim_b1, norm=False, full=False) ) self.butterfly_filters.append( Butterfly4D( f_dim_a2, f_dim_b2, norm=cfg.DECODER.NORM, full=True, ) ) self.sep_conv_4d_filters.append( SeparableConv4D(f_dim_b2, f_dim_b2, norm=False, full=True) ) self.flow_regressors = nn.ModuleList() size = cfg.SIZE self.flow_regressors.append( Soft4DFlowRegression( [f_dim_b1 * size[0], size[1] // 64, size[2] // 64], max_disp=self.max_disps[0], entropy=cfg.DECODER.ENTROPY, factorization=self.factorization, ) ) scale = 32 for i in range(1, 4): self.flow_regressors.append( Soft4DFlowRegression( [ f_dim_b1 * size[0], size[1] // scale, size[2] // scale, ], max_disp=self.max_disps[i], entropy=cfg.DECODER.ENTROPY, ) ) scale = scale // 2 self.flow_regressors.append( Soft4DFlowRegression( [f_dim_b2 * size[0], size[1] // 4, size[2] // 4], max_disp=self.max_disps[0], entropy=cfg.DECODER.ENTROPY, factorization=self.factorization, ) ) self.hypotheses_fusion_blocks = nn.ModuleList() for i in range(1, 5): if i == 4: in_channels = 64 + (4 * i * f_dim_b1) else: in_channels = 128 + (4 * i * f_dim_b1) out_channels = 2 * i * f_dim_b1 self.hypotheses_fusion_blocks.append( _gen_hypotheses_fusion_block(in_channels, out_channels) ) self.hypotheses_fusion_blocks.append( _gen_hypotheses_fusion_block( 64 + (16 * f_dim_b1) + (4 * f_dim_b2), (8 * f_dim_b1) + (2 * f_dim_b2), ) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(m.weight.data, mode="fan_in") if m.bias is not None: m.bias.data.zero_() def _corr_fn(self, features1, features2, max_disp, factorization=1): b, c, height, width = features1.shape if features1.is_cuda: cost = torch.cuda.FloatTensor( b, c, 2 * max_disp + 1, 2 * int(max_disp // factorization) + 1, height, width, ).fill_(0.0) else: cost = torch.FloatTensor( b, c, 2 * max_disp + 1, 2 * int(max_disp // factorization) + 1, height, width, ).fill_(0.0) for i in range(2 * max_disp + 1): ind = i - max_disp for j in range(2 * int(max_disp // factorization) + 1): indd = j - int(max_disp // factorization) feata = features1[ :, :, max(0, -indd) : height - indd, max(0, -ind) : width - ind ] featb = features2[ :, :, max(0, +indd) : height + indd, max(0, ind) : width + ind ] diff = feata * featb cost[ :, :, i, j, max(0, -indd) : height - indd, max(0, -ind) : width - ind, ] = diff cost = F.leaky_relu(cost, 0.1, inplace=True) return cost
[docs] def forward(self, img1, img2): batch_size = img1.shape[0] assert ( batch_size == self.cfg.SIZE[0] ), f"Batch size in model configuration must be equal to the training batch size. Model config batch size: {self.cfg.SIZE[0]}, Training batch size: {batch_size}" # if self.cfg.SIZE[0] != img1.shape[0]: # self.cfg.SIZE[0] = img1.shape[0] feature_pyramid1 = self.encoder(img1) feature_pyramid2 = self.encoder(img2) for i in range(len(feature_pyramid1)): feature_pyramid1[i] = feature_pyramid1[i] / ( torch.norm(feature_pyramid1[i], p=2, dim=1, keepdim=True) + 1e-9 ) feature_pyramid2[i] = feature_pyramid2[i] / ( torch.norm(feature_pyramid2[i], p=2, dim=1, keepdim=True) + 1e-9 ) flow_preds = [] flow_intermediates = [] ent_intermediates = [] scale = 32 for i in range(len(self.butterfly_filters)): if i != 0: up_flow = ( F.interpolate( flow_preds[-1], [img1.shape[2] // scale, img1.shape[3] // scale], mode="bilinear", align_corners=True, ) * 2 ) scale = scale // 2 features2 = warp(feature_pyramid2[i], up_flow) else: features2 = feature_pyramid2[i] cost = self._corr_fn( feature_pyramid1[i], features2, self.max_disps[i], factorization=self.cfg.FACTORIZATION, ) cost = self.butterfly_filters[i](cost) cost = self.sep_conv_4d_filters[i](cost) B, C, U, V, H, W = cost.shape cost = cost.view(-1, U, V, H, W) flow, ent = self.flow_regressors[i](cost) if i != 0: flow = flow.view(B, C, 2, H, W) + up_flow[:, np.newaxis] flow = flow.view(batch_size, -1, H, W) ent = ent.view(batch_size, -1, H, W) if i != 0: flow = torch.cat( ( flow, F.interpolate( flow_intermediates[-1].detach() * 2, [flow.shape[2], flow.shape[3]], mode="bilinear", align_corners=True, ), ), dim=1, ) ent = torch.cat( ( ent, F.upsample( ent_intermediates[-1], [flow.shape[2], flow.shape[3]], mode="bilinear", ), ), dim=1, ) flow_intermediates.append(flow) ent_intermediates.append(ent) x = torch.cat([ent.detach(), flow.detach(), feature_pyramid1[i]], dim=1) x = self.hypotheses_fusion_blocks[i](x) x = x.view(B, -1, 2, H, W) flow = (flow.view(B, -1, 2, H, W) * F.softmax(x, dim=1)).sum(dim=1) flow_preds.append(flow) flow_preds.reverse() scale = 4 for i in range(len(flow_preds)): flow_preds[i] = F.interpolate( flow_preds[i], [img1.shape[2], img1.shape[3]], mode="bilinear", align_corners=False, ) flow_preds[i] = flow_preds[i] * scale scale *= 2 output = {"flow_preds": flow_preds} if self.training: return output output["flow_upsampled"] = flow_preds[0] return output