Spaces:
Sleeping
Sleeping
| import os | |
| # Set Huggingface cache directory to avoid permission issues | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache') | |
| os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from vae_model import DemoVAE | |
| from visualization import plot_learning_curves, plot_fc_matrices | |
| from config import MODEL_CONFIG | |
| # Create small synthetic dataset with only 5 samples | |
| input_dim = 100 | |
| n_samples = 5 | |
| demo_dim = 4 | |
| print(f"Creating test dataset with {n_samples} samples...") | |
| # Synthetic FC matrices (Upper triangular values) | |
| X = np.random.randn(n_samples, input_dim) | |
| # Synthetic demographics | |
| demo_data = [ | |
| np.random.normal(60, 10, n_samples), # age | |
| np.random.choice(['M', 'F'], n_samples), # sex | |
| np.random.normal(24, 12, n_samples), # months post stroke | |
| np.random.normal(50, 15, n_samples) # WAB score | |
| ] | |
| # Types of demographics | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| # Initialize model with updated config | |
| print("Config settings:") | |
| print(f"- Epochs: {MODEL_CONFIG['nepochs']}") | |
| print(f"- Batch size: {MODEL_CONFIG['bsize']}") | |
| print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}") | |
| print("Initializing model...") | |
| vae = DemoVAE(**MODEL_CONFIG) | |
| # Train model | |
| print(f"Training model with {n_samples} samples...") | |
| train_losses, val_losses = vae.fit(X, demo_data, demo_types) | |
| print(f"Training complete! Final train loss: {train_losses[-1]:.4f}") | |
| print(f"Final validation loss: {val_losses[-1]:.4f}") | |
| # Save model | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("results", exist_ok=True) | |
| print("Saving model...") | |
| vae.save('models/vae_model_small.pt') | |
| # Create learning curve visualization | |
| print("Generating learning curve visualization...") | |
| learning_fig = plot_learning_curves(train_losses, val_losses) | |
| learning_fig.savefig('results/learning_curves_small.png') | |
| print("Learning curve saved to results/learning_curves_small.png") | |
| # Generate reconstructed data | |
| print("Generating reconstructions...") | |
| reconstructed = vae.transform(X, demo_data, demo_types) | |
| # Get a single sample for FC visualization | |
| original = X[0].reshape(10, 10) # Reshape to square matrix for visualization | |
| recon = reconstructed[0].reshape(10, 10) | |
| generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10) | |
| # Create FC visualization | |
| print("Generating FC matrix visualization...") | |
| fc_fig = plot_fc_matrices(original, recon, generated) | |
| fc_fig.savefig('results/fc_visualization_small.png') | |
| print("FC visualization saved to results/fc_visualization_small.png") | |
| print("Test with small sample size completed successfully!") |