Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Standalone script to visualize FC matrices using the VAE. | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| # Configure matplotlib for headless environment | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| from main import run_fc_analysis | |
| from config import PREDICTION_CONFIG | |
| def main(): | |
| # Configuration | |
| data_dir = "SreekarB/OSFData1" # HuggingFace dataset | |
| latent_dim = 16 | |
| nepochs = 50 | |
| batch_size = 4 | |
| use_hf_dataset = True | |
| # Check if using local data | |
| if os.path.exists(data_dir) and os.path.isdir(data_dir): | |
| print(f"Using local directory: {data_dir}") | |
| use_hf_dataset = False | |
| else: | |
| print(f"Using HuggingFace dataset: {data_dir}") | |
| print(f"Running FC visualization with:") | |
| print(f"- Data source: {data_dir}") | |
| print(f"- Latent dimension: {latent_dim}") | |
| print(f"- Training epochs: {nepochs}") | |
| print(f"- Batch size: {batch_size}") | |
| print(f"- Using HuggingFace API: {use_hf_dataset}") | |
| # Run analysis | |
| try: | |
| # Update config to allow synthetic data | |
| PREDICTION_CONFIG['use_synthetic_nifti'] = True | |
| PREDICTION_CONFIG['use_synthetic_fc'] = True | |
| print("Enabled synthetic data generation") | |
| # Create a dummy demographic file if needed | |
| demo_file = "temp_demographics.csv" | |
| with open(demo_file, "w") as f: | |
| f.write("ID,age_at_stroke,sex,months_post_stroke,wab_score\n") | |
| # Write some dummy data | |
| for i in range(1, 31): # 30 subjects | |
| f.write(f"P{i:02d},{65+i%10},{['M','F'][i%2]},{12+i%24},{50+i%30}\n") | |
| print(f"Created temporary demographic file: {demo_file}") | |
| fig, results = run_fc_analysis( | |
| data_dir=data_dir, | |
| demographic_file=demo_file, | |
| latent_dim=latent_dim, | |
| nepochs=nepochs, | |
| bsize=batch_size, | |
| save_model=True, | |
| use_hf_dataset=use_hf_dataset, | |
| return_data=True | |
| ) | |
| # Save the figure | |
| output_file = "fc_visualization.png" | |
| fig.savefig(output_file, dpi=300, bbox_inches='tight') | |
| print(f"Saved visualization to {output_file}") | |
| # If results are available, calculate some metrics | |
| if results: | |
| X = results.get('X') | |
| reconstructed_fc = results.get('reconstructed_fc') | |
| if X is not None and reconstructed_fc is not None: | |
| # Calculate MSE between original and reconstructed | |
| original = X[0] | |
| recon = reconstructed_fc[0] | |
| # Convert to matrices if needed | |
| from visualization import vector_to_matrix | |
| if len(original.shape) == 1: | |
| original = vector_to_matrix(original) | |
| recon = vector_to_matrix(recon) | |
| # Calculate MSE | |
| mse = np.mean((original - recon) ** 2) | |
| print(f"Reconstruction MSE: {mse:.6f}") | |
| # Save the matrices | |
| np.save("original_fc.npy", original) | |
| np.save("reconstructed_fc.npy", recon) | |
| print("Saved matrices to original_fc.npy and reconstructed_fc.npy") | |
| except Exception as e: | |
| print(f"Error during visualization: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| print("Visualization complete!") | |
| if __name__ == "__main__": | |
| main() |