AphasiaPred / visualize_fc.py
SreekarB's picture
Upload 6 files
e81f968 verified
#!/usr/bin/env python
"""
Standalone script to visualize FC matrices using the VAE.
"""
import os
import sys
import numpy as np
# Configure matplotlib for headless environment
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
from main import run_fc_analysis
from config import PREDICTION_CONFIG
def main():
# Configuration
data_dir = "SreekarB/OSFData1" # HuggingFace dataset
latent_dim = 16
nepochs = 50
batch_size = 4
use_hf_dataset = True
# Check if using local data
if os.path.exists(data_dir) and os.path.isdir(data_dir):
print(f"Using local directory: {data_dir}")
use_hf_dataset = False
else:
print(f"Using HuggingFace dataset: {data_dir}")
print(f"Running FC visualization with:")
print(f"- Data source: {data_dir}")
print(f"- Latent dimension: {latent_dim}")
print(f"- Training epochs: {nepochs}")
print(f"- Batch size: {batch_size}")
print(f"- Using HuggingFace API: {use_hf_dataset}")
# Run analysis
try:
# Update config to allow synthetic data
PREDICTION_CONFIG['use_synthetic_nifti'] = True
PREDICTION_CONFIG['use_synthetic_fc'] = True
print("Enabled synthetic data generation")
# Create a dummy demographic file if needed
demo_file = "temp_demographics.csv"
with open(demo_file, "w") as f:
f.write("ID,age_at_stroke,sex,months_post_stroke,wab_score\n")
# Write some dummy data
for i in range(1, 31): # 30 subjects
f.write(f"P{i:02d},{65+i%10},{['M','F'][i%2]},{12+i%24},{50+i%30}\n")
print(f"Created temporary demographic file: {demo_file}")
fig, results = run_fc_analysis(
data_dir=data_dir,
demographic_file=demo_file,
latent_dim=latent_dim,
nepochs=nepochs,
bsize=batch_size,
save_model=True,
use_hf_dataset=use_hf_dataset,
return_data=True
)
# Save the figure
output_file = "fc_visualization.png"
fig.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"Saved visualization to {output_file}")
# If results are available, calculate some metrics
if results:
X = results.get('X')
reconstructed_fc = results.get('reconstructed_fc')
if X is not None and reconstructed_fc is not None:
# Calculate MSE between original and reconstructed
original = X[0]
recon = reconstructed_fc[0]
# Convert to matrices if needed
from visualization import vector_to_matrix
if len(original.shape) == 1:
original = vector_to_matrix(original)
recon = vector_to_matrix(recon)
# Calculate MSE
mse = np.mean((original - recon) ** 2)
print(f"Reconstruction MSE: {mse:.6f}")
# Save the matrices
np.save("original_fc.npy", original)
np.save("reconstructed_fc.npy", recon)
print("Saved matrices to original_fc.npy and reconstructed_fc.npy")
except Exception as e:
print(f"Error during visualization: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
print("Visualization complete!")
if __name__ == "__main__":
main()