Source code for ezflow.modules.units

import torch
import torch.nn as nn


[docs]class Conv2x(nn.Module): """ Double convolutional layer with the option to perform deconvolution and concatenation Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels deconv : bool Whether to perform deconvolution instead of convolution concat : bool Whether to concatenate the input and the output of the first convolution layer norm : str Type of normalization to use. Can be "batch", "instance", "group", or "none" activation : str Type of activation to use. Can be "relu" or "leakyrelu" """ def __init__( self, in_channels, out_channels, deconv=False, concat=True, norm="batch", activation="relu", ): super(Conv2x, self).__init__() self.concat = concat self.deconv = deconv if deconv: kernel = 4 else: kernel = 3 self.conv1 = ConvNormRelu( in_channels, out_channels, deconv, kernel_size=kernel, stride=2, padding=1, ) if self.concat: self.conv2 = ConvNormRelu( out_channels * 2, out_channels, deconv=False, norm=norm, activation=activation, kernel_size=3, stride=1, padding=1, ) else: self.conv2 = ConvNormRelu( out_channels, out_channels, deconv=False, norm=norm, activation=activation, kernel_size=3, stride=1, padding=1, )
[docs] def forward(self, x, rem): x = self.conv1(x) if self.concat: x = torch.cat((x, rem), 1) else: x = x + rem x = self.conv2(x) return x
[docs]class ConvNormRelu(nn.Module): """ Block for a convolutional layer with normalization and activation Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels deconv : bool, optional If True, use a transposed convolution norm : str, optional Normalization method activation : str, optional Activation function """ def __init__( self, in_channels, out_channels, deconv=False, norm="batch", activation="relu", **kwargs ): super(ConvNormRelu, self).__init__() bias = False if norm is not None: if norm.lower() == "group": self.norm = nn.GroupNorm(out_channels) elif norm.lower() == "batch": self.norm = nn.BatchNorm2d(out_channels) elif norm.lower() == "instance": self.norm = nn.InstanceNorm2d(out_channels) else: self.norm = nn.Identity() bias = True if activation is not None: if activation.lower() == "leakyrelu": self.activation = nn.LeakyReLU(0.1, inplace=True) else: self.activation = nn.ReLU(inplace=True) else: self.activation = nn.Identity() if "kernel_size" not in kwargs.keys(): kwargs["kernel_size"] = 3 if deconv: self.conv = nn.ConvTranspose2d( in_channels, out_channels, bias=bias, **kwargs ) else: self.conv = nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
[docs] def forward(self, x): x = self.conv(x) x = self.norm(x) x = self.activation(x) return x
[docs]def conv( in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, norm=None ): """ 2D convolution layer with optional normalization followed by an inplace LeakyReLU activation of 0.1 negative slope. Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels kernel_size : int, default: 3 Size of the convolutional kernel stride : int, default: 1 Stride of the convolutional kernel padding : int, default: 1 Padding of the convolutional kernel dilation : int, default: 1 Dilation of the convolutional kernel norm : str, default: None Type of normalization to use. Can be None, 'batch', 'layer', 'group' """ bias = False if norm == "group": norm_fn = nn.GroupNorm(num_groups=8, num_channels=out_channels) elif norm == "batch": norm_fn = nn.BatchNorm2d(out_channels) elif norm == "instance": norm_fn = nn.InstanceNorm2d(out_channels) else: norm_fn = nn.Identity() bias = True return nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ), norm_fn, nn.LeakyReLU(0.1, inplace=True), )
[docs]def deconv(in_channels, out_channels, kernel_size=4, stride=2, padding=1): """ 2D transpose convolution layer followed by an activation function Parameters ---------- in_channels : int Number of input channels out_channels : int Number of output channels kernel_size : int, optional Size of the convolutional kernel stride : int, optional Stride of the convolutional kernel padding : int, optional Padding of the convolutional kernel dilation : int, optional Dilation of the convolutional kernel """ return nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride, padding, bias=True )