Source code for ezflow.encoder.residual

import torch
import torch.nn as nn

from ..config import configurable
from ..modules import BasicBlock, BottleneckBlock
from .build import ENCODER_REGISTRY


[docs]@ENCODER_REGISTRY.register() class BasicEncoder(nn.Module): """ ResNet-style encoder with basic residual blocks Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels norm : str Normalization layer to use. One of "batch", "instance", "group", or None p_dropout : float Dropout probability layer_config : list of int or tuple of int Number of output features per layer num_residual_layers : list of int or tuple of int Number of residual blocks features per layer intermediate_features : bool, default False Whether to return intermediate features to get a feature hierarchy """ @configurable def __init__( self, in_channels=3, norm="batch", p_dropout=0.0, layer_config=(64, 96, 128), num_residual_layers=(2, 2, 2), intermediate_features=False, ): super(BasicEncoder, self).__init__() self.intermediate_features = intermediate_features norm = norm.lower() assert norm in ("group", "batch", "instance", "none") start_channels = layer_config[0] if norm == "group": norm_fn = nn.GroupNorm(num_groups=8, num_channels=start_channels) elif norm == "batch": norm_fn = nn.BatchNorm2d(start_channels) elif norm == "instance": norm_fn = nn.InstanceNorm2d(start_channels) elif norm == "none": norm_fn = nn.Identity() layers = nn.ModuleList( [ nn.Conv2d( in_channels, start_channels, kernel_size=7, stride=2, padding=3 ), norm_fn, nn.ReLU(inplace=True), ] ) for i in range(len(layer_config)): stride = 1 if i == 0 else 2 layers.append( self._make_layer( start_channels, layer_config[i], stride, norm, num_residual_layers[i], ) ) start_channels = layer_config[i] self.dropout = nn.Identity() if p_dropout > 0: self.dropout = nn.Dropout2d(p=p_dropout) self.encoder = layers if self.intermediate_features is False: self.encoder = nn.Sequential(*self.encoder) self._init_weights() def _make_layer(self, in_channels, out_channels, stride, norm, num_layers=2): layers = [BasicBlock(in_channels, out_channels, stride, norm)] for _ in range(num_layers - 1): layers.append(BasicBlock(out_channels, out_channels, stride=1, norm=norm)) return nn.Sequential(*layers) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) @classmethod def from_config(cls, cfg): return { "in_channels": cfg.IN_CHANNELS, "norm": cfg.NORM, "p_dropout": cfg.P_DROPOUT, "layer_config": cfg.LAYER_CONFIG, "num_residual_layers": cfg.NUM_RESIDUAL_LAYERS, "intermediate_features": cfg.INTERMEDIATE_FEATURES, }
[docs] def forward(self, x): if self.intermediate_features: features = [] for i in range(len(self.encoder)): x = self.encoder[i](x) if isinstance(self.encoder[i], nn.Sequential): x = self.dropout(x) features.append(x) return features out = self.encoder(x) out = self.dropout(out) return out
[docs]@ENCODER_REGISTRY.register() class BottleneckEncoder(nn.Module): """ ResNet-style encoder with bottleneck residual blocks Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels norm : str Normalization layer to use. One of "batch", "instance", "group", or None p_dropout : float Dropout probability layer_config : list of int or tuple of int Configuration of encoder's layers intermediate_features : bool, default False Whether to return intermediate features to get a feature hierarchy """ @configurable def __init__( self, in_channels=3, norm="batch", p_dropout=0.0, layer_config=(32, 64, 96), num_residual_layers=(2, 2, 2), intermediate_features=False, ): super(BottleneckEncoder, self).__init__() self.intermediate_features = intermediate_features norm = norm.lower() assert norm in ("group", "batch", "instance", "none") start_channels = layer_config[0] if norm == "group": norm_fn = nn.GroupNorm(num_groups=8, num_channels=start_channels) elif norm == "batch": norm_fn = nn.BatchNorm2d(start_channels) elif norm == "instance": norm_fn = nn.InstanceNorm2d(start_channels) elif norm == "none": norm_fn = nn.Identity() layers = nn.ModuleList( [ nn.Conv2d( in_channels, start_channels, kernel_size=7, stride=2, padding=3 ), norm_fn, nn.ReLU(inplace=True), ] ) for i in range(len(layer_config)): stride = 1 if i == 0 else 2 layers.append( self._make_layer( start_channels, layer_config[i], stride, norm, num_residual_layers[i], ) ) start_channels = layer_config[i] self.dropout = nn.Identity() if p_dropout > 0: self.dropout = nn.Dropout2d(p=p_dropout) self.encoder = layers if self.intermediate_features is False: self.encoder = nn.Sequential(*self.encoder) self._init_weights() def _make_layer(self, in_channels, out_channels, stride, norm, num_layers=2): layers = [BottleneckBlock(in_channels, out_channels, stride, norm)] for _ in range(num_layers - 1): layers.append( BottleneckBlock(out_channels, out_channels, stride=1, norm=norm) ) return nn.Sequential(*layers) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) @classmethod def from_config(cls, cfg): return { "in_channels": cfg.IN_CHANNELS, "norm": cfg.NORM, "p_dropout": cfg.P_DROPOUT, "layer_config": cfg.LAYER_CONFIG, "num_residual_layers": cfg.NUM_RESIDUAL_LAYERS, "intermediate_features": cfg.INTERMEDIATE_FEATURES, }
[docs] def forward(self, x): if self.intermediate_features: features = [] for i in range(len(self.encoder)): x = self.encoder[i](x) if isinstance(self.encoder[i], nn.Sequential): x = self.dropout(x) features.append(x) return features out = self.encoder(x) out = self.dropout(out) return out