pengc02's picture
all
ec9a6bc
"""
Losses for meshes
Borrowed from: https://github.com/ShichenLiu/SoftRas
Note that I changed the implementation of laplacian matrices from dense tensor to COO sparse tensor
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
import network.styleunet.conv2d_gradfix as conv2d_gradfix
class SecondOrderSmoothnessLossForSequence(nn.Module):
def __init__(self):
super(SecondOrderSmoothnessLossForSequence, self).__init__()
def forward(self, x, dim=0):
assert x.shape[dim] > 3
a = x.shape[dim]
a0 = torch.arange(0, a-2).long().to(x.device)
a1 = torch.arange(1, a-1).long().to(x.device)
a2 = torch.arange(2, a).long().to(x.device)
x0 = torch.index_select(x, dim, index=a0)
x1 = torch.index_select(x, dim, index=a1)
x2 = torch.index_select(x, dim, index=a2)
l = (2*x1 - x2 - x0).pow(2)
return torch.mean(l)
class WeightedMSELoss(nn.Module):
def __init__(self, reduction: str = 'mean'):
super(WeightedMSELoss, self).__init__()
self.reduction = reduction
def forward(self, pred, target, weight):
return F.mse_loss(pred * weight, target * weight, reduction=self.reduction)
class CosineSimilarityLoss(nn.Module):
def __init__(self, reduction: str = 'mean'):
super(CosineSimilarityLoss, self).__init__()
self.reduction = reduction
if reduction not in ['mean', 'none', 'sum']:
raise RuntimeError('Unknown reduction type! It should be in ["mean", "none", "sum"]')
def forward(self, pred, target, weight=None, dim=-1, normalized=True):
if normalized: # assumes both ```pred``` and ```target``` have been normalized
cs = 1 - torch.sum(pred*target, dim=dim)
else:
cs = 1 - F.cosine_similarity(pred, target, dim=dim)
if weight is not None:
cs = weight * cs
if self.reduction == 'mean':
return torch.mean(cs)
else:
return torch.sum(cs)
class LeastMagnitudeLoss(nn.Module):
def __init__(self, average=False):
super(LeastMagnitudeLoss, self).__init__()
self.average = average
def forward(self, x):
batch_size = x.size(0)
dims = tuple(range(x.ndimension())[1:])
x = x.pow(2).sum(dims)
if self.average:
return x.sum() / batch_size
else:
return x.sum()
class NegIOULoss(nn.Module):
def __init__(self, average=False):
super(NegIOULoss, self).__init__()
self.average = average
def forward(self, predict, target):
dims = tuple(range(predict.ndimension())[1:])
intersect = (predict * target).sum(dims)
union = (predict + target - predict * target).sum(dims) + 1e-6
return 1. - (intersect / union).sum() / intersect.nelement()
class KLDLoss(nn.Module):
def __init__(self, reduction='mean'):
super(KLDLoss, self).__init__()
self.reduction = reduction
def forward(self, mu, logvar):
d = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
if self.reduction == 'mean':
return d / mu.shape[0]
return d
class PhaseTransitionsPotential(nn.Module):
"""
Refer to: Phase Transitions, Distance Functions, and Implicit Neural Representations
"""
def __init__(self, reduction='mean'):
super(PhaseTransitionsPotential, self).__init__()
self.reduction = reduction
def forward(self, x):
assert torch.all(x >= 0) and torch.all(x <= 1)
s = 2 * x - 1
l = s ** 2 - 2 * torch.abs(s) +1
if self.reduction == 'mean':
return torch.mean(l)
return l
class TotalVariationLoss(nn.Module):
"""
https://discuss.pytorch.org/t/implement-total-variation-loss-in-pytorch/55574
"""
def __init__(self, scale_factor=None):
super(TotalVariationLoss, self).__init__()
self.scale_factor = scale_factor
def forward(self, x):
if self.scale_factor is not None:
x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
assert len(x.shape) == 4
tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()
l = (tv_h+tv_w) / np.prod(x.shape)
return l
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
with conv2d_gradfix.no_weight_gradients():
grad_real, = autograd.grad(
outputs=real_pred.sum(), inputs=real_img, create_graph=True
)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss