Source code for ezflow.utils.warp

import torch
import torch.nn as nn


[docs]def warp(x, flow): """ Warps an image x according to the optical flow field specified Parameters ---------- x : torch.Tensor Image to be warped flow : torch.Tensor Optical flow field Returns ------- torch.Tensor Warped image """ B, _, H, W = x.size() xx = torch.arange(0, W).view(1, -1).repeat(H, 1) yy = torch.arange(0, H).view(-1, 1).repeat(1, W) xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) grid = torch.cat((xx, yy), 1).float() vgrid = torch.Tensor(grid).to(x.device) + flow vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0 vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0 vgrid = vgrid.permute(0, 2, 3, 1) output = nn.functional.grid_sample(x, vgrid, align_corners=True) mask = torch.ones_like(x) mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) mask[mask < 0.9999] = 0 mask[mask > 0] = 1 return output * mask