LCM / utils /nlinalg /nlinalg.py
shivrajanand's picture
Upload folder using huggingface_hub
e8f4897 verified
import torch
def logdet(x):
"""
Args:
x: 2D positive semidefinite matrix.
Returns: log determinant of x
"""
# TODO for pytorch 2.0.4, use inside potrf for variable.
print(torch.log(torch.eig(x.data)[0]))
print(x)
u_chol = x.potrf()
return torch.sum(torch.log(u_chol.diag())) * 2
def logsumexp(x, dim=None):
"""
Args:
x: A pytorch tensor (any dimension will do)
dim: int or None, over which to perform the summation. `None`, the
default, performs over all axes.
Returns: The result of the log(sum(exp(...))) operation.
"""
if dim is None:
xmax = x.max()
xmax_ = x.max()
return xmax_ + torch.log(torch.exp(x - xmax).sum())
else:
xmax, _ = x.max(dim, keepdim=True)
xmax_, _ = x.max(dim)
return xmax_ + torch.log(torch.exp(x - xmax).sum(dim))