""" Losses for meshes Borrowed from: https://github.com/ShichenLiu/SoftRas Note that I changed the implementation of laplacian matrices from dense tensor to COO sparse tensor """ import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd import numpy as np import network.styleunet.conv2d_gradfix as conv2d_gradfix class SecondOrderSmoothnessLossForSequence(nn.Module): def __init__(self): super(SecondOrderSmoothnessLossForSequence, self).__init__() def forward(self, x, dim=0): assert x.shape[dim] > 3 a = x.shape[dim] a0 = torch.arange(0, a-2).long().to(x.device) a1 = torch.arange(1, a-1).long().to(x.device) a2 = torch.arange(2, a).long().to(x.device) x0 = torch.index_select(x, dim, index=a0) x1 = torch.index_select(x, dim, index=a1) x2 = torch.index_select(x, dim, index=a2) l = (2*x1 - x2 - x0).pow(2) return torch.mean(l) class WeightedMSELoss(nn.Module): def __init__(self, reduction: str = 'mean'): super(WeightedMSELoss, self).__init__() self.reduction = reduction def forward(self, pred, target, weight): return F.mse_loss(pred * weight, target * weight, reduction=self.reduction) class CosineSimilarityLoss(nn.Module): def __init__(self, reduction: str = 'mean'): super(CosineSimilarityLoss, self).__init__() self.reduction = reduction if reduction not in ['mean', 'none', 'sum']: raise RuntimeError('Unknown reduction type! It should be in ["mean", "none", "sum"]') def forward(self, pred, target, weight=None, dim=-1, normalized=True): if normalized: # assumes both ```pred``` and ```target``` have been normalized cs = 1 - torch.sum(pred*target, dim=dim) else: cs = 1 - F.cosine_similarity(pred, target, dim=dim) if weight is not None: cs = weight * cs if self.reduction == 'mean': return torch.mean(cs) else: return torch.sum(cs) class LeastMagnitudeLoss(nn.Module): def __init__(self, average=False): super(LeastMagnitudeLoss, self).__init__() self.average = average def forward(self, x): batch_size = x.size(0) dims = tuple(range(x.ndimension())[1:]) x = x.pow(2).sum(dims) if self.average: return x.sum() / batch_size else: return x.sum() class NegIOULoss(nn.Module): def __init__(self, average=False): super(NegIOULoss, self).__init__() self.average = average def forward(self, predict, target): dims = tuple(range(predict.ndimension())[1:]) intersect = (predict * target).sum(dims) union = (predict + target - predict * target).sum(dims) + 1e-6 return 1. - (intersect / union).sum() / intersect.nelement() class KLDLoss(nn.Module): def __init__(self, reduction='mean'): super(KLDLoss, self).__init__() self.reduction = reduction def forward(self, mu, logvar): d = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) if self.reduction == 'mean': return d / mu.shape[0] return d class PhaseTransitionsPotential(nn.Module): """ Refer to: Phase Transitions, Distance Functions, and Implicit Neural Representations """ def __init__(self, reduction='mean'): super(PhaseTransitionsPotential, self).__init__() self.reduction = reduction def forward(self, x): assert torch.all(x >= 0) and torch.all(x <= 1) s = 2 * x - 1 l = s ** 2 - 2 * torch.abs(s) +1 if self.reduction == 'mean': return torch.mean(l) return l class TotalVariationLoss(nn.Module): """ https://discuss.pytorch.org/t/implement-total-variation-loss-in-pytorch/55574 """ def __init__(self, scale_factor=None): super(TotalVariationLoss, self).__init__() self.scale_factor = scale_factor def forward(self, x): if self.scale_factor is not None: x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') assert len(x.shape) == 4 tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum() tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum() l = (tv_h+tv_w) / np.prod(x.shape) return l def d_logistic_loss(real_pred, fake_pred): real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) return real_loss.mean() + fake_loss.mean() def d_r1_loss(real_pred, real_img): with conv2d_gradfix.no_weight_gradients(): grad_real, = autograd.grad( outputs=real_pred.sum(), inputs=real_img, create_graph=True ) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty def g_nonsaturating_loss(fake_pred): loss = F.softplus(-fake_pred).mean() return loss