Iterative Deocders

Recurrent Lookup

class ezflow.decoder.iterative.recurrent_lookup.FlowHead(input_dim=128, hidden_dim=256)[source]

Applies two 2D convolutions over an input feature map to generate a flow tensor of shape N x 2 x H x W.

Parameters
  • input_dim (int, default: 128) – Number of input dimensions.

  • hidden_dim (int, default: 256) – Number of hidden dimensions.

forward(x)[source]

Performs forward pass.

Parameters

x (torch.Tensor) – Input tensor of shape N x input_dim x H x W

Returns

A tensor of shape N x 2 x H x W

Return type

torch.Tensor

class ezflow.decoder.iterative.recurrent_lookup.MotionEncoder(corr_radius, corr_levels)[source]

Encodes motion features from the correlation levels of the pyramid and the input flow estimate using convolution layers.

Parameters
  • corr_radius (int) – Correlation radius of the correlation pyramid

  • corr_levels (int) – Correlation levels of the correlation pyramid

forward(flow, corr)[source]
Parameters
  • flow (torch.Tensor) – A tensor of shape N x 2 x H x W

  • corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W

Returns

A tensor of shape N x 128 x H x W

Return type

torch.Tensor

class ezflow.decoder.iterative.recurrent_lookup.RecurrentLookupUpdateBlock(corr_radius, corr_levels, hidden_dim=128, input_dim=128)[source]

Applies an iterative lookup update on all levels of the correlation pyramid to estimate flow with a sequence of GRU cells. Used in RAFT (https://arxiv.org/abs/2003.12039)

Parameters
  • corr_radius (int) – Correlation radius of the correlation pyramid

  • corr_levels (int) – Correlation levels of the correlation pyramid

  • hidden_dim (int, default: 128) – Number of hidden dimensions.

  • input_dim (int, default: 128) – Number of input dimensions.

forward(net, inp, corr, flow)[source]

Performs forward pass.

Parameters
  • net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W

  • inp (torch.Tensor) – A tensor of shape N x input_dim x H x W

  • corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W

  • flow (torch.Tensor) – A tensor of shape N x 2 x H x W

Returns

  • net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representing the output of the SepConvGRU cell.

  • mask (torch.Tensor) – A tensor of shape N x 576 x H x W

  • delta_flow (torch.Tensor) – A tensor of shape N x 2 x H x W representing the delta flow

class ezflow.decoder.iterative.recurrent_lookup.SepConvGRU(hidden_dim=128, input_dim=320)[source]

Applies two Convolution GRU cells to the input signal. Each GRU cell uses separate convolution layers.

Parameters
  • hidden_dim (int, default: 128) – Number of hidden dimensions.

  • input_dim (int, default: 192 + 128) – Number of hidden dimensions.

forward(h, x)[source]

Performs forward pass.

Parameters
  • h (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representating the hidden state

  • x (torch.Tensor) – A tensor of shape N x input_dim + hidden_dim x H x W representating the input

Returns

a tensor of shape N x hidden_dim x H x W

Return type

torch.Tensor

class ezflow.decoder.iterative.recurrent_lookup.SmallMotionEncoder(corr_radius, corr_levels)[source]

Encodes motion features from the correlation levels of the pyramid and the input flow estimate using convolution layers.

Parameters
  • corr_radius (int) – Correlation radius of the correlation pyramid

  • corr_levels (int) – Correlation levels of the correlation pyramid

forward(flow, corr)[source]
Parameters
  • flow (torch.Tensor) – A tensor of shape N x 2 x H x W

  • corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W

Returns

A tensor of shape N x 82 x H x W

Return type

torch.Tensor

class ezflow.decoder.iterative.recurrent_lookup.SmallRecurrentLookupUpdateBlock(corr_radius, corr_levels, hidden_dim=96, input_dim=64)[source]

Applies an iterative lookup update on all levels of the correlation pyramid to estimate flow with a sequence of GRU cells. Used in RAFT (https://arxiv.org/abs/2003.12039)

Parameters
  • corr_radius (int) – Correlation radius of the correlation pyramid

  • corr_levels (int) – Correlation levels of the correlation pyramid

  • hidden_dim (int, default: 96) – Number of hidden dimensions.

  • input_dim (int, default: 64) – Number of input dimensions.

forward(net, inp, corr, flow)[source]

Performs forward pass.

Parameters
  • net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W

  • inp (torch.Tensor) – A tensor of shape N x input_dim x H x W

  • corr (torch.Tensor) – A tensor of shape N x (corr_levels * (2 * corr_radius + 1) ** 2) x H x W

  • flow (torch.Tensor) – A tensor of shape N x 2 x H x W

Returns

  • net (torch.Tensor) – A tensor of shape N x hidden_dim x H x W representing the output of the GRU cell

  • mask (NoneType)

  • delta_flow (torch.Tensor) – A tensor of shape N x 2 x H x W representing the delta flow