import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from utils import to_torch, to_cuda, to_numpy, demo_to_torch from sklearn.base import BaseEstimator class VAE(nn.Module): def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True): super(VAE, self).__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.demo_dim = demo_dim self.use_cuda = use_cuda # Encoder self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda) self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda) # Decoder self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda) self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda) # Batch normalization layers self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda) self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda) def enc(self, x): x = self.bn1(F.relu(self.enc1(x))) z = self.enc2(x) return z def gen(self, n): return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda) def dec(self, z, demo): z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda) x = self.bn2(F.relu(self.dec1(z))) x = self.dec2(x) return x class DemoVAE(BaseEstimator): def __init__(self, **params): self.set_params(**params) @staticmethod def get_default_params(): return dict( latent_dim=32, use_cuda=True, nepochs=1000, pperiod=100, bsize=16, loss_C_mult=1, loss_mu_mult=1, loss_rec_mult=100, loss_decor_mult=10, loss_pred_mult=0.001, alpha=100, LR_C=100, lr=1e-4, weight_decay=0 ) def get_params(self, deep=True): return {k: getattr(self, k) for k in self.get_default_params().keys()} def set_params(self, **params): for k, v in self.get_default_params().items(): setattr(self, k, params.get(k, v)) return self def fit(self, x, demo, demo_types): from utils import train_vae # Calculate demo_dim demo_dim = 0 for d, t in zip(demo, demo_types): if t == 'continuous': demo_dim += 1 elif t == 'categorical': demo_dim += len(set(d)) else: raise ValueError(f'Demographic type "{t}" not supported') # Initialize VAE self.input_dim = x.shape[1] self.demo_dim = demo_dim self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda) # Train VAE train_vae( self.vae, x, demo, demo_types, self.nepochs, self.pperiod, self.bsize, self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult, self.lr, self.weight_decay, self.alpha, self.LR_C, self ) return self def transform(self, x, demo, demo_types): if isinstance(x, int): z = self.vae.gen(x) else: z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda)) demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda) y = self.vae.dec(z, demo_t) return to_numpy(y) def get_latents(self, x): z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda)) return to_numpy(z) def save(self, path): torch.save({ 'model_state_dict': self.vae.state_dict(), 'params': self.get_params(), 'pred_stats': self.pred_stats, 'input_dim': self.input_dim, 'demo_dim': self.demo_dim }, path) def load(self, path): checkpoint = torch.load(path) self.set_params(**checkpoint['params']) self.pred_stats = checkpoint['pred_stats'] self.input_dim = checkpoint['input_dim'] self.demo_dim = checkpoint['demo_dim'] self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda) self.vae.load_state_dict(checkpoint['model_state_dict'])