Source code for ezflow.decoder.separable_conv

import torch.nn.functional as F
from torch import nn

from ..config import configurable
from .build import DECODER_REGISTRY


[docs]class FeatureProjection4D(nn.Module): """ Applies a 3D convolution to the input feature map Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels stride : int Stride of the convolution norm : bool, default : True If True, applies Batch Norm 3D groups : int, default : 1 Number of groupds for 3D convolution, in_channels and out_channels must be divisible by groups """ def __init__(self, in_channels, out_channels, stride, norm=True, groups=1): super(FeatureProjection4D, self).__init__() self.norm = norm self.stride = stride bias = not norm self.conv1 = nn.Conv3d( in_channels, out_channels, 1, (stride, stride, 1), padding=0, bias=bias, groups=groups, ) self.bn = nn.BatchNorm3d(out_channels)
[docs] def forward(self, x): B, C, U, V, H, W = x.size() x = self.conv1(x.view(B, C, U, V, H * W)) if self.norm: x = self.bn(x) _, C, U, V, _ = x.shape x = x.view(B, C, U, V, H, W) return x
[docs]@DECODER_REGISTRY.register() class SeparableConv4D(nn.Module): """ Applies two 3D convolution followed by an optional 2D convolution to the input feature map. Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels stride : tuple, default : (1, 1, 1) Stride of the convolution norm : bool, default : True If True, applies Batch Normalization k_size : int, default : 3 Size of the kernel full : bool, default : True If True, applies a stride of (1, 1, 1) groups : int, default : 1 Number of groups for 3D convolution, in_channels and out_channels must both be divisible by groups """ @configurable def __init__( self, in_channels, out_channels, stride=(1, 1, 1), norm=True, k_size=3, full=True, groups=1, ): super(SeparableConv4D, self).__init__() bias = not norm self.is_proj = False self.stride = stride[0] expand = 1 if norm: if in_channels != out_channels: self.is_proj = True self.proj = nn.Sequential( nn.Conv2d( in_channels, out_channels, 1, bias=bias, padding=0, groups=groups, ), nn.BatchNorm2d(out_channels), ) if full: self.conv1 = nn.Sequential( nn.Conv3d( in_channels * expand, in_channels, (1, k_size, k_size), stride=(1, self.stride, self.stride), bias=bias, padding=(0, k_size // 2, k_size // 2), groups=groups, ), nn.BatchNorm3d(in_channels), ) else: self.conv1 = nn.Sequential( nn.Conv3d( in_channels * expand, in_channels, (1, k_size, k_size), stride=1, bias=bias, padding=(0, k_size // 2, k_size // 2), groups=groups, ), nn.BatchNorm3d(in_channels), ) self.conv2 = nn.Sequential( nn.Conv3d( in_channels, in_channels * expand, (k_size, k_size, 1), stride=(self.stride, self.stride, 1), bias=bias, padding=(k_size // 2, k_size // 2, 0), groups=groups, ), nn.BatchNorm3d(in_channels * expand), ) else: if in_channels != out_channels: self.is_proj = True self.proj = nn.Conv2d( in_channels, out_channels, 1, bias=bias, padding=0, groups=groups ) if full: self.conv1 = nn.Conv3d( in_channels * expand, in_channels, (1, k_size, k_size), stride=(1, self.stride, self.stride), bias=bias, padding=(0, k_size // 2, k_size // 2), groups=groups, ) else: self.conv1 = nn.Conv3d( in_channels * expand, in_channels, (1, k_size, k_size), stride=1, bias=bias, padding=(0, k_size // 2, k_size // 2), groups=groups, ) self.conv2 = nn.Conv3d( in_channels, in_channels * expand, (k_size, k_size, 1), stride=(self.stride, self.stride, 1), bias=bias, padding=(k_size // 2, k_size // 2, 0), groups=groups, ) self.relu = nn.ReLU(inplace=True) @classmethod def from_config(cls, cfg): return { "in_channels": cfg.IN_CHANNELS, "out_channels": cfg.OUT_CHANNELS, "stride": cfg.STRIDE, "norm": cfg.NORM, "k_size": cfg.k_size, "full": cfg.FULL, "groups": cfg.GROUPS, }
[docs] def forward(self, x): B, C, U, V, H, W = x.shape x = self.conv2(x.view(B, C, U, V, -1)) B, C, U, V, _ = x.shape x = self.relu(x) x = self.conv1(x.view(B, C, -1, H, W)) B, C, _, H, W = x.shape if self.is_proj: x = self.proj(x.view(B, W, -1, W)) x = x.view(B, -1, U, V, H, W) return x
[docs]class SeparableConv4DBlock(nn.Module): """ Applies separate SeperableConv4d convolutions to the input feature map. Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels stride : tuple, default : (1, 1, 1) Stride of the convolution norm : bool, default : True If True, applies Batch Normalization k_size : int, default : 3 Size of the kernel full : bool, default : True If True, applies SeparableConv4D otherwise FeatureProjection4D groups : int, default : 1 Number of groups for 3D convolution, in_channels and out_channels must both be divisible by groups """ def __init__( self, in_channels, out_channels, stride=(1, 1, 1), norm=True, full=True, groups=1, ): super(SeparableConv4DBlock, self).__init__() if in_channels == out_channels and stride == (1, 1, 1): self.downsample = None else: if full: self.downsample = SeparableConv4D( in_channels, out_channels, stride, norm=norm, k_size=1, full=full, groups=groups, ) else: self.downsample = FeatureProjection4D( in_channels, out_channels, stride[0], norm=norm, groups=groups ) self.conv1 = SeparableConv4D( in_channels, out_channels, stride, norm=norm, full=full, groups=groups ) self.conv2 = SeparableConv4D( out_channels, out_channels, (1, 1, 1), norm=norm, full=full, groups=groups ) self.relu1 = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=True)
[docs] def forward(self, x): out = self.relu1(self.conv1(x)) if self.downsample: x = self.downsample(x) out = self.relu2(x + self.conv2(out)) return out
[docs]@DECODER_REGISTRY.register() class Butterfly4D(nn.Module): """ Applies a FeatureProjection4D followed by five SeperableConv4d convolutions to the input feature map. Parameters ---------- f_dim_1 : int Number of input channels f_dim_2 : int Number of output channels stride : tuple, default : (1, 1, 1) Stride of the convolution norm : bool, default : True If True, applies Batch Normalization full : bool, default : True If True, applies SeparableConv4D otherwise FeatureProjection4D groups : int, default : 1 Number of groups for 3D convolution, in_channels and out_channels must both be divisible by groups """ @configurable def __init__(self, f_dim_1, f_dim_2, norm=True, full=True, groups=1): super(Butterfly4D, self).__init__() self.proj = nn.Sequential( FeatureProjection4D(f_dim_1, f_dim_2, 1, norm=norm, groups=groups), nn.ReLU(inplace=True), ) self.conva1 = SeparableConv4DBlock( f_dim_2, f_dim_2, norm=norm, stride=(2, 1, 1), full=full, groups=groups ) self.conva2 = SeparableConv4DBlock( f_dim_2, f_dim_2, norm=norm, stride=(2, 1, 1), full=full, groups=groups ) self.convb3 = SeparableConv4DBlock( f_dim_2, f_dim_2, norm=norm, stride=(1, 1, 1), full=full, groups=groups ) self.convb2 = SeparableConv4DBlock( f_dim_2, f_dim_2, norm=norm, stride=(1, 1, 1), full=full, groups=groups ) self.convb1 = SeparableConv4DBlock( f_dim_2, f_dim_2, norm=norm, stride=(1, 1, 1), full=full, groups=groups ) @classmethod def from_config(cls, cfg): return { "f_dim_1": cfg.F_DIM_1, "f_dim_2": cfg.F_DIM_2, "norm": cfg.NORM, "full": cfg.FULL, "groups": cfg.GROUPS, }
[docs] def forward(self, x): out = self.proj(x) B, C, U, V, H, W = out.shape out1 = self.conva1(out) _, _, U1, V1, H1, W1 = out1.shape out2 = self.conva2(out1) _, _, U2, V2, H2, W2 = out2.shape out2 = self.convb3(out2) t_out_1 = F.interpolate( out2.view(B, C, U2, V2, -1), (U1, V1, H2 * W2), mode="trilinear", align_corners=True, ).view(B, C, U1, V1, H2, W2) t_out_1 = F.interpolate( t_out_1.view(B, C, -1, H2, W2), (U1 * V1, H1, W1), mode="trilinear", align_corners=True, ).view(B, C, U1, V1, H1, W1) out1 = t_out_1 + out1 out1 = self.convb2(out1) t_out = F.interpolate( out1.view(B, C, U1, V1, -1), (U, V, H1 * W1), mode="trilinear", align_corners=True, ).view(B, C, U, V, H1, W1) t_out = F.interpolate( t_out.view(B, C, -1, H1, W1), (U * V, H, W), mode="trilinear", align_corners=True, ).view(B, C, U, V, H, W) out = t_out + out out = self.convb1(out) return out