| import torch | |
| def logdet(x): | |
| """ | |
| Args: | |
| x: 2D positive semidefinite matrix. | |
| Returns: log determinant of x | |
| """ | |
| # TODO for pytorch 2.0.4, use inside potrf for variable. | |
| print(torch.log(torch.eig(x.data)[0])) | |
| print(x) | |
| u_chol = x.potrf() | |
| return torch.sum(torch.log(u_chol.diag())) * 2 | |
| def logsumexp(x, dim=None): | |
| """ | |
| Args: | |
| x: A pytorch tensor (any dimension will do) | |
| dim: int or None, over which to perform the summation. `None`, the | |
| default, performs over all axes. | |
| Returns: The result of the log(sum(exp(...))) operation. | |
| """ | |
| if dim is None: | |
| xmax = x.max() | |
| xmax_ = x.max() | |
| return xmax_ + torch.log(torch.exp(x - xmax).sum()) | |
| else: | |
| xmax, _ = x.max(dim, keepdim=True) | |
| xmax_, _ = x.max(dim) | |
| return xmax_ + torch.log(torch.exp(x - xmax).sum(dim)) | |