Source code for ezflow.similarity.correlation.pairwise

import torch
import torch.nn.functional as F

from ...config import configurable
from ...utils import bilinear_sampler
from ..build import SIMILARITY_REGISTRY


[docs]@SIMILARITY_REGISTRY.register() class MutliScalePairwise4DCorr: """ Pairwise 4D correlation at multiple scales. Used in **RAFT** (https://arxiv.org/abs/2003.12039) Parameters ---------- fmap1 : torch.Tensor First feature map fmap2 : torch.Tensor Second feature map num_levels : int Number of levels in the feature pyramid corr_radius : int Radius of the correlation window """ @configurable def __init__(self, fmap1, fmap2, num_levels=4, corr_radius=4): self.num_levels = num_levels self.corr_radius = corr_radius self.corr_pyramid = [] corr = MutliScalePairwise4DCorr.corr(fmap1, fmap2) batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.reshape(batch * h1 * w1, dim, h2, w2) self.corr_pyramid.append(corr) for _ in range(self.num_levels - 1): corr = F.avg_pool2d(corr, 2, stride=2) self.corr_pyramid.append(corr) def __call__(self, coords): r = self.corr_radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] dx = torch.linspace(-r, r, 2 * r + 1) dy = torch.linspace(-r, r, 2 * r + 1) delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( coords.device ) centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) corr = corr.view(batch, h1, w1, -1) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() @classmethod def from_config(cls, cfg): return { "num_levels": cfg.NUM_LEVELS, "corr_radius": cfg.CORR_RADIUS, } @staticmethod def corr(fmap1, fmap2): batch, dim, ht, wd = fmap1.shape fmap1 = fmap1.view(batch, dim, ht * wd) fmap2 = fmap2.view(batch, dim, ht * wd) corr = torch.matmul(fmap1.transpose(1, 2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) return corr / torch.sqrt(torch.tensor(dim).float())