Spaces:
Sleeping
Sleeping
| """ | |
| Script to visualize FC matrices from HuggingFace dataset, comparing original FC to VAE-generated FC. | |
| """ | |
| import os | |
| import numpy as np | |
| # Configure matplotlib for headless environment | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| from datasets import load_dataset | |
| from fc_visualization import FCVisualizer | |
| from pathlib import Path | |
| import tempfile | |
| import requests | |
| from config import DATASET_CONFIG, PREPROCESS_CONFIG, MODEL_CONFIG | |
| from data_preprocessing import process_single_fmri | |
| from vae_model import VariationalAutoencoder | |
| def download_sample_fmri(dataset, temp_dir, max_samples=5): | |
| """ | |
| Download sample fMRI files from HuggingFace dataset. | |
| Args: | |
| dataset: HuggingFace dataset object | |
| temp_dir: Directory to save downloaded files | |
| max_samples: Maximum number of samples to download | |
| Returns: | |
| list of paths to downloaded files, demographic data, and file keys | |
| """ | |
| # Get first few samples to search for NIfTI files | |
| nifti_keys = [] | |
| # Look through dataset features to find NIfTI files | |
| for i, sample in enumerate(dataset): | |
| if i >= 5: # Check first 5 samples | |
| break | |
| for key, value in sample.items(): | |
| if isinstance(value, str) and (value.endswith('.nii') or value.endswith('.nii.gz')): | |
| if key not in nifti_keys: | |
| nifti_keys.append(key) | |
| print(f"Found {len(nifti_keys)} NIfTI file types in the dataset: {nifti_keys}") | |
| if not nifti_keys: | |
| print("No NIfTI files found in the dataset") | |
| return [], [], [] | |
| # Collect nifti files and demographics | |
| nifti_files = [] | |
| demo_data = [] | |
| # Process a limited number of samples | |
| num_samples = min(max_samples, len(dataset)) | |
| for sample_idx in range(num_samples): | |
| sample = dataset[sample_idx] | |
| for key in nifti_keys: | |
| try: | |
| file_url = sample[key] | |
| if not file_url or not isinstance(file_url, str): | |
| continue | |
| print(f"Processing sample {sample_idx+1}, file: {key}") | |
| # Download and save the file | |
| local_file = os.path.join(temp_dir, f"sample_{sample_idx}_{key}.nii.gz") | |
| print(f"Downloading {file_url} to {local_file}") | |
| response = requests.get(file_url) | |
| with open(local_file, 'wb') as f: | |
| f.write(response.content) | |
| nifti_files.append(local_file) | |
| # Extract demo data if available (or use placeholders) | |
| age = sample.get('age', 65.0) if 'age' in sample else 65.0 | |
| sex = sample.get('sex', 'M') if 'sex' in sample else 'M' | |
| mpo = sample.get('months_post_onset', 12.0) if 'months_post_onset' in sample else 12.0 | |
| wab = sample.get('wab_aq', 50.0) if 'wab_aq' in sample else 50.0 | |
| demo_sample = [age, sex, mpo, wab] | |
| demo_data.append(demo_sample) | |
| except Exception as e: | |
| print(f"Error processing sample {sample_idx}, {key}: {e}") | |
| return nifti_files, demo_data, nifti_keys | |
| class VariationalAutoencoder: | |
| """ | |
| Simplified VAE implementation for the visualization script. | |
| """ | |
| def __init__(self, n_features, latent_dim, demo_data, demo_types, **kwargs): | |
| """ | |
| Initialize the VAE. | |
| Args: | |
| n_features: Number of input features | |
| latent_dim: Dimension of latent space | |
| demo_data: Demographic data | |
| demo_types: Types of demographic variables | |
| **kwargs: Additional parameters | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| self.n_features = n_features | |
| self.latent_dim = latent_dim | |
| self.demo_dim = self._calculate_demo_dim(demo_data, demo_types) | |
| self.nepochs = kwargs.get('nepochs', 100) | |
| self.batch_size = kwargs.get('bsize', 8) | |
| self.learning_rate = kwargs.get('lr', 1e-3) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Build VAE model | |
| self.encoder = nn.Sequential( | |
| nn.Linear(n_features, 512), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(256), | |
| nn.Linear(256, latent_dim * 2) # mu and logvar | |
| ).to(self.device) | |
| self.decoder = nn.Sequential( | |
| nn.Linear(latent_dim + self.demo_dim, 256), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(256), | |
| nn.Linear(256, 512), | |
| nn.ReLU(), | |
| nn.BatchNorm1d(512), | |
| nn.Linear(512, n_features) | |
| ).to(self.device) | |
| self.optimizer = torch.optim.Adam( | |
| list(self.encoder.parameters()) + list(self.decoder.parameters()), | |
| lr=self.learning_rate | |
| ) | |
| self.demo_stats = None # Will be set during training | |
| def _calculate_demo_dim(self, demo_data, demo_types): | |
| """Calculate dimension of demographic data after one-hot encoding""" | |
| demo_dim = 0 | |
| for d, t in zip(demo_data, demo_types): | |
| if t == 'continuous': | |
| demo_dim += 1 | |
| elif t == 'categorical': | |
| if isinstance(d[0], str): | |
| # Get unique categories | |
| unique_values = list(set(d)) | |
| demo_dim += len(unique_values) | |
| else: | |
| demo_dim += len(set(d)) | |
| return demo_dim | |
| def _encode(self, x): | |
| """Encode input data to latent space""" | |
| import torch | |
| x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device) | |
| h = self.encoder(x_tensor) | |
| mu, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:] | |
| return mu, logvar | |
| def _reparameterize(self, mu, logvar): | |
| """Reparameterization trick for sampling from latent space""" | |
| import torch | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| z = mu + eps * std | |
| return z | |
| def _decode(self, z, demo): | |
| """Decode latent representation back to input space""" | |
| import torch | |
| # Concatenate latent code with demographic data | |
| z_concat = torch.cat([z, demo], dim=1) | |
| return self.decoder(z_concat) | |
| def _prepare_demographics(self, demo_data, demo_types): | |
| """Convert demographics to tensor with one-hot encoding for categorical variables""" | |
| import torch | |
| import numpy as np | |
| if self.demo_stats is None: | |
| # First time - compute stats | |
| self.demo_stats = [] | |
| for d, t in zip(demo_data, demo_types): | |
| if t == 'continuous': | |
| # Standardize continuous features | |
| self.demo_stats.append(('continuous', (np.mean(d), np.std(d)))) | |
| elif t == 'categorical': | |
| # Record unique values for one-hot encoding | |
| if isinstance(d[0], str): | |
| unique_values = sorted(list(set(d))) | |
| else: | |
| unique_values = sorted(list(set(d))) | |
| self.demo_stats.append(('categorical', unique_values)) | |
| # Process demographics based on saved stats | |
| demo_tensors = [] | |
| for (d, (dtype, stats)) in zip(demo_data, self.demo_stats): | |
| if dtype == 'continuous': | |
| mean, std = stats | |
| # Standardize | |
| standardized = (np.array(d) - mean) / (std + 1e-10) | |
| demo_tensors.append(torch.tensor(standardized, dtype=torch.float32).reshape(-1, 1)) | |
| else: # categorical | |
| unique_values = stats | |
| # One-hot encode | |
| one_hot_vectors = [] | |
| for val in d: | |
| try: | |
| idx = unique_values.index(val) | |
| vec = [0.0] * len(unique_values) | |
| vec[idx] = 1.0 | |
| one_hot_vectors.append(vec) | |
| except ValueError: | |
| # Handle unseen categories - use all zeros | |
| vec = [0.0] * len(unique_values) | |
| one_hot_vectors.append(vec) | |
| demo_tensors.append(torch.tensor(one_hot_vectors, dtype=torch.float32)) | |
| # Concatenate all demographic features | |
| return torch.cat(demo_tensors, dim=1).to(self.device) | |
| def fit(self, X, demo_data, demo_types): | |
| """ | |
| Train the VAE model. | |
| Args: | |
| X: Input data (FC matrices) | |
| demo_data: List of demographic variables | |
| demo_types: Types of demographic variables | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torch.utils.data import DataLoader, TensorDataset | |
| print(f"Training VAE on {len(X)} samples for {self.nepochs} epochs...") | |
| # Prepare demographic data | |
| demo_tensor = self._prepare_demographics(demo_data, demo_types) | |
| # Convert input data to tensor | |
| X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device) | |
| # Create dataset and dataloader | |
| dataset = TensorDataset(X_tensor, demo_tensor) | |
| dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) | |
| # Training loop | |
| self.train_losses = [] | |
| for epoch in range(self.nepochs): | |
| epoch_losses = [] | |
| for batch_x, batch_demo in dataloader: | |
| # Forward pass | |
| mu, logvar = self._encode(batch_x) | |
| z = self._reparameterize(mu, logvar) | |
| x_recon = self._decode(z, batch_demo) | |
| # Compute loss | |
| recon_loss = F.mse_loss(x_recon, batch_x) | |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| kl_loss = kl_loss / batch_x.size(0) # Normalize by batch size | |
| # Total loss | |
| loss = recon_loss + 0.1 * kl_loss | |
| # Backward and optimize | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| epoch_losses.append(loss.item()) | |
| # Record average loss for this epoch | |
| avg_loss = np.mean(epoch_losses) | |
| self.train_losses.append(avg_loss) | |
| # Print progress every 10 epochs | |
| if (epoch + 1) % 10 == 0: | |
| print(f"Epoch {epoch+1}/{self.nepochs}, Loss: {avg_loss:.6f}") | |
| print("VAE training complete!") | |
| return self.train_losses | |
| def reconstruct(self, X, demo_data=None, demo_types=None): | |
| """ | |
| Reconstruct input data. | |
| Args: | |
| X: Input data | |
| demo_data: Demographic data (optional) | |
| demo_types: Types of demographic variables (optional) | |
| Returns: | |
| Reconstructed data | |
| """ | |
| import torch | |
| # Set to evaluation mode | |
| self.encoder.eval() | |
| self.decoder.eval() | |
| with torch.no_grad(): | |
| # Encode to latent space | |
| mu, _ = self._encode(X) | |
| # Use demo data if provided, otherwise use the demo data from training | |
| if demo_data is not None and demo_types is not None: | |
| demo_tensor = self._prepare_demographics(demo_data, demo_types) | |
| else: | |
| # This would fail if model wasn't trained | |
| raise ValueError("Demo data and types must be provided for reconstruction") | |
| # Decode | |
| recon = self._decode(mu, demo_tensor) | |
| # Convert to numpy | |
| return recon.cpu().numpy() | |
| def generate(self, n_samples, demo_data, demo_types): | |
| """ | |
| Generate new samples from the latent space. | |
| Args: | |
| n_samples: Number of samples to generate | |
| demo_data: Demographic data | |
| demo_types: Types of demographic variables | |
| Returns: | |
| Generated samples | |
| """ | |
| import torch | |
| # Set to evaluation mode | |
| self.decoder.eval() | |
| with torch.no_grad(): | |
| # Sample from standard normal | |
| z = torch.randn(n_samples, self.latent_dim).to(self.device) | |
| # Prepare demographic data | |
| demo_tensor = self._prepare_demographics(demo_data, demo_types) | |
| # Check dimensions | |
| if demo_tensor.shape[0] != n_samples: | |
| # Handle mismatch - repeat the first demographic sample | |
| if demo_tensor.shape[0] >= 1: | |
| demo_tensor = demo_tensor[0].unsqueeze(0).repeat(n_samples, 1) | |
| # Generate samples | |
| generated = self._decode(z, demo_tensor) | |
| # Convert to numpy | |
| return generated.cpu().numpy() | |
| def generate_comparison(): | |
| """Download, process and visualize FC matrices from the HuggingFace dataset, | |
| comparing original to VAE-generated matrices.""" | |
| print("Loading dataset from HuggingFace...") | |
| # Load the HuggingFace dataset using config | |
| dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData1') | |
| dataset_split = DATASET_CONFIG.get('split', 'train') | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| print(f"Dataset loaded: {dataset}") | |
| # Create temporary directory for downloaded NIfTI files | |
| temp_dir = tempfile.mkdtemp(prefix="hf_nifti_") | |
| print(f"Created temp directory for NIfTI files: {temp_dir}") | |
| # Download and process fMRI files | |
| nifti_files, demo_samples, nifti_keys = download_sample_fmri(dataset, temp_dir, max_samples=5) | |
| if not nifti_files: | |
| print("No valid fMRI files were found") | |
| return | |
| # Process all fMRI files to FC matrices | |
| fc_matrices = [] | |
| demo_data = [] | |
| for file_idx, (file_path, demo_sample) in enumerate(zip(nifti_files, demo_samples)): | |
| try: | |
| print(f"Processing file {file_idx+1}/{len(nifti_files)}: {file_path}") | |
| fc_triu = process_single_fmri(file_path, allow_synthetic=False) | |
| fc_matrices.append(fc_triu) | |
| demo_data.append(demo_sample) | |
| except Exception as e: | |
| print(f"Error processing file {file_path}: {e}") | |
| if not fc_matrices: | |
| print("No valid FC matrices were generated") | |
| return | |
| # Convert to numpy arrays | |
| X = np.array(fc_matrices) | |
| # Normalize the data | |
| X = (X - np.mean(X, axis=0)) / np.std(X, axis=0) | |
| # Prepare demographic data | |
| # Transpose to get [feature_type][sample] format | |
| demo_data = np.array(demo_data).T.tolist() | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| # Train a VAE on the FC matrices | |
| print("Training VAE on the FC matrices...") | |
| n_features = X.shape[1] | |
| # Configure a smaller/faster VAE for demonstration | |
| vae = VariationalAutoencoder( | |
| n_features=n_features, | |
| latent_dim=MODEL_CONFIG.get('latent_dim', 32), | |
| demo_data=demo_data, | |
| demo_types=demo_types, | |
| nepochs=100, # Reduced for demo | |
| bsize=2, | |
| lr=1e-3 | |
| ) | |
| # Train the VAE | |
| vae.fit(X, demo_data, demo_types) | |
| # Generate reconstructed FC matrices | |
| print("Generating reconstructed FC matrices...") | |
| reconstructed = vae.reconstruct(X, demo_data, demo_types) | |
| # Generate a synthetic FC matrix | |
| print("Generating a synthetic FC matrix...") | |
| # For generating a new sample, we'll use demographics from first patient | |
| first_demo_data = [[d[0]] for d in demo_data] | |
| generated = vae.generate(1, first_demo_data, demo_types) | |
| # Visualize original, reconstructed, and generated FC matrices | |
| visualizer = FCVisualizer() | |
| # Process each sample to generate comparisons | |
| for i in range(min(3, len(X))): | |
| # Convert upper triangular vectors to full matrices for visualization | |
| original_matrix = visualizer._triu_to_matrix(X[i]) | |
| recon_matrix = visualizer._triu_to_matrix(reconstructed[i]) | |
| # Use the generate method for a single synthetic sample | |
| if i == 0: | |
| gen_matrix = visualizer._triu_to_matrix(generated[0]) | |
| # Visualize all three - original, reconstructed, generated | |
| fig = visualizer.plot_matrix_comparison( | |
| [original_matrix, recon_matrix, gen_matrix], | |
| titles=["Original FC", "Reconstructed FC", "Generated FC"] | |
| ) | |
| output_file = f"fc_comparison_with_generated.png" | |
| fig.savefig(output_file, dpi=300, bbox_inches='tight') | |
| print(f"Saved full comparison to {output_file}") | |
| # Visualize original vs reconstructed for each sample | |
| fig = visualizer.plot_matrix_comparison( | |
| [original_matrix, recon_matrix], | |
| titles=[f"Original FC (Sample {i+1})", f"Reconstructed FC (Sample {i+1})"] | |
| ) | |
| output_file = f"sample_{i}_original_vs_reconstructed.png" | |
| fig.savefig(output_file, dpi=300, bbox_inches='tight') | |
| print(f"Saved comparison to {output_file}") | |
| # Save the matrices | |
| np.save(f"sample_{i}_original_fc.npy", original_matrix) | |
| np.save(f"sample_{i}_reconstructed_fc.npy", recon_matrix) | |
| # Save the generated matrix | |
| np.save("generated_fc.npy", gen_matrix) | |
| print("Processing complete") | |
| if __name__ == "__main__": | |
| generate_comparison() |