Iterative Deocders
Recurrent Lookup
- class ezflow.decoder.iterative.recurrent_lookup.FlowHead(input_dim=128, hidden_dim=256)[source]
Applies two 2D convolutions over an input feature map to generate a flow tensor of shape N x 2 x H x W.
- Parameters
input_dim (int, default: 128) – Number of input dimensions.
hidden_dim (int, default: 256) – Number of hidden dimensions.
- class ezflow.decoder.iterative.recurrent_lookup.MotionEncoder(corr_radius, corr_levels)[source]
Encodes motion features from the correlation levels of the pyramid and the input flow estimate using convolution layers.
- Parameters
corr_radius (int) – Correlation radius of the correlation pyramid
corr_levels (int) – Correlation levels of the correlation pyramid
- class ezflow.decoder.iterative.recurrent_lookup.RecurrentLookupUpdateBlock(corr_radius, corr_levels, hidden_dim=128, input_dim=128)[source]
Applies an iterative lookup update on all levels of the correlation pyramid to estimate flow with a sequence of GRU cells. Used in RAFT (https://arxiv.org/abs/2003.12039)
- Parameters
corr_radius (int) – Correlation radius of the correlation pyramid
corr_levels (int) – Correlation levels of the correlation pyramid
hidden_dim (int, default: 128) – Number of hidden dimensions.
input_dim (int, default: 128) – Number of input dimensions.
- forward(net, inp, corr, flow)[source]
Performs forward pass.
- Parameters
net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W
inp (torch.Tensor) – A tensor of shape N x input_dim x H x W
corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
flow (torch.Tensor) – A tensor of shape N x 2 x H x W
- Returns
net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representing the output of the SepConvGRU cell.
mask (torch.Tensor) – A tensor of shape N x 576 x H x W
delta_flow (torch.Tensor) – A tensor of shape N x 2 x H x W representing the delta flow
- class ezflow.decoder.iterative.recurrent_lookup.SepConvGRU(hidden_dim=128, input_dim=320)[source]
Applies two Convolution GRU cells to the input signal. Each GRU cell uses separate convolution layers.
- Parameters
hidden_dim (int, default: 128) – Number of hidden dimensions.
input_dim (int, default: 192 + 128) – Number of hidden dimensions.
- forward(h, x)[source]
Performs forward pass.
- Parameters
h (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representating the hidden state
x (torch.Tensor) – A tensor of shape N x input_dim + hidden_dim x H x W representating the input
- Returns
a tensor of shape N x hidden_dim x H x W
- Return type
torch.Tensor
- class ezflow.decoder.iterative.recurrent_lookup.SmallMotionEncoder(corr_radius, corr_levels)[source]
Encodes motion features from the correlation levels of the pyramid and the input flow estimate using convolution layers.
- Parameters
corr_radius (int) – Correlation radius of the correlation pyramid
corr_levels (int) – Correlation levels of the correlation pyramid
- class ezflow.decoder.iterative.recurrent_lookup.SmallRecurrentLookupUpdateBlock(corr_radius, corr_levels, hidden_dim=96, input_dim=64)[source]
Applies an iterative lookup update on all levels of the correlation pyramid to estimate flow with a sequence of GRU cells. Used in RAFT (https://arxiv.org/abs/2003.12039)
- Parameters
corr_radius (int) – Correlation radius of the correlation pyramid
corr_levels (int) – Correlation levels of the correlation pyramid
hidden_dim (int, default: 96) – Number of hidden dimensions.
input_dim (int, default: 64) – Number of input dimensions.
- forward(net, inp, corr, flow)[source]
Performs forward pass.
- Parameters
net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W
inp (torch.Tensor) – A tensor of shape N x input_dim x H x W
corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W
flow (torch.Tensor) – A tensor of shape N x 2 x H x W
- Returns
net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representing the output of the GRU cell
mask (NoneType)
delta_flow (torch.Tensor) – A tensor of shape N x 2 x H x W representing the delta flow