import numpy as np import torch def kl_from_standard_normal(mean, log_var): kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) return kl.mean() def sample_from_standard_normal(mean, log_var, num=None): std = (0.5 * log_var).exp() shape = mean.shape if num is not None: # expand channel 1 to create several samples shape = shape[:1] + (num,) + shape[1:] mean = mean[:,None,...] std = std[:,None,...] return mean + std * torch.randn(shape, device=mean.device) def ensemble_nll_normal(ensemble, sample, epsilon=1e-5): mean = ensemble.mean(dim=1) var = ensemble.var(dim=1, unbiased=True) + epsilon logvar = var.log() diff = sample[:,None,...] - mean logtwopi = np.log(2*np.pi) nll = (logtwopi + logvar + diff.square() / var).mean() return nll