AphasiaPred / vae_model.py
SreekarB's picture
Upload 13 files
dbe81c1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import to_torch, to_cuda, to_numpy, demo_to_torch
from sklearn.base import BaseEstimator
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
super(VAE, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.demo_dim = demo_dim
self.use_cuda = use_cuda
# Encoder
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
# Decoder
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
# Batch normalization layers
self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
def enc(self, x):
x = self.bn1(F.relu(self.enc1(x)))
z = self.enc2(x)
return z
def gen(self, n):
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
def dec(self, z, demo):
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
x = self.bn2(F.relu(self.dec1(z)))
x = self.dec2(x)
return x
class DemoVAE(BaseEstimator):
def __init__(self, **params):
self.set_params(**params)
@staticmethod
def get_default_params():
return dict(
latent_dim=32,
use_cuda=True,
nepochs=1000,
pperiod=100,
bsize=16,
loss_C_mult=1,
loss_mu_mult=1,
loss_rec_mult=100,
loss_decor_mult=10,
loss_pred_mult=0.001,
alpha=100,
LR_C=100,
lr=1e-4,
weight_decay=0
)
def get_params(self, deep=True):
return {k: getattr(self, k) for k in self.get_default_params().keys()}
def set_params(self, **params):
for k, v in self.get_default_params().items():
setattr(self, k, params.get(k, v))
return self
def fit(self, x, demo, demo_types):
from utils import train_vae
# Calculate demo_dim
demo_dim = 0
for d, t in zip(demo, demo_types):
if t == 'continuous':
demo_dim += 1
elif t == 'categorical':
demo_dim += len(set(d))
else:
raise ValueError(f'Demographic type "{t}" not supported')
# Initialize VAE
self.input_dim = x.shape[1]
self.demo_dim = demo_dim
self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
# Train VAE
train_vae(
self.vae, x, demo, demo_types,
self.nepochs, self.pperiod, self.bsize,
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
self.loss_decor_mult, self.loss_pred_mult,
self.lr, self.weight_decay, self.alpha, self.LR_C,
self
)
return self
def transform(self, x, demo, demo_types):
if isinstance(x, int):
z = self.vae.gen(x)
else:
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
y = self.vae.dec(z, demo_t)
return to_numpy(y)
def get_latents(self, x):
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
return to_numpy(z)
def save(self, path):
torch.save({
'model_state_dict': self.vae.state_dict(),
'params': self.get_params(),
'pred_stats': self.pred_stats,
'input_dim': self.input_dim,
'demo_dim': self.demo_dim
}, path)
def load(self, path):
checkpoint = torch.load(path)
self.set_params(**checkpoint['params'])
self.pred_stats = checkpoint['pred_stats']
self.input_dim = checkpoint['input_dim']
self.demo_dim = checkpoint['demo_dim']
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
self.vae.load_state_dict(checkpoint['model_state_dict'])