import torch
import torch.nn as nn
import torch.nn.functional as F

from ..decoder import ContextNetwork, build_decoder
from ..encoder import build_encoder
from ..modules import BaseModule
from .build import MODEL_REGISTRY

[docs]@MODEL_REGISTRY.register() class PWCNet(BaseModule): """ Implementation of the paper `PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume <>`_ Parameters ---------- cfg : :class:`CfgNode` Configuration for the model """ def __init__(self, cfg): super(PWCNet, self).__init__() self.cfg = cfg self.encoder = build_encoder(cfg.ENCODER) self.decoder = build_decoder(cfg.DECODER) search_range = (2 * cfg.DECODER.SIMILARITY.MAX_DISPLACEMENT + 1) ** 2 self.context_net = ContextNetwork( in_channels=search_range + cfg.DECODER.SIMILARITY.MAX_DISPLACEMENT + cfg.DECODER.CONFIG[-1] + sum(cfg.DECODER.CONFIG), config=cfg.DECODER.CONFIG, ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(, mode="fan_in") if m.bias is not None:
[docs] def forward(self, img1, img2): """ Performs forward pass of the network Parameters ---------- img1 : torch.Tensor Image to predict flow from img2 : torch.Tensor Image to predict flow to Returns ------- :class:`dict` <flow_preds> torch.Tensor : intermediate flow predications from img1 to img2 <flow_upsampled> torch.Tensor : if model is in eval state, return upsampled flow """ H, W = img1.shape[-2:] feature_pyramid1 = self.encoder(img1) feature_pyramid2 = self.encoder(img2) feature_pyramid1.reverse() feature_pyramid2.reverse() flow_preds, features = self.decoder(feature_pyramid1, feature_pyramid2) flow_preds[-1] += self.context_net(features) output = {"flow_preds": flow_preds} if return output flow_up = flow_preds[-1] flow_up = F.interpolate( flow_up, size=(H, W), mode="bilinear", align_corners=False ) output["flow_upsampled"] = flow_up return output