AphasiaPred / test_small_sample.py
SreekarB's picture
Upload 10 files
763369a verified
import os
# Set Huggingface cache directory to avoid permission issues
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
import numpy as np
import torch
import matplotlib.pyplot as plt
from vae_model import DemoVAE
from visualization import plot_learning_curves, plot_fc_matrices
from config import MODEL_CONFIG
# Create small synthetic dataset with only 5 samples
input_dim = 100
n_samples = 5
demo_dim = 4
print(f"Creating test dataset with {n_samples} samples...")
# Synthetic FC matrices (Upper triangular values)
X = np.random.randn(n_samples, input_dim)
# Synthetic demographics
demo_data = [
np.random.normal(60, 10, n_samples), # age
np.random.choice(['M', 'F'], n_samples), # sex
np.random.normal(24, 12, n_samples), # months post stroke
np.random.normal(50, 15, n_samples) # WAB score
]
# Types of demographics
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
# Initialize model with updated config
print("Config settings:")
print(f"- Epochs: {MODEL_CONFIG['nepochs']}")
print(f"- Batch size: {MODEL_CONFIG['bsize']}")
print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}")
print("Initializing model...")
vae = DemoVAE(**MODEL_CONFIG)
# Train model
print(f"Training model with {n_samples} samples...")
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
print(f"Training complete! Final train loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")
# Save model
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)
print("Saving model...")
vae.save('models/vae_model_small.pt')
# Create learning curve visualization
print("Generating learning curve visualization...")
learning_fig = plot_learning_curves(train_losses, val_losses)
learning_fig.savefig('results/learning_curves_small.png')
print("Learning curve saved to results/learning_curves_small.png")
# Generate reconstructed data
print("Generating reconstructions...")
reconstructed = vae.transform(X, demo_data, demo_types)
# Get a single sample for FC visualization
original = X[0].reshape(10, 10) # Reshape to square matrix for visualization
recon = reconstructed[0].reshape(10, 10)
generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10)
# Create FC visualization
print("Generating FC matrix visualization...")
fc_fig = plot_fc_matrices(original, recon, generated)
fc_fig.savefig('results/fc_visualization_small.png')
print("FC visualization saved to results/fc_visualization_small.png")
print("Test with small sample size completed successfully!")