Source code for ezflow.decoder.conv_decoder

import torch
import torch.nn as nn

from ..config import configurable
from .build import DECODER_REGISTRY


[docs]def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1): """ Block for a 2D Convolutional layer with Leaky ReLU activation Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels kernel_size : int, default : 3 Size of the kernel stride : int, default : 1 Stride of the convolution dilation : int, default : 1 Spacing between kernel elements Returns ------- torch.nn.Sequential block containing nn.Conv2d layer and leaky relu """ return nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True, ), nn.LeakyReLU(0.1), )
[docs]def deconv(in_channels, out_channels): """ Block for a 2D Transpose Convolutional layer with Leaky ReLU activation Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels Returns ------- torch.nn.Sequential block containing nn.ConvTranspose2d layer and leaky relu """ return nn.Sequential( nn.ConvTranspose2d( in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False ), nn.LeakyReLU(0.1, inplace=True), )
[docs]@DECODER_REGISTRY.register() class ConvDecoder(nn.Module): """ Applies a 2D Convolutional decoder to the input feature map. Used in **PWCNet** (https://arxiv.org/abs/1709.02371) Parameters ---------- config : List[int], default : [128, 128, 96, 64, 32] List containing all output channels of the decoder concat_channels : int, optional Additional input channels to be concatenated for convolution layers to_flow : bool, default : True If True, convoloves decoder output to optical flow of shape N x 2 x H x W block : object, default : None the conv block to be used to build the decoder layers. """ @configurable def __init__( self, config=[128, 128, 96, 64, 32], concat_channels=None, to_flow=True, block=None, ): super().__init__() self.concat_channels = concat_channels if block is None: block = conv self.decoder = nn.ModuleList() config_cumsum = torch.cumsum(torch.tensor(config), dim=0) if concat_channels is not None: self.decoder.append( block(concat_channels, config[0], kernel_size=3, stride=1) ) for i in range(len(config) - 1): if concat_channels is not None: in_channels = config_cumsum[i] + concat_channels else: in_channels = config[i] self.decoder.append( block(in_channels, config[i + 1], kernel_size=3, stride=1) ) self.to_flow = nn.Identity() if to_flow: if concat_channels is not None: in_channels = config_cumsum[-1] + concat_channels else: in_channels = config[-1] self.to_flow = nn.Conv2d( in_channels, 2, kernel_size=3, stride=1, padding=1, bias=True ) @classmethod def from_config(self, cfg): return {"config": cfg.CONFIG}
[docs] def forward(self, x): """ Performs forward pass. Parameters ---------- x : torch.Tensor Input feature map Returns ------- torch.Tensor A tensor of shape N x 2 x H x W representing the flow torch.Tensor Tensor of shape N x output_channel x H x W """ for i in range(len(self.decoder)): y = self.decoder[i](x) if self.concat_channels is not None: x = torch.cat((x, y), dim=1) else: x = y return self.to_flow(x), x
[docs]@DECODER_REGISTRY.register() class FlowNetConvDecoder(nn.Module): """ Applies a 2D Convolutional decoder to regress the optical flow from the intermediate outputs convolutions of the encoder. Used in **FlowNetSimple** (https://arxiv.org/abs/1504.06852) Parameters ---------- in_channels : int, default: 1024 Number of input channels of the decoder. This value should be equal to the final output channels of the encoder config : List[int], default : [512, 256, 128, 64] List containing all output channels of the decoder """ @configurable def __init__(self, in_channels=1024, config=[512, 256, 128, 64]): super().__init__() if isinstance(config, tuple): config = list(config) out_channels = [in_channels] + config in_channels = [] prev_out_channels = 0 for i in range(len(out_channels)): if i > 0: inp = out_channels[i] + prev_out_channels + 2 prev_out_channels = out_channels[i] else: inp = out_channels[i] prev_out_channels = out_channels[i + 1] in_channels.append(inp) self.predict_flow = nn.ModuleList() self.upsample_flow = nn.ModuleList() self.deconv = nn.ModuleList() for i in range(len(in_channels) - 1): self.predict_flow.append( nn.Conv2d( in_channels[i], 2, kernel_size=3, stride=1, padding=1, bias=False ), ) self.upsample_flow.append( nn.ConvTranspose2d(2, 2, kernel_size=4, stride=2, padding=1, bias=False) ) self.deconv.append(deconv(in_channels[i], out_channels[i + 1])) self.to_flow = nn.Conv2d( in_channels[-1], 2, kernel_size=3, stride=1, padding=1, bias=False ) @classmethod def from_config(self, cfg): return {"in_channels": cfg.IN_CHANNELS, "config": cfg.CONFIG}
[docs] def forward(self, x): """ Performs forward pass. Parameters ---------- x : List[torch.Tensor] List of all the outputs from each convolution layer of the encoder Returns ------- List[torch.Tensor], List of all the flow predictions from each decoder layer """ flow_preds = [] conv_out = x[-1] flow = self.predict_flow[0](conv_out) flow_up = self.upsample_flow[0](flow) deconv_out = self.deconv[0](conv_out) flow_preds.append(flow) layer_index = 1 start = len(x) - 2 end = 1 for conv_out in x[start:end:-1]: assert conv_out.shape[2] == deconv_out.shape[2] == flow_up.shape[2] assert conv_out.shape[3] == deconv_out.shape[3] == flow_up.shape[3] concat_out = torch.cat((conv_out, deconv_out, flow_up), dim=1) flow = self.predict_flow[layer_index](concat_out) flow_up = self.upsample_flow[layer_index](flow) deconv_out = self.deconv[layer_index](concat_out) flow_preds.append(flow) layer_index += 1 concat_out = torch.cat((x[1], deconv_out, flow_up), dim=1) flow = self.to_flow(concat_out) flow_preds.append(flow) return flow_preds