AphasiaPred / data_preprocessing.py
SreekarB's picture
Upload 3 files
1c47445 verified
import numpy as np
import pandas as pd
from datasets import load_dataset
from nilearn import input_data, connectome
from nilearn.image import load_img
import nibabel as nib
import os
def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
"""
Process fMRI data to generate functional connectivity matrices
Parameters:
- dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
- demo_data: Optional demographic data, required if providing NIfTI files
- demo_types: Optional demographic data types, required if providing NIfTI files
Returns:
- X: Array of FC matrices
- demo_data: Demographic data
- demo_types: Demographic data types
"""
print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
# For SreekarB/OSFData dataset, the data will be loaded from dataset features
if isinstance(dataset_or_niifiles, str):
dataset_name = dataset_or_niifiles
print(f"Loading data from dataset: {dataset_name}")
try:
# Try multiple approaches to load the dataset
approaches = [
lambda: load_dataset(dataset_name, split="train"),
lambda: load_dataset(dataset_name), # Try without split
lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
]
dataset = None
last_error = None
for i, approach in enumerate(approaches):
if approach is None:
continue
try:
print(f"Attempt {i+1} to load dataset...")
dataset = approach()
print(f"Successfully loaded dataset with approach {i+1}!")
break
except Exception as e:
print(f"Attempt {i+1} failed: {e}")
last_error = e
if dataset is None:
print(f"All attempts to load dataset failed. Last error: {last_error}")
raise ValueError(f"Could not load dataset {dataset_name}")
except Exception as e:
print(f"Error during dataset loading: {e}")
raise
# Prepare demographics data from the dataset
if demo_data is None:
# Create demo_data from the dataset
demo_df = pd.DataFrame({
'age': dataset['age'],
'gender': dataset['gender'],
'mpo': dataset['mpo'],
'wab_aq': dataset['wab_aq']
})
demo_data = [
demo_df['age'].values,
demo_df['gender'].values,
demo_df['mpo'].values,
demo_df['wab_aq'].values
]
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
# Look for NIfTI files in P01_rs.nii format
print("Searching for NIfTI files in dataset columns...")
nii_files = []
# Create a temp directory for downloads
import tempfile
from huggingface_hub import hf_hub_download
import shutil
temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
print(f"Created temporary directory for NIfTI files: {temp_dir}")
try:
# First approach: Check if there are any columns containing file paths
nii_columns = []
for col in dataset.column_names:
# Check if column name suggests NIfTI files
if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
nii_columns.append(col)
# Or check if column contains file paths
elif len(dataset) > 0:
first_val = dataset[0][col]
if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
nii_columns.append(col)
if nii_columns:
print(f"Found columns that may contain NIfTI files: {nii_columns}")
for col in nii_columns:
print(f"Processing column '{col}'...")
for i, item in enumerate(dataset[col]):
if not isinstance(item, str):
print(f"Item {i} in column {col} is not a string but {type(item)}")
continue
if not (item.endswith('.nii') or item.endswith('.nii.gz')):
print(f"Item {i} in column {col} is not a NIfTI file: {item}")
continue
print(f"Downloading {item} from dataset {dataset_name}...")
try:
# Attempt to download with explicit filename
file_path = hf_hub_download(
repo_id=dataset_name,
filename=item,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Successfully downloaded {item}")
except Exception as e1:
print(f"Error downloading with explicit filename: {e1}")
# Second attempt: try with the item's basename
try:
basename = os.path.basename(item)
print(f"Trying with basename: {basename}")
file_path = hf_hub_download(
repo_id=dataset_name,
filename=basename,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Successfully downloaded {basename}")
except Exception as e2:
print(f"Error downloading with basename: {e2}")
# Third attempt: check if it's a binary blob in the dataset
try:
if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
print("Found binary data in dataset, saving to temporary file...")
binary_data = dataset[i]['bytes']
temp_file = os.path.join(temp_dir, basename)
with open(temp_file, 'wb') as f:
f.write(binary_data)
nii_files.append(temp_file)
print(f"✓ Saved binary data to {temp_file}")
except Exception as e3:
print(f"Error handling binary data: {e3}")
# Last resort: look for the file locally
local_path = os.path.join(os.getcwd(), item)
if os.path.exists(local_path):
nii_files.append(local_path)
print(f"✓ Found {item} locally")
else:
print(f"❌ Warning: Could not find {item} anywhere")
# Second approach: Try to find NIfTI files in dataset repository directly
if not nii_files:
print("No NIfTI files found in dataset columns. Trying direct repository search...")
try:
from huggingface_hub import list_repo_files, hf_hub_download
# Try to list all files in the repository
try:
print("Listing all repository files...")
all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
print(f"Found {len(all_repo_files)} files in repository")
# First prioritize P*_rs.nii files
p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
# Then include all other NIfTI files
other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
# Combine, with P*_rs.nii files first
nii_repo_files = p_rs_files + other_nii_files
if nii_repo_files:
print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
# Download each file
for nii_file in nii_repo_files:
try:
file_path = hf_hub_download(
repo_id=dataset_name,
filename=nii_file,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Downloaded {nii_file}")
except Exception as e:
print(f"Error downloading {nii_file}: {e}")
except Exception as e:
print(f"Error listing repository files: {e}")
print("Will try alternative approaches...")
# If repo listing fails, try with common NIfTI file patterns directly
if not nii_files:
print("Trying common NIfTI file patterns...")
# Focus specifically on P*_rs.nii pattern
patterns = []
# Generate P01_rs.nii through P30_rs.nii
for i in range(1, 31): # Try subjects 1-30
patterns.append(f"P{i:02d}_rs.nii")
# Also try with .nii.gz extension
for i in range(1, 31):
patterns.append(f"P{i:02d}_rs.nii.gz")
# Include a few other common patterns as fallbacks
patterns.extend([
"sub-01_task-rest_bold.nii.gz", # BIDS format
"fmri.nii.gz", "bold.nii.gz",
"rest.nii.gz"
])
for pattern in patterns:
try:
print(f"Trying to download {pattern}...")
file_path = hf_hub_download(
repo_id=dataset_name,
filename=pattern,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Successfully downloaded {pattern}")
except Exception as e:
print(f"× Failed to download {pattern}")
# If we still couldn't find any files, check if data files are nested
if not nii_files:
print("Checking for nested data files...")
nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
for path in nested_paths:
for pattern in patterns:
nested_file = f"{path}{pattern}"
try:
print(f"Trying to download {nested_file}...")
file_path = hf_hub_download(
repo_id=dataset_name,
filename=nested_file,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Successfully downloaded {nested_file}")
# If we found one file in this directory, try to find all files in it
try:
all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
for nii_file in nii_files_in_dir:
if nii_file != nested_file: # Skip the one we already downloaded
try:
file_path = hf_hub_download(
repo_id=dataset_name,
filename=nii_file,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Downloaded {nii_file}")
except Exception as e:
print(f"Error downloading {nii_file}: {e}")
except Exception as e:
print(f"Error finding additional files in {path}: {e}")
except Exception as e:
pass
except Exception as e:
print(f"Error during repository exploration: {e}")
# If we still don't have any files, try to search for P*_rs.nii pattern specifically
if not nii_files:
print("Trying to find files matching P*_rs.nii pattern specifically...")
try:
# List all files in the repository (if we haven't already)
if not 'all_repo_files' in locals():
from huggingface_hub import list_repo_files
try:
all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
except Exception as e:
print(f"Error listing repo files: {e}")
all_repo_files = []
# Look for files matching the pattern exactly (P*_rs.nii)
pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
# If we don't find any exact matches, try a more relaxed pattern
if not pattern_files:
pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
if pattern_files:
print(f"Found {len(pattern_files)} files matching rs.nii pattern")
# Download each file
for pattern_file in pattern_files:
try:
file_path = hf_hub_download(
repo_id=dataset_name,
filename=pattern_file,
repo_type="dataset",
cache_dir=temp_dir
)
nii_files.append(file_path)
print(f"✓ Downloaded {pattern_file}")
except Exception as e:
print(f"Error downloading {pattern_file}: {e}")
except Exception as e:
print(f"Error searching for pattern files: {e}")
print(f"Found total of {len(nii_files)} NIfTI files")
except Exception as e:
print(f"Unexpected error during NIfTI file search: {e}")
import traceback
traceback.print_exc()
# If we found NIfTI files, process them to FC matrices
if nii_files:
print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
# Load Power 264 atlas
from nilearn import datasets
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
masker = input_data.NiftiSpheresMasker(
coords, radius=5,
standardize=True,
memory='nilearn_cache', memory_level=1,
verbose=0,
detrend=True,
low_pass=0.1,
high_pass=0.01,
t_r=2.0 # Adjust TR according to your data
)
# Process fMRI data and compute FC matrices
fc_matrices = []
valid_files = 0
total_files = len(nii_files)
for nii_file in nii_files:
try:
print(f"Processing {nii_file}...")
fmri_img = load_img(nii_file)
# Check image dimensions
if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
continue
try:
# Explicitly handle warnings about empty spheres
import warnings
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message='.*empty.*')
time_series = masker.fit_transform(fmri_img)
except Exception as e:
if "empty" in str(e):
print(f"Warning: Some spheres are empty in {nii_file}. Using a different sphere radius.")
# Extract the list of empty spheres for logging
import re
empty_spheres = re.findall(r"\[(.*?)\]", str(e))
if empty_spheres:
print(f"Empty spheres: {empty_spheres[0]}")
# Try with a different radius
alternate_masker = input_data.NiftiSpheresMasker(
coords, radius=8, # Larger radius
standardize=True,
memory='nilearn_cache', memory_level=1,
verbose=0,
detrend=True,
low_pass=0.1,
high_pass=0.01,
t_r=2.0
)
try:
time_series = alternate_masker.fit_transform(fmri_img)
print(f"Successfully extracted time series with larger radius")
except Exception as e2:
print(f"Error with alternate masker: {e2}")
print(f"Skipping this file due to empty spheres")
continue # Skip this file entirely
else:
print(f"Unknown error in masker: {e}")
continue # Skip this file if there's any other error
# Validate time series data
if np.isnan(time_series).any() or np.isinf(time_series).any():
print(f"Warning: {nii_file} contains NaN or Inf values after masking")
# Replace NaNs with zeros for this file
time_series = np.nan_to_num(time_series)
correlation_measure = connectome.ConnectivityMeasure(
kind='correlation',
vectorize=False,
discard_diagonal=False
)
fc_matrix = correlation_measure.fit_transform([time_series])[0]
# Check for invalid correlation values
if np.isnan(fc_matrix).any():
print(f"Warning: {nii_file} produced NaN correlation values")
continue
triu_indices = np.triu_indices_from(fc_matrix, k=1)
fc_triu = fc_matrix[triu_indices]
# Fisher z-transform with proper bounds check
# Clip correlation values to valid range for arctanh
fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
fc_triu = np.arctanh(fc_triu_clipped)
fc_matrices.append(fc_triu)
valid_files += 1
print(f"Successfully processed {nii_file} to FC matrix")
except Exception as e:
print(f"Error processing {nii_file}: {e}")
if fc_matrices:
print(f"Successfully processed {valid_files} out of {total_files} files")
# Ensure all matrices have the same dimensions
dims = [m.shape[0] for m in fc_matrices]
if len(set(dims)) > 1:
print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
# Use the most common dimension
from collections import Counter
most_common_dim = Counter(dims).most_common(1)[0][0]
print(f"Using most common dimension: {most_common_dim}")
fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
X = np.array(fc_matrices)
# Normalize the FC data
mean_x = np.mean(X, axis=0)
std_x = np.std(X, axis=0)
# Handle zero standard deviation
std_x[std_x == 0] = 1.0
X = (X - mean_x) / std_x
print(f"Created FC matrices with shape {X.shape}")
# Make sure demo_data matches the number of FC matrices
if len(demo_data[0]) != X.shape[0]:
print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
f"doesn't match number of FC matrices ({X.shape[0]})")
# Adjust demo_data to match FC matrices
indices = list(range(min(len(demo_data[0]), X.shape[0])))
X = X[indices]
demo_data = [d[indices] for d in demo_data]
return X, demo_data, demo_types
print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
# Return a placeholder with the right demographics but empty FC
n_subjects = len(dataset)
n_rois = 264
fc_dim = (n_rois * (n_rois - 1)) // 2
X = np.zeros((n_subjects, fc_dim))
print(f"Created placeholder FC matrices with shape {X.shape}")
return X, demo_data, demo_types
elif isinstance(dataset_or_niifiles, str):
# Handle real dataset with actual fMRI data
dataset = load_dataset(dataset_or_niifiles, split="train")
# Load Power 264 atlas
from nilearn import datasets
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
masker = input_data.NiftiSpheresMasker(
coords, radius=5,
standardize=True,
memory='nilearn_cache', memory_level=1,
verbose=0,
detrend=True,
low_pass=0.1,
high_pass=0.01,
t_r=2.0 # Adjust TR according to your data
)
# Load demographic data if needed
if demo_data is None:
if 'demographics' in dataset.features:
demo_df = pd.DataFrame(dataset['demographics'])
demo_data = [
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
demo_df['sex'].values if 'sex' in demo_df.columns else [],
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
]
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
# Process fMRI data and compute FC matrices
fc_matrices = []
for nii_file in dataset['nii_files']:
fmri_img = load_img(nii_file)
time_series = masker.fit_transform(fmri_img)
correlation_measure = connectome.ConnectivityMeasure(
kind='correlation', vectorize=False, discard_diagonal=False
)
fc_matrix = correlation_measure.fit_transform([time_series])[0]
triu_indices = np.triu_indices_from(fc_matrix, k=1)
fc_triu = fc_matrix[triu_indices]
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
fc_matrices.append(fc_triu)
X = np.array(fc_matrices)
elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
# Handle a list of NIfTI files
# Similar processing as above but with local files
print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
# Load Power 264 atlas
from nilearn import datasets
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
masker = input_data.NiftiSpheresMasker(
coords, radius=5,
standardize=True,
memory='nilearn_cache', memory_level=1,
verbose=0,
detrend=True,
low_pass=0.1,
high_pass=0.01,
t_r=2.0
)
fc_matrices = []
for nii_file in dataset_or_niifiles:
fmri_img = load_img(nii_file)
time_series = masker.fit_transform(fmri_img)
correlation_measure = connectome.ConnectivityMeasure(
kind='correlation', vectorize=False, discard_diagonal=False
)
fc_matrix = correlation_measure.fit_transform([time_series])[0]
triu_indices = np.triu_indices_from(fc_matrix, k=1)
fc_triu = fc_matrix[triu_indices]
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
fc_matrices.append(fc_triu)
X = np.array(fc_matrices)
else:
raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
# Normalize the FC data
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
return X, demo_data, demo_types