import torch
import torch.nn as nn
import torch.nn.functional as F
from ...config import configurable
from ...modules import ConvGRU
from ..build import DECODER_REGISTRY
[docs]class FlowHead(nn.Module):
"""
Applies two 2D convolutions over an input feature map
to generate a flow tensor of shape N x 2 x H x W.
Parameters
----------
input_dim : int, default: 128
Number of input dimensions.
hidden_dim : int, default: 256
Number of hidden dimensions.
"""
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
[docs] def forward(self, x):
"""
Performs forward pass.
Parameters
----------
x : torch.Tensor
Input tensor of shape N x input_dim x H x W
Returns
-------
torch.Tensor
A tensor of shape N x 2 x H x W
"""
return self.conv2(self.relu(self.conv1(x)))
[docs]class SepConvGRU(nn.Module):
"""
Applies two Convolution GRU cells to the input signal.
Each GRU cell uses separate convolution layers.
Parameters
----------
hidden_dim : int, default: 128
Number of hidden dimensions.
input_dim : int, default: 192 + 128
Number of hidden dimensions.
"""
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
)
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
)
[docs] def forward(self, h, x):
"""
Performs forward pass.
Parameters
----------
h : torch.Tensor
A tensor of shape N x hidden_dim x H x W representating the hidden state
x : torch.Tensor
A tensor of shape N x input_dim + hidden_dim x H x W representating the input
Returns
-------
torch.Tensor
a tensor of shape N x hidden_dim x H x W
"""
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
[docs]class SmallMotionEncoder(nn.Module):
"""
Encodes motion features from the correlation levels of the pyramid
and the input flow estimate using convolution layers.
Parameters
----------
corr_radius : int
Correlation radius of the correlation pyramid
corr_levels : int
Correlation levels of the correlation pyramid
"""
def __init__(self, corr_radius, corr_levels):
super(SmallMotionEncoder, self).__init__()
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
[docs] def forward(self, flow, corr):
"""
Parameters
----------
flow : torch.Tensor
A tensor of shape N x 2 x H x W
corr : torch.Tensor
A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
Returns
-------
torch.Tensor
A tensor of shape N x 82 x H x W
"""
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
[docs]class MotionEncoder(nn.Module):
"""
Encodes motion features from the correlation levels of the pyramid
and the input flow estimate using convolution layers.
Parameters
----------
corr_radius : int
Correlation radius of the correlation pyramid
corr_levels : int
Correlation levels of the correlation pyramid
"""
def __init__(self, corr_radius, corr_levels):
super(MotionEncoder, self).__init__()
cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
[docs] def forward(self, flow, corr):
"""
Parameters
----------
flow : torch.Tensor
A tensor of shape N x 2 x H x W
corr : torch.Tensor
A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
Returns
-------
torch.Tensor
A tensor of shape N x 128 x H x W
"""
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
[docs]@DECODER_REGISTRY.register()
class SmallRecurrentLookupUpdateBlock(nn.Module):
"""
Applies an iterative lookup update on all levels of the correlation
pyramid to estimate flow with a sequence of GRU cells.
Used in **RAFT** (https://arxiv.org/abs/2003.12039)
Parameters
----------
corr_radius : int
Correlation radius of the correlation pyramid
corr_levels : int
Correlation levels of the correlation pyramid
hidden_dim : int, default: 96
Number of hidden dimensions.
input_dim : int, default: 64
Number of input dimensions.
"""
@configurable
def __init__(self, corr_radius, corr_levels, hidden_dim=96, input_dim=64):
super(SmallRecurrentLookupUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(corr_radius, corr_levels)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + input_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
@classmethod
def from_config(cls, cfg):
return {
"corr_radius": cfg.CORR_RADIUS,
"corr_levels": cfg.CORR_LEVELS,
"hidden_dim": cfg.HIDDEN_DIM,
"input_dim": cfg.INPUT_DIM,
}
[docs] def forward(self, net, inp, corr, flow):
"""
Performs forward pass.
Parameters
----------
net : torch.Tensor
A tensor of shape N x hidden_dim x H x W
inp : torch.Tensor
A tensor of shape N x input_dim x H x W
corr : torch.Tensor
A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
flow : torch.Tensor
A tensor of shape N x 2 x H x W
Returns
-------
net : torch.Tensor
A tensor of shape N x hidden_dim x H x W representing the output of the GRU cell
mask : NoneType
delta_flow : torch.Tensor
A tensor of shape N x 2 x H x W representing the delta flow
"""
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
[docs]@DECODER_REGISTRY.register()
class RecurrentLookupUpdateBlock(nn.Module):
"""
Applies an iterative lookup update on all levels of the correlation
pyramid to estimate flow with a sequence of GRU cells.
Used in **RAFT** (https://arxiv.org/abs/2003.12039)
Parameters
----------
corr_radius : int
Correlation radius of the correlation pyramid
corr_levels : int
Correlation levels of the correlation pyramid
hidden_dim : int, default: 128
Number of hidden dimensions.
input_dim : int, default: 128
Number of input dimensions.
"""
@configurable
def __init__(self, corr_radius, corr_levels, hidden_dim=128, input_dim=128):
super(RecurrentLookupUpdateBlock, self).__init__()
self.encoder = MotionEncoder(corr_radius, corr_levels)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=input_dim + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0),
)
@classmethod
def from_config(cls, cfg):
return {
"corr_radius": cfg.CORR_RADIUS,
"corr_levels": cfg.CORR_LEVELS,
"hidden_dim": cfg.HIDDEN_DIM,
"input_dim": cfg.INPUT_DIM,
}
[docs] def forward(self, net, inp, corr, flow):
"""
Performs forward pass.
Parameters
----------
net : torch.Tensor
A tensor of shape N x hidden_dim x H x W
inp : torch.Tensor
A tensor of shape N x input_dim x H x W
corr : torch.Tensor
A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
flow : torch.Tensor
A tensor of shape N x 2 x H x W
Returns
-------
net : torch.Tensor
A tensor of shape N x hidden_dim x H x W representing the output of the SepConvGRU cell.
mask : torch.Tensor
A tensor of shape N x 576 x H x W
delta_flow : torch.Tensor
A tensor of shape N x 2 x H x W representing the delta flow
"""
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
mask = 0.25 * self.mask(net)
return net, mask, delta_flow