Source code for ezflow.similarity.learnable_cost

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from spatial_correlation_sampler import SpatialCorrelationSampler
except:
    from .correlation import IterSpatialCorrelationSampler as SpatialCorrelationSampler

from ..config import configurable
from ..modules import ConvNormRelu
from .build import SIMILARITY_REGISTRY


[docs]@SIMILARITY_REGISTRY.register() class Conv2DMatching(nn.Module): """ Convolutional matching/filtering network for cost volume learning Parameters ---------- config : tuple of int or list of int Configuration of the convolutional layers in the network """ @configurable def __init__(self, config=(64, 96, 128, 64, 32, 1)): super(Conv2DMatching, self).__init__() self.matching_net = nn.Sequential( ConvNormRelu(config[0], config[1], kernel_size=3, padding=1, dilation=1), ConvNormRelu(config[1], config[2], kernel_size=3, stride=2, padding=1), ConvNormRelu(config[2], config[2], kernel_size=3, padding=1, dilation=1), ConvNormRelu(config[2], config[3], kernel_size=3, padding=1, dilation=1), ConvNormRelu( config[3], config[4], kernel_size=4, padding=1, stride=2, deconv=True ), nn.Conv2d( config[4], config[5], kernel_size=3, stride=1, padding=1, bias=True ), ) @classmethod def from_config(cls, cfg): return { "config": cfg.CONFIG, }
[docs] def forward(self, x): x = self.matching_net(x) return x
[docs]@SIMILARITY_REGISTRY.register() class Custom2DConvMatching(nn.Module): """ Convolutional matching/filtering network for cost volume learning with custom convolutions Parameters ---------- config : tuple of int or list of int Configuration of the convolutional layers in the network kernel_size : int Kernel size of the convolutional layers **kwargs Additional keyword arguments for the convolutional layers """ @configurable def __init__(self, config=(16, 32, 16, 1), kernel_size=3, **kwargs): super(Custom2DConvMatching, self).__init__() matching_net = nn.ModuleList() for i in range(len(config) - 2): matching_net.append( ConvNormRelu( config[i], config[i + 1], kernel_size=kernel_size, **kwargs ) ) matching_net.append(nn.Conv2d(config[-2], config[-1], kernel_size=1)) self.matching_net = nn.Sequential(*matching_net) @classmethod def from_config(cls, cfg): return { "config": cfg.CONFIG, "kernel_size": cfg.KERNEL_SIZE, }
[docs] def forward(self, x): x = self.matching_net(x) return x
[docs]@SIMILARITY_REGISTRY.register() class LearnableMatchingCost(nn.Module): """ Learnable matching cost network for cost volume learning. Used in **DICL** (https://arxiv.org/abs/2010.14851) Parameters ---------- max_u : int, optional Maximum displacement in the horizontal direction max_v : int, optional Maximum displacement in the vertical direction config : tuple of int or list of int, optional Configuration of the convolutional layers (matching net) in the network remove_warp_hole : bool, optional Whether to remove the warp holes in the cost volume cuda_cost_compute : bool, optional Whether to compute the cost volume on the GPU matching_net : Optional[nn.Module], optional Custom matching network, by default None, which uses a Conv2DMatching network """ @configurable def __init__( self, max_u=3, max_v=3, config=(64, 96, 128, 64, 32, 1), remove_warp_hole=True, cuda_cost_compute=False, matching_net=None, ): super(LearnableMatchingCost, self).__init__() if matching_net is not None: self.matching_net = matching_net else: self.matching_net = Conv2DMatching(config=config) self.max_u = max_u self.max_v = max_v self.remove_warp_hole = remove_warp_hole self.cuda_cost_compute = cuda_cost_compute @classmethod def from_config(cls, cfg): return { "max_u": cfg.MAX_U, "max_v": cfg.MAX_V, "config": cfg.CONFIG, "remove_warp_hole": cfg.REMOVE_WARP_HOLE, }
[docs] def forward(self, x, y): size_u = 2 * self.max_u + 1 size_v = 2 * self.max_v + 1 _, c, height, width = x.shape with torch.cuda.device_of(x): cost = ( x.new() .resize_( x.size()[0], 2 * c, 2 * self.max_u + 1, 2 * self.max_v + 1, height, width, ) .zero_() ) if self.cuda_cost_compute: corr = SpatialCorrelationSampler( kernel_size=1, patch_size=(int(1 + 2 * 3), int(1 + 2 * 3)), stride=1, padding=0, dilation_patch=1, ) cost = corr(x, y) else: for i in range(2 * self.max_u + 1): ind = i - self.max_u for j in range(2 * self.max_v + 1): indd = j - self.max_v cost[ :, :c, i, j, max(0, -indd) : height - indd, max(0, -ind) : width - ind, ] = x[ :, :, max(0, -indd) : height - indd, max(0, -ind) : width - ind ] cost[ :, c:, i, j, max(0, -indd) : height - indd, max(0, -ind) : width - ind, ] = y[ :, :, max(0, +indd) : height + indd, max(0, ind) : width + ind ] if self.remove_warp_hole: valid_mask = cost[:, c:, ...].sum(dim=1) != 0 valid_mask = valid_mask.detach() cost = cost * valid_mask.unsqueeze(1).float() cost = cost.permute([0, 2, 3, 1, 4, 5]).contiguous() cost = cost.view(x.size()[0] * size_u * size_v, c * 2, x.size()[2], x.size()[3]) cost = self.matching_net(cost) cost = cost.view(x.size()[0], size_u, size_v, 1, x.size()[2], x.size()[3]) cost = cost.permute([0, 3, 1, 2, 4, 5]).contiguous() return cost
[docs]@SIMILARITY_REGISTRY.register() class MatryoshkaDilatedCostVolume(nn.Module): """ Cost Volume with concentric offset dilations used in `DCVNet: Dilated Cost Volume Networks for Fast Optical Flow <https://jianghz.me/files/DCVNet_camera_ready_wacv2023.pdf>`_ Parameters ---------- num_groups : int, default 1 Divides channels into groups of batches for batch processing of similarity computation. max_displacement : int, default 4 Determines the cost volume search range/patch size. stride : int, default 1 Stride of the spatial sampler dilations : List[int], default [1, 2, 3, 5, 9, 16] List of steps for every shift in patch. use_relu : bool, default False If True, applies ReLU activation to the cost volume output. """ @configurable def __init__( self, num_groups=1, max_displacement=4, stride=1, dilations=[1, 2, 3, 5, 9, 16], use_relu=False, ): super(MatryoshkaDilatedCostVolume, self).__init__() self.num_groups = num_groups self.use_relu = use_relu self._set_concentric_offsets(dilations=dilations, radius=max_displacement) self.corr_layers = nn.ModuleList() search_range = 2 * max_displacement + 1 for i in range(len(dilations)): self.corr_layers.append( SpatialCorrelationSampler( patch_size=search_range, stride=stride, padding=0, dilation_patch=dilations[i], ) ) @classmethod def from_config(cls, cfg): return { "num_groups": cfg.NUM_GROUPS, "max_displacement": cfg.MAX_DISPLACEMENT, "stride": cfg.STRIDE, "dilations": cfg.DILATIONS, "use_relu": cfg.USE_RELU, } def _set_concentric_offsets(self, dilations, radius): offsets_list = [] for dilation_i in dilations: offsets_i = np.arange(-radius, radius + 1) * dilation_i offsets_list.append(offsets_i.tolist()) offsets = np.array(offsets_list) self.register_buffer("offsets", torch.Tensor(offsets).float()) def get_relative_offsets(self): return self.offsets def get_search_range(self): return self.offsets.shape[1]
[docs] def forward(self, x1, x2): b, c, h, w = x1.shape assert c % self.num_groups == 0 channels_per_group = c // self.num_groups x1 = x1.view(b * self.num_groups, channels_per_group, h, w) x2 = x2.view(b * self.num_groups, channels_per_group, h, w) cost_list = [] for corr_fn in self.corr_layers: cost = corr_fn(x1, x2) _, u, v, h, w = cost.shape cost_list.append(cost.view(b, self.num_groups, u, v, h, w)) cost = torch.cat(cost_list, dim=1) if self.use_relu: cost = F.leaky_relu(cost, negative_slope=0.1) return cost
[docs]@SIMILARITY_REGISTRY.register() class MatryoshkaDilatedCostVolumeList(nn.Module): """ A List of Cost Volume with concentric offset dilations used in `DCVNet: Dilated Cost Volume Networks for Fast Optical Flow <https://jianghz.me/files/DCVNet_camera_ready_wacv2023.pdf>`_ Parameters ---------- num_groups : int, default 1 Divides channels into groups of batches for batch processing of similarity computation. max_displacement : int, default 4 Determines the cost volume search range/patch size. encoder_output_strides : List[int], default [2, 8] Stride of the feature maps from the encoder output, will modify output height and width. dilations : List[int], default [1, 2, 3, 5, 9, 16] List of steps for every shift in patch. normalize_feat_l2 : bool, default False If True, normalizes input feature maps. use_relu: bool, default False If True, applies ReLU activation to the cost volume output. """ @configurable def __init__( self, num_groups=1, max_displacement=4, encoder_output_strides=[2, 8], dilations=[[1], [1, 2, 3, 5, 9, 16]], normalize_feat_l2=False, use_relu=False, ): super(MatryoshkaDilatedCostVolumeList, self).__init__() self.normalize_feat_l2 = normalize_feat_l2 self.cost_volume_list = nn.ModuleList() offsets = None for idx, (dilations_i, feat_stride_i) in enumerate( zip(dilations, encoder_output_strides) ): assert feat_stride_i <= 8 cost_volume_i = MatryoshkaDilatedCostVolume( num_groups=num_groups, max_displacement=max_displacement, dilations=dilations_i, stride=8 // feat_stride_i, use_relu=use_relu, ) self.cost_volume_list.append(cost_volume_i) if offsets is None: offsets = cost_volume_i.get_relative_offsets() * feat_stride_i else: offsets = torch.cat( (offsets, cost_volume_i.get_relative_offsets() * feat_stride_i), dim=0, ) self.offsets = offsets self._set_global_flow_offsets() @classmethod def from_config(cls, cfg): return { "num_groups": cfg.NUM_GROUPS, "max_displacement": cfg.MAX_DISPLACEMENT, "encoder_output_strides": cfg.ENCODER_OUTPUT_STRIDES, "dilations": cfg.DILATIONS, "normalize_feat_l2": cfg.NORMALIZE_FEAT_L2, "use_relu": cfg.USE_RELU, } def _set_global_flow_offsets(self): # process offsets num_dilations, search_range = self.offsets.shape offsets_2d = torch.zeros((num_dilations, search_range, search_range, 2)) for idx in range(num_dilations): offsets_i, offsets_j = torch.meshgrid( self.offsets[idx], self.offsets[idx], indexing="ij" ) offsets_2d[idx, :, :, 0] = offsets_i # y offsets_2d[idx, :, :, 1] = offsets_j # x self.register_buffer("offsets_2d", torch.Tensor(offsets_2d).float()) def get_global_flow_offsets(self): return self.offsets_2d def get_search_range(self): return self.cost_volume_list[0].get_search_range()
[docs] def forward(self, x1, x2): # B, C, U, V, H, W cost_list = [] for idx in range(len(x1)): x1_i = x1[idx] x2_i = x2[idx] if self.normalize_feat_l2: x1_i = x1_i / (x1_i.norm(dim=1, keepdim=True) + 1e-9) x2_i = x2_i / (x2_i.norm(dim=1, keepdim=True) + 1e-9) cost_i = self.cost_volume_list[idx](x1_i, x2_i) cost_list.append(cost_i) cost = torch.cat(cost_list, dim=1) return cost