Source code for ezflow.encoder.pspnet

import math

import torch
import torch.nn.functional as F
from torch import nn

from ..config import configurable
from ..modules import ConvNormRelu
from .build import ENCODER_REGISTRY


[docs]class ResidualBlock(nn.Module): expansion = 1 def __init__( self, in_channels, out_channels, stride=1, downsample=None, dilation=1, norm=True, ): super(ResidualBlock, self).__init__() if dilation > 1: padding = dilation else: padding = 1 self.downsample = nn.Identity() if downsample is None else downsample if norm: norm = "batch" else: norm = None self.block = nn.Sequential( ConvNormRelu( in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, dilation=dilation, norm=norm, ), ConvNormRelu( out_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, activation=None, ), )
[docs] def forward(self, x): residual = x out = self.block(x) residual = self.downsample(x) out += residual out = F.leaky_relu(out, 0.1) return out
[docs]class PyramidPooling(nn.Module): """ Pyramid pooling module for the **PSPNet** feature extractor Parameters ---------- in_channels : int Number of input channels levels : int Number of levels in the pyramid norm : bool Whether to use batch normalization """ def __init__(self, in_channels, levels=4, norm=True): super(PyramidPooling, self).__init__() self.levels = levels if norm: norm = "batch" else: norm = None self.paths = [] for _ in range(levels): self.paths.append( ConvNormRelu( in_channels, in_channels, kernel_size=1, stride=1, padding=0, dilation=1, norm=norm, ) ) self.path_module_list = nn.ModuleList(self.paths) self.relu = nn.LeakyReLU(0.1, inplace=True)
[docs] def forward(self, x): H, W = x.shape[2:] k_sizes = [] strides = [] for pool_size in torch.linspace(1, min(H, W) // 2, self.levels): k_sizes.append((int(H / pool_size), int(W / pool_size))) strides.append((int(H / pool_size), int(W / pool_size))) k_sizes = k_sizes[::-1] strides = strides[::-1] pp_sum = x for i, module in enumerate(self.path_module_list): out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) out = module(out) out = F.interpolate(out, size=(H, W), mode="bilinear", align_corners=True) pp_sum = pp_sum + 1.0 / self.levels * out pp_sum = self.relu(pp_sum / 2.0) return pp_sum
[docs]@ENCODER_REGISTRY.register() class PSPNetBackbone(nn.Module): """ PSPNet feature extractor backbone (https://arxiv.org/abs/1612.01105) Used in **Volumetric Correspondence Networks for Optical Flow** (https://papers.nips.cc/paper/2019/hash/bbf94b34eb32268ada57a3be5062fe7d-Abstract.html) Parameters ---------- is_proj : bool Whether to use projection pooling or not groups : int Number of groups in the convolutional in_channels : int Number of input channels norm : bool Whether to use batch normalization """ @configurable def __init__(self, is_proj=True, groups=1, in_channels=3, norm=True): super(PSPNetBackbone, self).__init__() self.is_proj = is_proj self.inplanes = 32 if norm: norm = "batch" else: norm = None self.convbnrelu1_1 = ConvNormRelu( in_channels, 16, kernel_size=3, padding=1, stride=2, norm=norm ) self.convbnrelu1_2 = ConvNormRelu( 16, 16, kernel_size=3, padding=1, stride=1, norm=norm ) self.convbnrelu1_3 = ConvNormRelu( 16, 32, kernel_size=3, padding=1, stride=1, norm=norm ) self.res_block3 = self._make_layer(ResidualBlock, 64, 1, stride=2) self.res_block5 = self._make_layer(ResidualBlock, 128, 1, stride=2) self.res_block6 = self._make_layer(ResidualBlock, 128, 1, stride=2) self.res_block7 = self._make_layer(ResidualBlock, 128, 1, stride=2) self.pyramid_pooling = PyramidPooling(128, levels=3, norm=norm) self.upconv6 = nn.Sequential( nn.Upsample(scale_factor=2), ConvNormRelu(128, 64, kernel_size=3, padding=1, stride=1, norm=norm), ) self.iconv5 = ConvNormRelu( 192, 128, kernel_size=3, padding=1, stride=1, norm=norm ) self.upconv5 = nn.Sequential( nn.Upsample(scale_factor=2), ConvNormRelu(128, 64, kernel_size=3, padding=1, stride=1, norm=norm), ) self.iconv4 = ConvNormRelu( 192, 128, kernel_size=3, padding=1, stride=1, norm=norm ) self.upconv4 = nn.Sequential( nn.Upsample(scale_factor=2), ConvNormRelu(128, 64, kernel_size=3, padding=1, stride=1, norm=norm), ) self.iconv3 = ConvNormRelu( 128, 64, kernel_size=3, padding=1, stride=1, norm=norm ) self.upconv3 = nn.Sequential( nn.Upsample(scale_factor=2), ConvNormRelu(64, 32, kernel_size=3, padding=1, stride=1, norm=norm), ) self.iconv2 = ConvNormRelu( 64, 64, kernel_size=3, padding=1, stride=1, norm=norm ) if self.is_proj: self.proj6 = ConvNormRelu( 128, 128 // groups, kernel_size=1, padding=0, stride=1 ) self.proj5 = ConvNormRelu( 128, 128 // groups, kernel_size=1, padding=0, stride=1 ) self.proj4 = ConvNormRelu( 128, 128 // groups, kernel_size=1, padding=0, stride=1 ) self.proj3 = ConvNormRelu( 64, 64 // groups, kernel_size=1, padding=0, stride=1 ) self.proj2 = ConvNormRelu( 64, 64 // groups, kernel_size=1, padding=0, stride=1 ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2.0 / n)) if hasattr(m.bias, "data"): m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) @classmethod def from_config(cls, cfg): return { "is_proj": cfg.IS_PROJ, "groups": cfg.GROUPS, "in_channels": cfg.IN_CHANNELS, "norm": cfg.NORM, }
[docs] def forward(self, x): conv1 = self.convbnrelu1_1(x) conv1 = self.convbnrelu1_2(conv1) conv1 = self.convbnrelu1_3(conv1) pool1 = F.max_pool2d(conv1, 3, 2, 1) rconv3 = self.res_block3(pool1) conv4 = self.res_block5(rconv3) conv5 = self.res_block6(conv4) conv6 = self.res_block7(conv5) conv6 = self.pyramid_pooling(conv6) conv6x = F.interpolate( conv6, [conv5.size()[2], conv5.size()[3]], mode="bilinear", align_corners=True, ) concat5 = torch.cat((conv5, self.upconv6[1](conv6x)), dim=1) conv5 = self.iconv5(concat5) conv5x = F.interpolate( conv5, [conv4.size()[2], conv4.size()[3]], mode="bilinear", align_corners=True, ) concat4 = torch.cat((conv4, self.upconv5[1](conv5x)), dim=1) conv4 = self.iconv4(concat4) conv4x = F.interpolate( conv4, [rconv3.size()[2], rconv3.size()[3]], mode="bilinear", align_corners=True, ) concat3 = torch.cat((rconv3, self.upconv4[1](conv4x)), dim=1) conv3 = self.iconv3(concat3) conv3x = F.interpolate( conv3, [pool1.size()[2], pool1.size()[3]], mode="bilinear", align_corners=True, ) concat2 = torch.cat((pool1, self.upconv3[1](conv3x)), dim=1) conv2 = self.iconv2(concat2) if self.is_proj: proj6 = self.proj6(conv6) proj5 = self.proj5(conv5) proj4 = self.proj4(conv4) proj3 = self.proj3(conv3) proj2 = self.proj2(conv2) return [proj6, proj5, proj4, proj3, proj2] return [conv6, conv5, conv4, conv3, conv2]