Source code for ezflow.utils.metrics

import torch


[docs]def endpointerror(pred, flow_gt, valid=None, multi_magnitude=False, **kwargs): """ Endpoint error Parameters ---------- pred : torch.Tensor Predicted flow flow_gt : torch.Tensor flow_gt flow valid : torch.Tensor Valid flow vectors Returns ------- torch.Tensor Endpoint error """ if isinstance(pred, tuple) or isinstance(pred, list): pred = pred[-1] epe = torch.norm(pred - flow_gt, p=2, dim=1) f1 = None if valid is not None: mag = torch.sum(flow_gt**2, dim=1).sqrt() epe = epe.view(-1) mag = mag.view(-1) val = valid.reshape(-1) >= 0.5 f1 = ((epe > 3.0) & ((epe / mag) > 0.05)).float() epe = epe[val] f1 = f1[val].cpu().numpy() if not multi_magnitude: if f1 is not None: return epe.mean().item(), f1 return epe.mean().item() epe = epe.view(-1) multi_magnitude_epe = { "epe": epe.mean().item(), "1px": (epe < 1).float().mean().item(), "3px": (epe < 3).float().mean().item(), "5px": (epe < 5).float().mean().item(), } if f1 is not None: return multi_magnitude_epe, f1 return multi_magnitude_epe