Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from utils import to_torch, to_cuda, to_numpy, demo_to_torch | |
| from sklearn.base import BaseEstimator | |
| class VAE(nn.Module): | |
| def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True): | |
| super(VAE, self).__init__() | |
| self.input_dim = input_dim | |
| self.latent_dim = latent_dim | |
| self.demo_dim = demo_dim | |
| self.use_cuda = use_cuda | |
| # Encoder | |
| self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda) | |
| self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda) | |
| # Decoder | |
| self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda) | |
| self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda) | |
| # Batch normalization layers | |
| self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda) | |
| self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda) | |
| def enc(self, x): | |
| x = self.bn1(F.relu(self.enc1(x))) | |
| z = self.enc2(x) | |
| return z | |
| def gen(self, n): | |
| return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda) | |
| def dec(self, z, demo): | |
| z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda) | |
| x = self.bn2(F.relu(self.dec1(z))) | |
| x = self.dec2(x) | |
| return x | |
| class DemoVAE(BaseEstimator): | |
| def __init__(self, **params): | |
| self.set_params(**params) | |
| def get_default_params(): | |
| return dict( | |
| latent_dim=32, | |
| use_cuda=True, | |
| nepochs=1000, | |
| pperiod=100, | |
| bsize=16, | |
| loss_C_mult=1, | |
| loss_mu_mult=1, | |
| loss_rec_mult=100, | |
| loss_decor_mult=10, | |
| loss_pred_mult=0.001, | |
| alpha=100, | |
| LR_C=100, | |
| lr=1e-4, | |
| weight_decay=0 | |
| ) | |
| def get_params(self, deep=True): | |
| return {k: getattr(self, k) for k in self.get_default_params().keys()} | |
| def set_params(self, **params): | |
| for k, v in self.get_default_params().items(): | |
| setattr(self, k, params.get(k, v)) | |
| return self | |
| def fit(self, x, demo, demo_types): | |
| from utils import train_vae | |
| # Calculate demo_dim | |
| demo_dim = 0 | |
| for d, t in zip(demo, demo_types): | |
| if t == 'continuous': | |
| demo_dim += 1 | |
| elif t == 'categorical': | |
| demo_dim += len(set(d)) | |
| else: | |
| raise ValueError(f'Demographic type "{t}" not supported') | |
| # Initialize VAE | |
| self.input_dim = x.shape[1] | |
| self.demo_dim = demo_dim | |
| self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda) | |
| # Train VAE | |
| train_vae( | |
| self.vae, x, demo, demo_types, | |
| self.nepochs, self.pperiod, self.bsize, | |
| self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, | |
| self.loss_decor_mult, self.loss_pred_mult, | |
| self.lr, self.weight_decay, self.alpha, self.LR_C, | |
| self | |
| ) | |
| return self | |
| def transform(self, x, demo, demo_types): | |
| if isinstance(x, int): | |
| z = self.vae.gen(x) | |
| else: | |
| z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda)) | |
| demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda) | |
| y = self.vae.dec(z, demo_t) | |
| return to_numpy(y) | |
| def get_latents(self, x): | |
| z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda)) | |
| return to_numpy(z) | |
| def save(self, path): | |
| torch.save({ | |
| 'model_state_dict': self.vae.state_dict(), | |
| 'params': self.get_params(), | |
| 'pred_stats': self.pred_stats, | |
| 'input_dim': self.input_dim, | |
| 'demo_dim': self.demo_dim | |
| }, path) | |
| def load(self, path): | |
| checkpoint = torch.load(path) | |
| self.set_params(**checkpoint['params']) | |
| self.pred_stats = checkpoint['pred_stats'] | |
| self.input_dim = checkpoint['input_dim'] | |
| self.demo_dim = checkpoint['demo_dim'] | |
| self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda) | |
| self.vae.load_state_dict(checkpoint['model_state_dict']) | |