Source code for ezflow.utils.resampling

import numpy as np
import torch
import torch.nn.functional as F
from scipy import interpolate


[docs]def forward_interpolate(flow): """ Forward interpolation of flow field Parameters ---------- flow : torch.Tensor Flow field to be interpolated Returns ------- torch.Tensor Forward interpolated flow field """ flow = flow.detach().cpu().numpy() dx, dy = flow[0], flow[1] ht, wd = dx.shape x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) x1 = x0 + dx y1 = y0 + dy x1 = x1.reshape(-1) y1 = y1.reshape(-1) dx = dx.reshape(-1) dy = dy.reshape(-1) valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) x1 = x1[valid] y1 = y1[valid] dx = dx[valid] dy = dy[valid] flow_x = interpolate.griddata( (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 ) flow_y = interpolate.griddata( (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 ) flow = np.stack([flow_x, flow_y], axis=0) return torch.from_numpy(flow).float()
[docs]def bilinear_sampler(img, coords, mask=False): """ Biliear sampler for images Parameters ---------- img : torch.Tensor Image to be sampled coords : torch.Tensor Coordinates to be sampled Returns ------- torch.Tensor Sampled image """ H, W = img.shape[-2:] xgrid, ygrid = coords.split([1, 1], dim=-1) xgrid = 2 * xgrid / (W - 1) - 1 ygrid = 2 * ygrid / (H - 1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img
[docs]def upflow(flow, scale=8, mode="bilinear"): """ Interpolate flow field Parameters ---------- flow : torch.Tensor Flow field to be interpolated scale : int Scale of the interpolated flow field mode : str Interpolation mode Returns ------- torch.Tensor Interpolated flow field """ new_size = (scale * flow.shape[-2], scale * flow.shape[-1]) return scale * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
[docs]def convex_upsample_flow(flow, mask_logits, out_stride): # adapted from RAFT """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination Parameters ---------- flow : torch.Tensor Flow field to be upsampled mask_logits : torch.Tensor Mask logits out_stride : int Output stride Returns ------- torch.Tensor Upsampled flow field """ N, C, H, W = flow.shape mask_logits = mask_logits.view(N, 1, 9, out_stride, out_stride, H, W) mask_probs = torch.softmax(mask_logits, dim=2) up_flow = F.unfold(flow, [3, 3], padding=1) up_flow = up_flow.view(N, C, 9, 1, 1, H, W) up_flow = torch.sum(mask_probs * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, C, out_stride * H, out_stride * W)