Source code for ezflow.modules.dap

import torch.nn as nn

from ..config import configurable
from .build import MODULE_REGISTRY
from .units import ConvNormRelu


[docs]@MODULE_REGISTRY.register() class DisplacementAwareProjection(nn.Module): """ Displacement-aware projection layer Parameters ---------- max_displacement : int, optional Maximum displacement temperature : bool, optional If True, use temperature scaling temp_factor : float, optional Temperature scaling factor """ @configurable def __init__(self, max_displacement=3, temperature=False, temp_factor=1e-6): super(DisplacementAwareProjection, self).__init__() self.temperature = temperature self.temp_factor = temp_factor dim_c = (2 * max_displacement + 1) ** 2 if self.temperature: self.dap_layer = ConvNormRelu( dim_c, 1, kernel_size=1, padding=0, stride=1, norm=None, activation=None ) else: self.dap_layer = ConvNormRelu( dim_c, dim_c, kernel_size=1, padding=0, stride=1, norm=None, activation=None, ) @classmethod def from_config(cls, cfg): return { "max_displacement": cfg.MAX_DISPLACEMENT, "temperature": cfg.TEMPERATURE, "temp_factor": cfg.TEMP_FACTOR, }
[docs] def forward(self, x): x = x.squeeze(1) bs, du, dv, h, w = x.shape x = x.view(bs, du * dv, h, w) if self.temperature: temp = self.dap_layer(x) + self.temp_factor x = x * temp else: x = self.dap_layer(x) return x.view(bs, du, dv, h, w).unsqueeze(1)