Source code for ezflow.decoder.noniterative.operators

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class FlowEntropy(nn.Module): """ Computes entropy from matching cost """ def __init__(self): super(FlowEntropy, self).__init__()
[docs] def forward(self, x): """ Performs forward pass. Parameters ---------- x : torch.Tensor A tensor of shape B x U x V x H x W representing the cost Returns ------- torch.Tensor A tensor of shape B x 1 x H x W """ x = torch.squeeze(x, 1) B, U, V, H, W = x.shape x = x.view(B, -1, H, W) x = F.softmax(x, dim=1).view(B, U, V, H, W) global_entropy = ( (-x * torch.clamp(x, 1e-9, 1 - 1e-9).log()).sum(1).sum(1)[:, np.newaxis] ) global_entropy /= np.log(x.shape[1] * x.shape[2]) return global_entropy