Source code for ezflow.encoder.pyramid

import torch.nn as nn

from ..config import configurable
from ..modules import conv
from .build import ENCODER_REGISTRY


[docs]@ENCODER_REGISTRY.register() class PyramidEncoder(nn.Module): """ Pyramid encoder which returns a hierarchy of features Used in **PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume** (https://arxiv.org/abs/1709.02371) Parameters ---------- in_channels : int Number of input channels config : list of int Configuration of the pyramid encoder's layers """ @configurable def __init__(self, in_channels=3, config=[16, 32, 64, 96, 128, 196]): super().__init__() if isinstance(config, tuple): config = list(config) config = [in_channels] + config self.encoder = nn.ModuleList() for i in range(len(config) - 1): self.encoder.append( nn.Sequential( conv(config[i], config[i + 1], kernel_size=3, stride=2), conv(config[i + 1], config[i + 1], kernel_size=3, stride=1), conv(config[i + 1], config[i + 1], kernel_size=3, stride=1), ) ) @classmethod def from_config(self, cfg): return { "config": cfg.CONFIG, }
[docs] def forward(self, img): """ Performs forward pass. Parameters ---------- img : torch.Tensor Input tensor Returns ------- List[torch.Tensor], List of all the output convolutions from each encoder layer """ feature_pyramid = [] x = img for i in range(len(self.encoder)): x = self.encoder[i](x) feature_pyramid.append(x) return feature_pyramid