Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from nilearn import input_data, connectome | |
| from nilearn.image import load_img | |
| import nibabel as nib | |
| import os | |
| def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None): | |
| """ | |
| Process fMRI data to generate functional connectivity matrices | |
| Parameters: | |
| - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files | |
| - demo_data: Optional demographic data, required if providing NIfTI files | |
| - demo_types: Optional demographic data types, required if providing NIfTI files | |
| Returns: | |
| - X: Array of FC matrices | |
| - demo_data: Demographic data | |
| - demo_types: Demographic data types | |
| """ | |
| print(f"Preprocessing data with type: {type(dataset_or_niifiles)}") | |
| # For SreekarB/OSFData dataset, the data will be loaded from dataset features | |
| if isinstance(dataset_or_niifiles, str): | |
| dataset_name = dataset_or_niifiles | |
| print(f"Loading data from dataset: {dataset_name}") | |
| try: | |
| # Try multiple approaches to load the dataset | |
| approaches = [ | |
| lambda: load_dataset(dataset_name, split="train"), | |
| lambda: load_dataset(dataset_name), # Try without split | |
| lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code | |
| lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None | |
| ] | |
| dataset = None | |
| last_error = None | |
| for i, approach in enumerate(approaches): | |
| if approach is None: | |
| continue | |
| try: | |
| print(f"Attempt {i+1} to load dataset...") | |
| dataset = approach() | |
| print(f"Successfully loaded dataset with approach {i+1}!") | |
| break | |
| except Exception as e: | |
| print(f"Attempt {i+1} failed: {e}") | |
| last_error = e | |
| if dataset is None: | |
| print(f"All attempts to load dataset failed. Last error: {last_error}") | |
| raise ValueError(f"Could not load dataset {dataset_name}") | |
| except Exception as e: | |
| print(f"Error during dataset loading: {e}") | |
| raise | |
| # Prepare demographics data from the dataset | |
| if demo_data is None: | |
| # Create demo_data from the dataset | |
| demo_df = pd.DataFrame({ | |
| 'age': dataset['age'], | |
| 'gender': dataset['gender'], | |
| 'mpo': dataset['mpo'], | |
| 'wab_aq': dataset['wab_aq'] | |
| }) | |
| demo_data = [ | |
| demo_df['age'].values, | |
| demo_df['gender'].values, | |
| demo_df['mpo'].values, | |
| demo_df['wab_aq'].values | |
| ] | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| # Look for NIfTI files in P01_rs.nii format | |
| print("Searching for NIfTI files in dataset columns...") | |
| nii_files = [] | |
| # Create a temp directory for downloads | |
| import tempfile | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| temp_dir = tempfile.mkdtemp(prefix="hf_nifti_") | |
| print(f"Created temporary directory for NIfTI files: {temp_dir}") | |
| try: | |
| # First approach: Check if there are any columns containing file paths | |
| nii_columns = [] | |
| for col in dataset.column_names: | |
| # Check if column name suggests NIfTI files | |
| if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower(): | |
| nii_columns.append(col) | |
| # Or check if column contains file paths | |
| elif len(dataset) > 0: | |
| first_val = dataset[0][col] | |
| if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')): | |
| nii_columns.append(col) | |
| if nii_columns: | |
| print(f"Found columns that may contain NIfTI files: {nii_columns}") | |
| for col in nii_columns: | |
| print(f"Processing column '{col}'...") | |
| for i, item in enumerate(dataset[col]): | |
| if not isinstance(item, str): | |
| print(f"Item {i} in column {col} is not a string but {type(item)}") | |
| continue | |
| if not (item.endswith('.nii') or item.endswith('.nii.gz')): | |
| print(f"Item {i} in column {col} is not a NIfTI file: {item}") | |
| continue | |
| print(f"Downloading {item} from dataset {dataset_name}...") | |
| try: | |
| # Attempt to download with explicit filename | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=item, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Successfully downloaded {item}") | |
| except Exception as e1: | |
| print(f"Error downloading with explicit filename: {e1}") | |
| # Second attempt: try with the item's basename | |
| try: | |
| basename = os.path.basename(item) | |
| print(f"Trying with basename: {basename}") | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=basename, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Successfully downloaded {basename}") | |
| except Exception as e2: | |
| print(f"Error downloading with basename: {e2}") | |
| # Third attempt: check if it's a binary blob in the dataset | |
| try: | |
| if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]: | |
| print("Found binary data in dataset, saving to temporary file...") | |
| binary_data = dataset[i]['bytes'] | |
| temp_file = os.path.join(temp_dir, basename) | |
| with open(temp_file, 'wb') as f: | |
| f.write(binary_data) | |
| nii_files.append(temp_file) | |
| print(f"✓ Saved binary data to {temp_file}") | |
| except Exception as e3: | |
| print(f"Error handling binary data: {e3}") | |
| # Last resort: look for the file locally | |
| local_path = os.path.join(os.getcwd(), item) | |
| if os.path.exists(local_path): | |
| nii_files.append(local_path) | |
| print(f"✓ Found {item} locally") | |
| else: | |
| print(f"❌ Warning: Could not find {item} anywhere") | |
| # Second approach: Try to find NIfTI files in dataset repository directly | |
| if not nii_files: | |
| print("No NIfTI files found in dataset columns. Trying direct repository search...") | |
| try: | |
| from huggingface_hub import list_repo_files, hf_hub_download | |
| # Try to list all files in the repository | |
| try: | |
| print("Listing all repository files...") | |
| all_repo_files = list_repo_files(dataset_name, repo_type="dataset") | |
| print(f"Found {len(all_repo_files)} files in repository") | |
| # First prioritize P*_rs.nii files | |
| p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')] | |
| # Then include all other NIfTI files | |
| other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files] | |
| # Combine, with P*_rs.nii files first | |
| nii_repo_files = p_rs_files + other_nii_files | |
| if nii_repo_files: | |
| print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...") | |
| # Download each file | |
| for nii_file in nii_repo_files: | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=nii_file, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Downloaded {nii_file}") | |
| except Exception as e: | |
| print(f"Error downloading {nii_file}: {e}") | |
| except Exception as e: | |
| print(f"Error listing repository files: {e}") | |
| print("Will try alternative approaches...") | |
| # If repo listing fails, try with common NIfTI file patterns directly | |
| if not nii_files: | |
| print("Trying common NIfTI file patterns...") | |
| # Focus specifically on P*_rs.nii pattern | |
| patterns = [] | |
| # Generate P01_rs.nii through P30_rs.nii | |
| for i in range(1, 31): # Try subjects 1-30 | |
| patterns.append(f"P{i:02d}_rs.nii") | |
| # Also try with .nii.gz extension | |
| for i in range(1, 31): | |
| patterns.append(f"P{i:02d}_rs.nii.gz") | |
| # Include a few other common patterns as fallbacks | |
| patterns.extend([ | |
| "sub-01_task-rest_bold.nii.gz", # BIDS format | |
| "fmri.nii.gz", "bold.nii.gz", | |
| "rest.nii.gz" | |
| ]) | |
| for pattern in patterns: | |
| try: | |
| print(f"Trying to download {pattern}...") | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=pattern, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Successfully downloaded {pattern}") | |
| except Exception as e: | |
| print(f"× Failed to download {pattern}") | |
| # If we still couldn't find any files, check if data files are nested | |
| if not nii_files: | |
| print("Checking for nested data files...") | |
| nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"] | |
| for path in nested_paths: | |
| for pattern in patterns: | |
| nested_file = f"{path}{pattern}" | |
| try: | |
| print(f"Trying to download {nested_file}...") | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=nested_file, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Successfully downloaded {nested_file}") | |
| # If we found one file in this directory, try to find all files in it | |
| try: | |
| all_files_in_dir = [f for f in all_repo_files if f.startswith(path)] | |
| nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')] | |
| print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}") | |
| for nii_file in nii_files_in_dir: | |
| if nii_file != nested_file: # Skip the one we already downloaded | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=nii_file, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Downloaded {nii_file}") | |
| except Exception as e: | |
| print(f"Error downloading {nii_file}: {e}") | |
| except Exception as e: | |
| print(f"Error finding additional files in {path}: {e}") | |
| except Exception as e: | |
| pass | |
| except Exception as e: | |
| print(f"Error during repository exploration: {e}") | |
| # If we still don't have any files, try to search for P*_rs.nii pattern specifically | |
| if not nii_files: | |
| print("Trying to find files matching P*_rs.nii pattern specifically...") | |
| try: | |
| # List all files in the repository (if we haven't already) | |
| if not 'all_repo_files' in locals(): | |
| from huggingface_hub import list_repo_files | |
| try: | |
| all_repo_files = list_repo_files(dataset_name, repo_type="dataset") | |
| except Exception as e: | |
| print(f"Error listing repo files: {e}") | |
| all_repo_files = [] | |
| # Look for files matching the pattern exactly (P*_rs.nii) | |
| pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')] | |
| # If we don't find any exact matches, try a more relaxed pattern | |
| if not pattern_files: | |
| pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()] | |
| if pattern_files: | |
| print(f"Found {len(pattern_files)} files matching rs.nii pattern") | |
| # Download each file | |
| for pattern_file in pattern_files: | |
| try: | |
| file_path = hf_hub_download( | |
| repo_id=dataset_name, | |
| filename=pattern_file, | |
| repo_type="dataset", | |
| cache_dir=temp_dir | |
| ) | |
| nii_files.append(file_path) | |
| print(f"✓ Downloaded {pattern_file}") | |
| except Exception as e: | |
| print(f"Error downloading {pattern_file}: {e}") | |
| except Exception as e: | |
| print(f"Error searching for pattern files: {e}") | |
| print(f"Found total of {len(nii_files)} NIfTI files") | |
| except Exception as e: | |
| print(f"Unexpected error during NIfTI file search: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # If we found NIfTI files, process them to FC matrices | |
| if nii_files: | |
| print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices") | |
| # Load Power 264 atlas | |
| from nilearn import datasets | |
| power = datasets.fetch_coords_power_2011() | |
| coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T | |
| masker = input_data.NiftiSpheresMasker( | |
| coords, radius=5, | |
| standardize=True, | |
| memory='nilearn_cache', memory_level=1, | |
| verbose=0, | |
| detrend=True, | |
| low_pass=0.1, | |
| high_pass=0.01, | |
| t_r=2.0 # Adjust TR according to your data | |
| ) | |
| # Process fMRI data and compute FC matrices | |
| fc_matrices = [] | |
| valid_files = 0 | |
| total_files = len(nii_files) | |
| for nii_file in nii_files: | |
| try: | |
| print(f"Processing {nii_file}...") | |
| fmri_img = load_img(nii_file) | |
| # Check image dimensions | |
| if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10: | |
| print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}") | |
| continue | |
| try: | |
| # Explicitly handle warnings about empty spheres | |
| import warnings | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings('ignore', message='.*empty.*') | |
| time_series = masker.fit_transform(fmri_img) | |
| except Exception as e: | |
| if "empty" in str(e): | |
| print(f"Warning: Some spheres are empty in {nii_file}. Using a different sphere radius.") | |
| # Extract the list of empty spheres for logging | |
| import re | |
| empty_spheres = re.findall(r"\[(.*?)\]", str(e)) | |
| if empty_spheres: | |
| print(f"Empty spheres: {empty_spheres[0]}") | |
| # Try with a different radius | |
| alternate_masker = input_data.NiftiSpheresMasker( | |
| coords, radius=8, # Larger radius | |
| standardize=True, | |
| memory='nilearn_cache', memory_level=1, | |
| verbose=0, | |
| detrend=True, | |
| low_pass=0.1, | |
| high_pass=0.01, | |
| t_r=2.0 | |
| ) | |
| try: | |
| time_series = alternate_masker.fit_transform(fmri_img) | |
| print(f"Successfully extracted time series with larger radius") | |
| except Exception as e2: | |
| print(f"Error with alternate masker: {e2}") | |
| print(f"Skipping this file due to empty spheres") | |
| continue # Skip this file entirely | |
| else: | |
| print(f"Unknown error in masker: {e}") | |
| continue # Skip this file if there's any other error | |
| # Validate time series data | |
| if np.isnan(time_series).any() or np.isinf(time_series).any(): | |
| print(f"Warning: {nii_file} contains NaN or Inf values after masking") | |
| # Replace NaNs with zeros for this file | |
| time_series = np.nan_to_num(time_series) | |
| correlation_measure = connectome.ConnectivityMeasure( | |
| kind='correlation', | |
| vectorize=False, | |
| discard_diagonal=False | |
| ) | |
| fc_matrix = correlation_measure.fit_transform([time_series])[0] | |
| # Check for invalid correlation values | |
| if np.isnan(fc_matrix).any(): | |
| print(f"Warning: {nii_file} produced NaN correlation values") | |
| continue | |
| triu_indices = np.triu_indices_from(fc_matrix, k=1) | |
| fc_triu = fc_matrix[triu_indices] | |
| # Fisher z-transform with proper bounds check | |
| # Clip correlation values to valid range for arctanh | |
| fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999) | |
| fc_triu = np.arctanh(fc_triu_clipped) | |
| fc_matrices.append(fc_triu) | |
| valid_files += 1 | |
| print(f"Successfully processed {nii_file} to FC matrix") | |
| except Exception as e: | |
| print(f"Error processing {nii_file}: {e}") | |
| if fc_matrices: | |
| print(f"Successfully processed {valid_files} out of {total_files} files") | |
| # Ensure all matrices have the same dimensions | |
| dims = [m.shape[0] for m in fc_matrices] | |
| if len(set(dims)) > 1: | |
| print(f"Warning: FC matrices have inconsistent dimensions: {dims}") | |
| # Use the most common dimension | |
| from collections import Counter | |
| most_common_dim = Counter(dims).most_common(1)[0][0] | |
| print(f"Using most common dimension: {most_common_dim}") | |
| fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim] | |
| X = np.array(fc_matrices) | |
| # Normalize the FC data | |
| mean_x = np.mean(X, axis=0) | |
| std_x = np.std(X, axis=0) | |
| # Handle zero standard deviation | |
| std_x[std_x == 0] = 1.0 | |
| X = (X - mean_x) / std_x | |
| print(f"Created FC matrices with shape {X.shape}") | |
| # Make sure demo_data matches the number of FC matrices | |
| if len(demo_data[0]) != X.shape[0]: | |
| print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " + | |
| f"doesn't match number of FC matrices ({X.shape[0]})") | |
| # Adjust demo_data to match FC matrices | |
| indices = list(range(min(len(demo_data[0]), X.shape[0]))) | |
| X = X[indices] | |
| demo_data = [d[indices] for d in demo_data] | |
| return X, demo_data, demo_types | |
| print("No FC or fMRI data found in the dataset. Please provide FC matrices.") | |
| # Return a placeholder with the right demographics but empty FC | |
| n_subjects = len(dataset) | |
| n_rois = 264 | |
| fc_dim = (n_rois * (n_rois - 1)) // 2 | |
| X = np.zeros((n_subjects, fc_dim)) | |
| print(f"Created placeholder FC matrices with shape {X.shape}") | |
| return X, demo_data, demo_types | |
| elif isinstance(dataset_or_niifiles, str): | |
| # Handle real dataset with actual fMRI data | |
| dataset = load_dataset(dataset_or_niifiles, split="train") | |
| # Load Power 264 atlas | |
| from nilearn import datasets | |
| power = datasets.fetch_coords_power_2011() | |
| coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T | |
| masker = input_data.NiftiSpheresMasker( | |
| coords, radius=5, | |
| standardize=True, | |
| memory='nilearn_cache', memory_level=1, | |
| verbose=0, | |
| detrend=True, | |
| low_pass=0.1, | |
| high_pass=0.01, | |
| t_r=2.0 # Adjust TR according to your data | |
| ) | |
| # Load demographic data if needed | |
| if demo_data is None: | |
| if 'demographics' in dataset.features: | |
| demo_df = pd.DataFrame(dataset['demographics']) | |
| demo_data = [ | |
| demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [], | |
| demo_df['sex'].values if 'sex' in demo_df.columns else [], | |
| demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [], | |
| demo_df['wab_score'].values if 'wab_score' in demo_df.columns else [] | |
| ] | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| # Process fMRI data and compute FC matrices | |
| fc_matrices = [] | |
| for nii_file in dataset['nii_files']: | |
| fmri_img = load_img(nii_file) | |
| time_series = masker.fit_transform(fmri_img) | |
| correlation_measure = connectome.ConnectivityMeasure( | |
| kind='correlation', vectorize=False, discard_diagonal=False | |
| ) | |
| fc_matrix = correlation_measure.fit_transform([time_series])[0] | |
| triu_indices = np.triu_indices_from(fc_matrix, k=1) | |
| fc_triu = fc_matrix[triu_indices] | |
| fc_triu = np.arctanh(fc_triu) # Fisher z-transform | |
| fc_matrices.append(fc_triu) | |
| X = np.array(fc_matrices) | |
| elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None: | |
| # Handle a list of NIfTI files | |
| # Similar processing as above but with local files | |
| print(f"Processing {len(dataset_or_niifiles)} local NIfTI files") | |
| # Load Power 264 atlas | |
| from nilearn import datasets | |
| power = datasets.fetch_coords_power_2011() | |
| coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T | |
| masker = input_data.NiftiSpheresMasker( | |
| coords, radius=5, | |
| standardize=True, | |
| memory='nilearn_cache', memory_level=1, | |
| verbose=0, | |
| detrend=True, | |
| low_pass=0.1, | |
| high_pass=0.01, | |
| t_r=2.0 | |
| ) | |
| fc_matrices = [] | |
| for nii_file in dataset_or_niifiles: | |
| fmri_img = load_img(nii_file) | |
| time_series = masker.fit_transform(fmri_img) | |
| correlation_measure = connectome.ConnectivityMeasure( | |
| kind='correlation', vectorize=False, discard_diagonal=False | |
| ) | |
| fc_matrix = correlation_measure.fit_transform([time_series])[0] | |
| triu_indices = np.triu_indices_from(fc_matrix, k=1) | |
| fc_triu = fc_matrix[triu_indices] | |
| fc_triu = np.arctanh(fc_triu) # Fisher z-transform | |
| fc_matrices.append(fc_triu) | |
| X = np.array(fc_matrices) | |
| else: | |
| raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.") | |
| # Normalize the FC data | |
| X = (X - np.mean(X, axis=0)) / np.std(X, axis=0) | |
| return X, demo_data, demo_types |