AphasiaPred / huggingface_fc_visualization.py
SreekarB's picture
Upload 6 files
e81f968 verified
"""
Script to visualize FC matrices from HuggingFace dataset, comparing original FC to VAE-generated FC.
"""
import os
import numpy as np
# Configure matplotlib for headless environment
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
from datasets import load_dataset
from fc_visualization import FCVisualizer
from pathlib import Path
import tempfile
import requests
from config import DATASET_CONFIG, PREPROCESS_CONFIG, MODEL_CONFIG
from data_preprocessing import process_single_fmri
from vae_model import VariationalAutoencoder
def download_sample_fmri(dataset, temp_dir, max_samples=5):
"""
Download sample fMRI files from HuggingFace dataset.
Args:
dataset: HuggingFace dataset object
temp_dir: Directory to save downloaded files
max_samples: Maximum number of samples to download
Returns:
list of paths to downloaded files, demographic data, and file keys
"""
# Get first few samples to search for NIfTI files
nifti_keys = []
# Look through dataset features to find NIfTI files
for i, sample in enumerate(dataset):
if i >= 5: # Check first 5 samples
break
for key, value in sample.items():
if isinstance(value, str) and (value.endswith('.nii') or value.endswith('.nii.gz')):
if key not in nifti_keys:
nifti_keys.append(key)
print(f"Found {len(nifti_keys)} NIfTI file types in the dataset: {nifti_keys}")
if not nifti_keys:
print("No NIfTI files found in the dataset")
return [], [], []
# Collect nifti files and demographics
nifti_files = []
demo_data = []
# Process a limited number of samples
num_samples = min(max_samples, len(dataset))
for sample_idx in range(num_samples):
sample = dataset[sample_idx]
for key in nifti_keys:
try:
file_url = sample[key]
if not file_url or not isinstance(file_url, str):
continue
print(f"Processing sample {sample_idx+1}, file: {key}")
# Download and save the file
local_file = os.path.join(temp_dir, f"sample_{sample_idx}_{key}.nii.gz")
print(f"Downloading {file_url} to {local_file}")
response = requests.get(file_url)
with open(local_file, 'wb') as f:
f.write(response.content)
nifti_files.append(local_file)
# Extract demo data if available (or use placeholders)
age = sample.get('age', 65.0) if 'age' in sample else 65.0
sex = sample.get('sex', 'M') if 'sex' in sample else 'M'
mpo = sample.get('months_post_onset', 12.0) if 'months_post_onset' in sample else 12.0
wab = sample.get('wab_aq', 50.0) if 'wab_aq' in sample else 50.0
demo_sample = [age, sex, mpo, wab]
demo_data.append(demo_sample)
except Exception as e:
print(f"Error processing sample {sample_idx}, {key}: {e}")
return nifti_files, demo_data, nifti_keys
class VariationalAutoencoder:
"""
Simplified VAE implementation for the visualization script.
"""
def __init__(self, n_features, latent_dim, demo_data, demo_types, **kwargs):
"""
Initialize the VAE.
Args:
n_features: Number of input features
latent_dim: Dimension of latent space
demo_data: Demographic data
demo_types: Types of demographic variables
**kwargs: Additional parameters
"""
import torch
import torch.nn as nn
self.n_features = n_features
self.latent_dim = latent_dim
self.demo_dim = self._calculate_demo_dim(demo_data, demo_types)
self.nepochs = kwargs.get('nepochs', 100)
self.batch_size = kwargs.get('bsize', 8)
self.learning_rate = kwargs.get('lr', 1e-3)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Build VAE model
self.encoder = nn.Sequential(
nn.Linear(n_features, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, latent_dim * 2) # mu and logvar
).to(self.device)
self.decoder = nn.Sequential(
nn.Linear(latent_dim + self.demo_dim, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Linear(512, n_features)
).to(self.device)
self.optimizer = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters()),
lr=self.learning_rate
)
self.demo_stats = None # Will be set during training
def _calculate_demo_dim(self, demo_data, demo_types):
"""Calculate dimension of demographic data after one-hot encoding"""
demo_dim = 0
for d, t in zip(demo_data, demo_types):
if t == 'continuous':
demo_dim += 1
elif t == 'categorical':
if isinstance(d[0], str):
# Get unique categories
unique_values = list(set(d))
demo_dim += len(unique_values)
else:
demo_dim += len(set(d))
return demo_dim
def _encode(self, x):
"""Encode input data to latent space"""
import torch
x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device)
h = self.encoder(x_tensor)
mu, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:]
return mu, logvar
def _reparameterize(self, mu, logvar):
"""Reparameterization trick for sampling from latent space"""
import torch
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def _decode(self, z, demo):
"""Decode latent representation back to input space"""
import torch
# Concatenate latent code with demographic data
z_concat = torch.cat([z, demo], dim=1)
return self.decoder(z_concat)
def _prepare_demographics(self, demo_data, demo_types):
"""Convert demographics to tensor with one-hot encoding for categorical variables"""
import torch
import numpy as np
if self.demo_stats is None:
# First time - compute stats
self.demo_stats = []
for d, t in zip(demo_data, demo_types):
if t == 'continuous':
# Standardize continuous features
self.demo_stats.append(('continuous', (np.mean(d), np.std(d))))
elif t == 'categorical':
# Record unique values for one-hot encoding
if isinstance(d[0], str):
unique_values = sorted(list(set(d)))
else:
unique_values = sorted(list(set(d)))
self.demo_stats.append(('categorical', unique_values))
# Process demographics based on saved stats
demo_tensors = []
for (d, (dtype, stats)) in zip(demo_data, self.demo_stats):
if dtype == 'continuous':
mean, std = stats
# Standardize
standardized = (np.array(d) - mean) / (std + 1e-10)
demo_tensors.append(torch.tensor(standardized, dtype=torch.float32).reshape(-1, 1))
else: # categorical
unique_values = stats
# One-hot encode
one_hot_vectors = []
for val in d:
try:
idx = unique_values.index(val)
vec = [0.0] * len(unique_values)
vec[idx] = 1.0
one_hot_vectors.append(vec)
except ValueError:
# Handle unseen categories - use all zeros
vec = [0.0] * len(unique_values)
one_hot_vectors.append(vec)
demo_tensors.append(torch.tensor(one_hot_vectors, dtype=torch.float32))
# Concatenate all demographic features
return torch.cat(demo_tensors, dim=1).to(self.device)
def fit(self, X, demo_data, demo_types):
"""
Train the VAE model.
Args:
X: Input data (FC matrices)
demo_data: List of demographic variables
demo_types: Types of demographic variables
"""
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
print(f"Training VAE on {len(X)} samples for {self.nepochs} epochs...")
# Prepare demographic data
demo_tensor = self._prepare_demographics(demo_data, demo_types)
# Convert input data to tensor
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
# Create dataset and dataloader
dataset = TensorDataset(X_tensor, demo_tensor)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
# Training loop
self.train_losses = []
for epoch in range(self.nepochs):
epoch_losses = []
for batch_x, batch_demo in dataloader:
# Forward pass
mu, logvar = self._encode(batch_x)
z = self._reparameterize(mu, logvar)
x_recon = self._decode(z, batch_demo)
# Compute loss
recon_loss = F.mse_loss(x_recon, batch_x)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
kl_loss = kl_loss / batch_x.size(0) # Normalize by batch size
# Total loss
loss = recon_loss + 0.1 * kl_loss
# Backward and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
epoch_losses.append(loss.item())
# Record average loss for this epoch
avg_loss = np.mean(epoch_losses)
self.train_losses.append(avg_loss)
# Print progress every 10 epochs
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{self.nepochs}, Loss: {avg_loss:.6f}")
print("VAE training complete!")
return self.train_losses
def reconstruct(self, X, demo_data=None, demo_types=None):
"""
Reconstruct input data.
Args:
X: Input data
demo_data: Demographic data (optional)
demo_types: Types of demographic variables (optional)
Returns:
Reconstructed data
"""
import torch
# Set to evaluation mode
self.encoder.eval()
self.decoder.eval()
with torch.no_grad():
# Encode to latent space
mu, _ = self._encode(X)
# Use demo data if provided, otherwise use the demo data from training
if demo_data is not None and demo_types is not None:
demo_tensor = self._prepare_demographics(demo_data, demo_types)
else:
# This would fail if model wasn't trained
raise ValueError("Demo data and types must be provided for reconstruction")
# Decode
recon = self._decode(mu, demo_tensor)
# Convert to numpy
return recon.cpu().numpy()
def generate(self, n_samples, demo_data, demo_types):
"""
Generate new samples from the latent space.
Args:
n_samples: Number of samples to generate
demo_data: Demographic data
demo_types: Types of demographic variables
Returns:
Generated samples
"""
import torch
# Set to evaluation mode
self.decoder.eval()
with torch.no_grad():
# Sample from standard normal
z = torch.randn(n_samples, self.latent_dim).to(self.device)
# Prepare demographic data
demo_tensor = self._prepare_demographics(demo_data, demo_types)
# Check dimensions
if demo_tensor.shape[0] != n_samples:
# Handle mismatch - repeat the first demographic sample
if demo_tensor.shape[0] >= 1:
demo_tensor = demo_tensor[0].unsqueeze(0).repeat(n_samples, 1)
# Generate samples
generated = self._decode(z, demo_tensor)
# Convert to numpy
return generated.cpu().numpy()
def generate_comparison():
"""Download, process and visualize FC matrices from the HuggingFace dataset,
comparing original to VAE-generated matrices."""
print("Loading dataset from HuggingFace...")
# Load the HuggingFace dataset using config
dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData1')
dataset_split = DATASET_CONFIG.get('split', 'train')
dataset = load_dataset(dataset_name, split=dataset_split)
print(f"Dataset loaded: {dataset}")
# Create temporary directory for downloaded NIfTI files
temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
print(f"Created temp directory for NIfTI files: {temp_dir}")
# Download and process fMRI files
nifti_files, demo_samples, nifti_keys = download_sample_fmri(dataset, temp_dir, max_samples=5)
if not nifti_files:
print("No valid fMRI files were found")
return
# Process all fMRI files to FC matrices
fc_matrices = []
demo_data = []
for file_idx, (file_path, demo_sample) in enumerate(zip(nifti_files, demo_samples)):
try:
print(f"Processing file {file_idx+1}/{len(nifti_files)}: {file_path}")
fc_triu = process_single_fmri(file_path, allow_synthetic=False)
fc_matrices.append(fc_triu)
demo_data.append(demo_sample)
except Exception as e:
print(f"Error processing file {file_path}: {e}")
if not fc_matrices:
print("No valid FC matrices were generated")
return
# Convert to numpy arrays
X = np.array(fc_matrices)
# Normalize the data
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
# Prepare demographic data
# Transpose to get [feature_type][sample] format
demo_data = np.array(demo_data).T.tolist()
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
# Train a VAE on the FC matrices
print("Training VAE on the FC matrices...")
n_features = X.shape[1]
# Configure a smaller/faster VAE for demonstration
vae = VariationalAutoencoder(
n_features=n_features,
latent_dim=MODEL_CONFIG.get('latent_dim', 32),
demo_data=demo_data,
demo_types=demo_types,
nepochs=100, # Reduced for demo
bsize=2,
lr=1e-3
)
# Train the VAE
vae.fit(X, demo_data, demo_types)
# Generate reconstructed FC matrices
print("Generating reconstructed FC matrices...")
reconstructed = vae.reconstruct(X, demo_data, demo_types)
# Generate a synthetic FC matrix
print("Generating a synthetic FC matrix...")
# For generating a new sample, we'll use demographics from first patient
first_demo_data = [[d[0]] for d in demo_data]
generated = vae.generate(1, first_demo_data, demo_types)
# Visualize original, reconstructed, and generated FC matrices
visualizer = FCVisualizer()
# Process each sample to generate comparisons
for i in range(min(3, len(X))):
# Convert upper triangular vectors to full matrices for visualization
original_matrix = visualizer._triu_to_matrix(X[i])
recon_matrix = visualizer._triu_to_matrix(reconstructed[i])
# Use the generate method for a single synthetic sample
if i == 0:
gen_matrix = visualizer._triu_to_matrix(generated[0])
# Visualize all three - original, reconstructed, generated
fig = visualizer.plot_matrix_comparison(
[original_matrix, recon_matrix, gen_matrix],
titles=["Original FC", "Reconstructed FC", "Generated FC"]
)
output_file = f"fc_comparison_with_generated.png"
fig.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"Saved full comparison to {output_file}")
# Visualize original vs reconstructed for each sample
fig = visualizer.plot_matrix_comparison(
[original_matrix, recon_matrix],
titles=[f"Original FC (Sample {i+1})", f"Reconstructed FC (Sample {i+1})"]
)
output_file = f"sample_{i}_original_vs_reconstructed.png"
fig.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"Saved comparison to {output_file}")
# Save the matrices
np.save(f"sample_{i}_original_fc.npy", original_matrix)
np.save(f"sample_{i}_reconstructed_fc.npy", recon_matrix)
# Save the generated matrix
np.save("generated_fc.npy", gen_matrix)
print("Processing complete")
if __name__ == "__main__":
generate_comparison()