Spaces:
Sleeping
Sleeping
Upload 20 files
Browse files- README.md +1 -1
- app.py +431 -189
- config.py +0 -1
- demo_fc_visualization.py +73 -0
- demovae/model.py +225 -0
- demovae/sklearn.py +124 -0
- fc_visualization.py +349 -0
- huggingface_fc_visualization.py +489 -0
- main.py +4 -12
- rcf_prediction.py +34 -93
- src/.DS_Store +0 -0
- vae_model.py +89 -12
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🧠
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
|
|
| 7 |
from visualization import plot_fc_matrices, plot_learning_curves
|
| 8 |
import os
|
| 9 |
import glob
|
| 10 |
-
from sklearn.metrics import mean_squared_error, r2_score
|
| 11 |
import json
|
| 12 |
import pickle
|
| 13 |
import pandas as pd
|
|
@@ -24,6 +24,7 @@ class AphasiaPredictionApp:
|
|
| 24 |
self.predictor = None
|
| 25 |
self.trained = False
|
| 26 |
self.latent_dim = MODEL_CONFIG['latent_dim']
|
|
|
|
| 27 |
|
| 28 |
def train_models(self, data_dir, latent_dim, nepochs, bsize):
|
| 29 |
"""
|
|
@@ -34,9 +35,8 @@ class AphasiaPredictionApp:
|
|
| 34 |
logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
|
| 35 |
|
| 36 |
# Default prediction parameters from our config
|
| 37 |
-
prediction_type = PREDICTION_CONFIG.get('prediction_type', 'regression')
|
| 38 |
outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
|
| 39 |
-
logger.info(f"Prediction: type=
|
| 40 |
|
| 41 |
figures = {}
|
| 42 |
|
|
@@ -323,6 +323,8 @@ class AphasiaPredictionApp:
|
|
| 323 |
try:
|
| 324 |
real_treatment_file = process_behavioral_data_to_outcomes(csv_path)
|
| 325 |
treatment_file = real_treatment_file # Use the real treatment file if processing succeeded
|
|
|
|
|
|
|
| 326 |
logger.info(f"Using processed behavioral data for treatment outcomes")
|
| 327 |
except Exception as proc_err:
|
| 328 |
logger.warning(f"Couldn't process behavioral data: {proc_err}, using standard outcomes")
|
|
@@ -338,6 +340,8 @@ class AphasiaPredictionApp:
|
|
| 338 |
|
| 339 |
# Use the found file
|
| 340 |
treatment_file = real_treatment_file
|
|
|
|
|
|
|
| 341 |
logger.info(f"Using real treatment outcomes file")
|
| 342 |
except Exception as find_err:
|
| 343 |
logger.warning(f"Couldn't find treatment outcomes file: {find_err}, using standard outcomes")
|
|
@@ -754,54 +758,50 @@ class AphasiaPredictionApp:
|
|
| 754 |
# Plot predicted vs actual values
|
| 755 |
ax1 = fig.add_subplot(gs[0, 0])
|
| 756 |
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
else: # classification
|
| 803 |
-
# Convert to integer classes if they're strings
|
| 804 |
-
if isinstance(y_true[0], str) or isinstance(y_pred[0], str):
|
| 805 |
# Create mapping of class labels to integers
|
| 806 |
classes = sorted(list(set(list(y_true) + list(y_pred))))
|
| 807 |
class_to_int = {c: i for i, c in enumerate(classes)}
|
|
@@ -911,76 +911,39 @@ class AphasiaPredictionApp:
|
|
| 911 |
"""Create learning curve plots from cross-validation results"""
|
| 912 |
fig = plt.figure(figsize=(12, 6))
|
| 913 |
|
| 914 |
-
#
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
ax2.set_xticks(range(1, len(fold_metrics)+1))
|
| 948 |
-
ax2.legend()
|
| 949 |
-
|
| 950 |
-
else: # classification
|
| 951 |
-
# For classification, show accuracy and F1
|
| 952 |
-
ax1 = plt.subplot(1, 2, 1)
|
| 953 |
-
ax2 = plt.subplot(1, 2, 2)
|
| 954 |
-
|
| 955 |
-
# Plot accuracy for each fold
|
| 956 |
-
for i, metrics in enumerate(fold_metrics):
|
| 957 |
-
ax1.plot(i+1, metrics['accuracy'], 'bo')
|
| 958 |
-
|
| 959 |
-
# Plot average accuracy
|
| 960 |
-
avg_acc = np.mean([m['accuracy'] for m in fold_metrics])
|
| 961 |
-
ax1.axhline(y=avg_acc, color='r', linestyle='--',
|
| 962 |
-
label=f'Average Accuracy = {avg_acc:.4f}')
|
| 963 |
-
|
| 964 |
-
ax1.set_xlabel('Fold')
|
| 965 |
-
ax1.set_ylabel('Accuracy')
|
| 966 |
-
ax1.set_title('Accuracy by Fold')
|
| 967 |
-
ax1.set_xticks(range(1, len(fold_metrics)+1))
|
| 968 |
-
ax1.legend()
|
| 969 |
-
|
| 970 |
-
# Plot F1 for each fold
|
| 971 |
-
for i, metrics in enumerate(fold_metrics):
|
| 972 |
-
ax2.plot(i+1, metrics['f1'], 'go')
|
| 973 |
-
|
| 974 |
-
# Plot average F1
|
| 975 |
-
avg_f1 = np.mean([m['f1'] for m in fold_metrics])
|
| 976 |
-
ax2.axhline(y=avg_f1, color='r', linestyle='--',
|
| 977 |
-
label=f'Average F1 = {avg_f1:.4f}')
|
| 978 |
-
|
| 979 |
-
ax2.set_xlabel('Fold')
|
| 980 |
-
ax2.set_ylabel('F1 Score')
|
| 981 |
-
ax2.set_title('F1 Score by Fold')
|
| 982 |
-
ax2.set_xticks(range(1, len(fold_metrics)+1))
|
| 983 |
-
ax2.legend()
|
| 984 |
|
| 985 |
plt.tight_layout()
|
| 986 |
return fig
|
|
@@ -1350,6 +1313,8 @@ def process_behavioral_data_to_outcomes(behavioral_file):
|
|
| 1350 |
outcomes_df = pd.DataFrame(outcome_data)
|
| 1351 |
outcomes_df.to_csv(outcomes_file, index=False)
|
| 1352 |
logger.info(f"Created treatment outcomes file with {len(outcomes_df)} patients")
|
|
|
|
|
|
|
| 1353 |
return outcomes_file
|
| 1354 |
else:
|
| 1355 |
# If we couldn't extract outcomes per patient, try a simpler approach
|
|
@@ -1375,6 +1340,8 @@ def process_behavioral_data_to_outcomes(behavioral_file):
|
|
| 1375 |
])
|
| 1376 |
outcomes_df.to_csv(outcomes_file, index=False)
|
| 1377 |
logger.warning(f"Created simplified treatment outcomes with group improvement: {improvement:.2f}")
|
|
|
|
|
|
|
| 1378 |
return outcomes_file
|
| 1379 |
except Exception as e:
|
| 1380 |
logger.error(f"Could not create even simplified outcomes: {e}")
|
|
@@ -1858,8 +1825,8 @@ def create_interface():
|
|
| 1858 |
gr.Markdown("# Aphasia Treatment Trajectory Prediction")
|
| 1859 |
|
| 1860 |
with gr.Tabs():
|
| 1861 |
-
# Training
|
| 1862 |
-
with gr.Tab("
|
| 1863 |
with gr.Row():
|
| 1864 |
with gr.Column(scale=1):
|
| 1865 |
data_dir = gr.Textbox(
|
|
@@ -1889,39 +1856,72 @@ def create_interface():
|
|
| 1889 |
use_hf_dataset = gr.Checkbox(
|
| 1890 |
label="Use HuggingFace Dataset", value=True
|
| 1891 |
)
|
| 1892 |
-
|
| 1893 |
-
|
| 1894 |
-
label="Prediction Type",
|
| 1895 |
-
choices=["regression", "classification"],
|
| 1896 |
-
value="regression"
|
| 1897 |
-
)
|
| 1898 |
-
outcome_variable = gr.Dropdown(
|
| 1899 |
-
label="Outcome Variable",
|
| 1900 |
-
choices=["wab_aq", "age", "mpo", "education"],
|
| 1901 |
-
value="wab_aq"
|
| 1902 |
-
)
|
| 1903 |
skip_behavioral = gr.Checkbox(
|
| 1904 |
label="Skip Behavioral Data Processing",
|
| 1905 |
value=PREDICTION_CONFIG.get('skip_behavioral_data', True),
|
| 1906 |
info="Use pre-defined treatment outcomes instead of processing behavioral data"
|
| 1907 |
)
|
| 1908 |
-
|
| 1909 |
-
|
| 1910 |
-
|
| 1911 |
-
|
| 1912 |
-
|
| 1913 |
-
|
| 1914 |
-
|
| 1915 |
-
|
| 1916 |
-
|
| 1917 |
-
|
| 1918 |
-
info="Generate synthetic FC matrices if processing fails"
|
| 1919 |
-
)
|
| 1920 |
|
| 1921 |
-
|
|
|
|
|
|
|
| 1922 |
|
| 1923 |
with gr.Row():
|
| 1924 |
-
fc_plot = gr.Plot(label="FC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1925 |
|
| 1926 |
with gr.Row():
|
| 1927 |
with gr.Column(scale=1):
|
|
@@ -1930,13 +1930,18 @@ def create_interface():
|
|
| 1930 |
prediction_plot = gr.Plot(label="Prediction Performance")
|
| 1931 |
|
| 1932 |
with gr.Row():
|
| 1933 |
-
|
|
|
|
|
|
|
| 1934 |
|
| 1935 |
-
#
|
| 1936 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
| 1937 |
with gr.Row():
|
| 1938 |
with gr.Column(scale=1):
|
| 1939 |
-
fmri_file = gr.File(label="Patient fMRI Data")
|
| 1940 |
with gr.Column(scale=1):
|
| 1941 |
with gr.Group("Patient Demographics"):
|
| 1942 |
age = gr.Number(label="Age at Stroke", value=60)
|
|
@@ -1952,21 +1957,23 @@ def create_interface():
|
|
| 1952 |
with gr.Row():
|
| 1953 |
trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
|
| 1954 |
|
| 1955 |
-
#
|
| 1956 |
-
|
| 1957 |
-
|
| 1958 |
-
|
| 1959 |
-
'
|
| 1960 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1961 |
}
|
| 1962 |
|
| 1963 |
-
#
|
| 1964 |
-
def
|
| 1965 |
-
|
| 1966 |
-
|
| 1967 |
-
#
|
| 1968 |
-
PREDICTION_CONFIG['prediction_type'] = prediction_type
|
| 1969 |
-
PREDICTION_CONFIG['default_outcome'] = outcome_variable
|
| 1970 |
PREDICTION_CONFIG['skip_behavioral_data'] = skip_behavioral
|
| 1971 |
PREDICTION_CONFIG['use_synthetic_nifti'] = use_synthetic_nifti
|
| 1972 |
PREDICTION_CONFIG['use_synthetic_fc'] = use_synthetic_fc
|
|
@@ -1978,36 +1985,271 @@ def create_interface():
|
|
| 1978 |
else:
|
| 1979 |
PREDICTION_CONFIG['local_nii_dir'] = None
|
| 1980 |
|
| 1981 |
-
# Log
|
| 1982 |
-
logger.info(f"
|
| 1983 |
-
logger.info(f"
|
| 1984 |
-
logger.info(f"Prediction type: {prediction_type}, target: {outcome_variable}")
|
| 1985 |
|
| 1986 |
-
|
| 1987 |
-
|
| 1988 |
-
|
| 1989 |
-
|
| 1990 |
-
|
| 1991 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1992 |
|
| 1993 |
-
|
| 1994 |
-
|
| 1995 |
-
|
| 1996 |
-
|
| 1997 |
-
|
| 1998 |
-
|
| 1999 |
-
|
| 2000 |
-
|
| 2001 |
-
|
| 2002 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2003 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2004 |
-
|
| 2005 |
-
|
| 2006 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2007 |
)
|
| 2008 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2009 |
predict_btn.click(
|
| 2010 |
-
fn=
|
| 2011 |
inputs=[fmri_file, age, sex, months, wab],
|
| 2012 |
outputs=[prediction_text, trajectory_plot]
|
| 2013 |
)
|
|
|
|
| 7 |
from visualization import plot_fc_matrices, plot_learning_curves
|
| 8 |
import os
|
| 9 |
import glob
|
| 10 |
+
from sklearn.metrics import mean_squared_error, r2_score
|
| 11 |
import json
|
| 12 |
import pickle
|
| 13 |
import pandas as pd
|
|
|
|
| 24 |
self.predictor = None
|
| 25 |
self.trained = False
|
| 26 |
self.latent_dim = MODEL_CONFIG['latent_dim']
|
| 27 |
+
self.last_treatment_file = None # Track the last treatment file used
|
| 28 |
|
| 29 |
def train_models(self, data_dir, latent_dim, nepochs, bsize):
|
| 30 |
"""
|
|
|
|
| 35 |
logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
|
| 36 |
|
| 37 |
# Default prediction parameters from our config
|
|
|
|
| 38 |
outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
|
| 39 |
+
logger.info(f"Prediction: type=regression, outcome={outcome_variable}")
|
| 40 |
|
| 41 |
figures = {}
|
| 42 |
|
|
|
|
| 323 |
try:
|
| 324 |
real_treatment_file = process_behavioral_data_to_outcomes(csv_path)
|
| 325 |
treatment_file = real_treatment_file # Use the real treatment file if processing succeeded
|
| 326 |
+
# Store the treatment file path for later use
|
| 327 |
+
self.last_treatment_file = treatment_file
|
| 328 |
logger.info(f"Using processed behavioral data for treatment outcomes")
|
| 329 |
except Exception as proc_err:
|
| 330 |
logger.warning(f"Couldn't process behavioral data: {proc_err}, using standard outcomes")
|
|
|
|
| 340 |
|
| 341 |
# Use the found file
|
| 342 |
treatment_file = real_treatment_file
|
| 343 |
+
# Store the treatment file path for later use
|
| 344 |
+
self.last_treatment_file = treatment_file
|
| 345 |
logger.info(f"Using real treatment outcomes file")
|
| 346 |
except Exception as find_err:
|
| 347 |
logger.warning(f"Couldn't find treatment outcomes file: {find_err}, using standard outcomes")
|
|
|
|
| 758 |
# Plot predicted vs actual values
|
| 759 |
ax1 = fig.add_subplot(gs[0, 0])
|
| 760 |
|
| 761 |
+
# Regression plots
|
| 762 |
+
# Scatter plot
|
| 763 |
+
ax1.scatter(y_true, y_pred, alpha=0.7)
|
| 764 |
+
|
| 765 |
+
# Add perfect prediction line
|
| 766 |
+
min_val = min(np.min(y_true), np.min(y_pred))
|
| 767 |
+
max_val = max(np.max(y_true), np.max(y_pred))
|
| 768 |
+
ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
|
| 769 |
+
|
| 770 |
+
ax1.set_xlabel('Actual Values')
|
| 771 |
+
ax1.set_ylabel('Predicted Values')
|
| 772 |
+
ax1.set_title('Predicted vs. Actual Values')
|
| 773 |
+
|
| 774 |
+
# Add R² to the plot
|
| 775 |
+
r2 = r2_score(y_true, y_pred)
|
| 776 |
+
ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
|
| 777 |
+
bbox=dict(facecolor='white', alpha=0.5))
|
| 778 |
+
|
| 779 |
+
# Plot residuals
|
| 780 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 781 |
+
residuals = y_true - y_pred
|
| 782 |
+
ax2.scatter(y_pred, residuals, alpha=0.7)
|
| 783 |
+
ax2.axhline(y=0, color='r', linestyle='--')
|
| 784 |
+
ax2.set_xlabel('Predicted Values')
|
| 785 |
+
ax2.set_ylabel('Residuals')
|
| 786 |
+
ax2.set_title('Residual Plot')
|
| 787 |
+
|
| 788 |
+
# Plot prediction errors
|
| 789 |
+
ax3 = fig.add_subplot(gs[1, 0])
|
| 790 |
+
ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
|
| 791 |
+
label='Predicted ± 2σ')
|
| 792 |
+
ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
|
| 793 |
+
ax3.set_xlabel('Sample Index')
|
| 794 |
+
ax3.set_ylabel('Value')
|
| 795 |
+
ax3.set_title('Prediction with Error Bars')
|
| 796 |
+
ax3.legend()
|
| 797 |
+
|
| 798 |
+
# Plot error distribution
|
| 799 |
+
ax4 = fig.add_subplot(gs[1, 1])
|
| 800 |
+
ax4.hist(residuals, bins=20, alpha=0.7)
|
| 801 |
+
ax4.axvline(x=0, color='r', linestyle='--')
|
| 802 |
+
ax4.set_xlabel('Prediction Error')
|
| 803 |
+
ax4.set_ylabel('Frequency')
|
| 804 |
+
ax4.set_title('Error Distribution')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
# Create mapping of class labels to integers
|
| 806 |
classes = sorted(list(set(list(y_true) + list(y_pred))))
|
| 807 |
class_to_int = {c: i for i, c in enumerate(classes)}
|
|
|
|
| 911 |
"""Create learning curve plots from cross-validation results"""
|
| 912 |
fig = plt.figure(figsize=(12, 6))
|
| 913 |
|
| 914 |
+
# For regression, show R² and RMSE
|
| 915 |
+
ax1 = plt.subplot(1, 2, 1)
|
| 916 |
+
ax2 = plt.subplot(1, 2, 2)
|
| 917 |
+
|
| 918 |
+
# Plot R² for each fold
|
| 919 |
+
for i, metrics in enumerate(fold_metrics):
|
| 920 |
+
ax1.plot(i+1, metrics['r2'], 'bo')
|
| 921 |
+
|
| 922 |
+
# Plot average R²
|
| 923 |
+
avg_r2 = np.mean([m['r2'] for m in fold_metrics])
|
| 924 |
+
ax1.axhline(y=avg_r2, color='r', linestyle='--',
|
| 925 |
+
label=f'Average R² = {avg_r2:.4f}')
|
| 926 |
+
|
| 927 |
+
ax1.set_xlabel('Fold')
|
| 928 |
+
ax1.set_ylabel('R²')
|
| 929 |
+
ax1.set_title('R² by Fold')
|
| 930 |
+
ax1.set_xticks(range(1, len(fold_metrics)+1))
|
| 931 |
+
ax1.legend()
|
| 932 |
+
|
| 933 |
+
# Plot RMSE for each fold
|
| 934 |
+
for i, metrics in enumerate(fold_metrics):
|
| 935 |
+
ax2.plot(i+1, metrics['rmse'], 'go')
|
| 936 |
+
|
| 937 |
+
# Plot average RMSE
|
| 938 |
+
avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
|
| 939 |
+
ax2.axhline(y=avg_rmse, color='r', linestyle='--',
|
| 940 |
+
label=f'Average RMSE = {avg_rmse:.4f}')
|
| 941 |
+
|
| 942 |
+
ax2.set_xlabel('Fold')
|
| 943 |
+
ax2.set_ylabel('RMSE')
|
| 944 |
+
ax2.set_title('RMSE by Fold')
|
| 945 |
+
ax2.set_xticks(range(1, len(fold_metrics)+1))
|
| 946 |
+
ax2.legend()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
|
| 948 |
plt.tight_layout()
|
| 949 |
return fig
|
|
|
|
| 1313 |
outcomes_df = pd.DataFrame(outcome_data)
|
| 1314 |
outcomes_df.to_csv(outcomes_file, index=False)
|
| 1315 |
logger.info(f"Created treatment outcomes file with {len(outcomes_df)} patients")
|
| 1316 |
+
# Store the treatment file path for later use
|
| 1317 |
+
self.last_treatment_file = outcomes_file
|
| 1318 |
return outcomes_file
|
| 1319 |
else:
|
| 1320 |
# If we couldn't extract outcomes per patient, try a simpler approach
|
|
|
|
| 1340 |
])
|
| 1341 |
outcomes_df.to_csv(outcomes_file, index=False)
|
| 1342 |
logger.warning(f"Created simplified treatment outcomes with group improvement: {improvement:.2f}")
|
| 1343 |
+
# Store the treatment file path for later use
|
| 1344 |
+
self.last_treatment_file = outcomes_file
|
| 1345 |
return outcomes_file
|
| 1346 |
except Exception as e:
|
| 1347 |
logger.error(f"Could not create even simplified outcomes: {e}")
|
|
|
|
| 1825 |
gr.Markdown("# Aphasia Treatment Trajectory Prediction")
|
| 1826 |
|
| 1827 |
with gr.Tabs():
|
| 1828 |
+
# Tab 1: VAE Training
|
| 1829 |
+
with gr.Tab("1. VAE Training"):
|
| 1830 |
with gr.Row():
|
| 1831 |
with gr.Column(scale=1):
|
| 1832 |
data_dir = gr.Textbox(
|
|
|
|
| 1856 |
use_hf_dataset = gr.Checkbox(
|
| 1857 |
label="Use HuggingFace Dataset", value=True
|
| 1858 |
)
|
| 1859 |
+
|
| 1860 |
+
with gr.Accordion("Advanced Data Options", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1861 |
skip_behavioral = gr.Checkbox(
|
| 1862 |
label="Skip Behavioral Data Processing",
|
| 1863 |
value=PREDICTION_CONFIG.get('skip_behavioral_data', True),
|
| 1864 |
info="Use pre-defined treatment outcomes instead of processing behavioral data"
|
| 1865 |
)
|
| 1866 |
+
use_synthetic_nifti = gr.Checkbox(
|
| 1867 |
+
label="Use Synthetic NIfTI Data",
|
| 1868 |
+
value=PREDICTION_CONFIG.get('use_synthetic_nifti', False),
|
| 1869 |
+
info="Generate synthetic NIfTI files if real ones aren't found"
|
| 1870 |
+
)
|
| 1871 |
+
use_synthetic_fc = gr.Checkbox(
|
| 1872 |
+
label="Use Synthetic FC Matrices",
|
| 1873 |
+
value=PREDICTION_CONFIG.get('use_synthetic_fc', False),
|
| 1874 |
+
info="Generate synthetic FC matrices if processing fails"
|
| 1875 |
+
)
|
|
|
|
|
|
|
| 1876 |
|
| 1877 |
+
train_vae_btn = gr.Button("Train VAE Model", variant="primary")
|
| 1878 |
+
|
| 1879 |
+
gr.Markdown("### VAE Training Results")
|
| 1880 |
|
| 1881 |
with gr.Row():
|
| 1882 |
+
fc_plot = gr.Plot(label="FC Matrices (Original/Reconstructed/Generated)")
|
| 1883 |
+
|
| 1884 |
+
with gr.Row():
|
| 1885 |
+
learning_plot = gr.Plot(label="VAE Learning Curves")
|
| 1886 |
+
|
| 1887 |
+
gr.Markdown("After VAE training completes, proceed to the 'Random Forest Prediction' tab →")
|
| 1888 |
+
|
| 1889 |
+
# Tab 2: Random Forest Prediction
|
| 1890 |
+
with gr.Tab("2. Random Forest Prediction"):
|
| 1891 |
+
gr.Markdown("### Random Forest Model Training")
|
| 1892 |
+
gr.Markdown("First complete the VAE training in tab 1, then configure and train the Random Forest model below:")
|
| 1893 |
+
|
| 1894 |
+
with gr.Row():
|
| 1895 |
+
with gr.Column(scale=1):
|
| 1896 |
+
prediction_type = gr.Radio(
|
| 1897 |
+
label="Prediction Type",
|
| 1898 |
+
choices=["regression", "classification"],
|
| 1899 |
+
value="regression"
|
| 1900 |
+
)
|
| 1901 |
+
outcome_variable = gr.Dropdown(
|
| 1902 |
+
label="Outcome Variable",
|
| 1903 |
+
choices=["wab_aq", "age", "mpo", "education"],
|
| 1904 |
+
value="wab_aq"
|
| 1905 |
+
)
|
| 1906 |
+
|
| 1907 |
+
with gr.Column(scale=1):
|
| 1908 |
+
rf_n_estimators = gr.Slider(
|
| 1909 |
+
minimum=10, maximum=500, step=10,
|
| 1910 |
+
label="Number of Trees", value=100
|
| 1911 |
+
)
|
| 1912 |
+
rf_max_depth = gr.Slider(
|
| 1913 |
+
minimum=3, maximum=50, step=1,
|
| 1914 |
+
label="Max Tree Depth", value=10,
|
| 1915 |
+
info="Set to 0 for unlimited depth"
|
| 1916 |
+
)
|
| 1917 |
+
rf_cv_folds = gr.Slider(
|
| 1918 |
+
minimum=2, maximum=10, step=1,
|
| 1919 |
+
label="Cross-validation Folds", value=5
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
train_rf_btn = gr.Button("Train Random Forest Model", variant="primary")
|
| 1923 |
+
|
| 1924 |
+
gr.Markdown("### Random Forest Results")
|
| 1925 |
|
| 1926 |
with gr.Row():
|
| 1927 |
with gr.Column(scale=1):
|
|
|
|
| 1930 |
prediction_plot = gr.Plot(label="Prediction Performance")
|
| 1931 |
|
| 1932 |
with gr.Row():
|
| 1933 |
+
rf_metrics = gr.Textbox(label="Model Performance Metrics")
|
| 1934 |
+
|
| 1935 |
+
gr.Markdown("After Random Forest training completes, proceed to the 'Treatment Prediction' tab →")
|
| 1936 |
|
| 1937 |
+
# Tab 3: Predict Treatment
|
| 1938 |
+
with gr.Tab("3. Treatment Prediction"):
|
| 1939 |
+
gr.Markdown("### Predict Individual Treatment Outcomes")
|
| 1940 |
+
gr.Markdown("After completing VAE and Random Forest training, you can predict treatment outcomes for individual patients:")
|
| 1941 |
+
|
| 1942 |
with gr.Row():
|
| 1943 |
with gr.Column(scale=1):
|
| 1944 |
+
fmri_file = gr.File(label="Patient fMRI Data (NIfTI file)")
|
| 1945 |
with gr.Column(scale=1):
|
| 1946 |
with gr.Group("Patient Demographics"):
|
| 1947 |
age = gr.Number(label="Age at Stroke", value=60)
|
|
|
|
| 1957 |
with gr.Row():
|
| 1958 |
trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
|
| 1959 |
|
| 1960 |
+
# Define various handler functions for the different tabs
|
| 1961 |
+
|
| 1962 |
+
# Store shared state between tabs
|
| 1963 |
+
app_state = {
|
| 1964 |
+
'vae': None,
|
| 1965 |
+
'latents': None,
|
| 1966 |
+
'demographics': None,
|
| 1967 |
+
'predictor': None,
|
| 1968 |
+
'vae_trained': False,
|
| 1969 |
+
'rf_trained': False
|
| 1970 |
}
|
| 1971 |
|
| 1972 |
+
# Tab 1: VAE Training Handler
|
| 1973 |
+
def handle_vae_training(data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 1974 |
+
skip_behavioral, use_synthetic_nifti, use_synthetic_fc):
|
| 1975 |
+
"""Train the VAE model and display FC visualization and learning curves"""
|
| 1976 |
+
# Store config values
|
|
|
|
|
|
|
| 1977 |
PREDICTION_CONFIG['skip_behavioral_data'] = skip_behavioral
|
| 1978 |
PREDICTION_CONFIG['use_synthetic_nifti'] = use_synthetic_nifti
|
| 1979 |
PREDICTION_CONFIG['use_synthetic_fc'] = use_synthetic_fc
|
|
|
|
| 1985 |
else:
|
| 1986 |
PREDICTION_CONFIG['local_nii_dir'] = None
|
| 1987 |
|
| 1988 |
+
# Log info
|
| 1989 |
+
logger.info(f"Training VAE model with data from: {data_dir}")
|
| 1990 |
+
logger.info(f"VAE parameters: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
|
|
|
|
| 1991 |
|
| 1992 |
+
# Create a subset of app.train_models functionality that just trains the VAE
|
| 1993 |
+
try:
|
| 1994 |
+
# Start by setting up data for the VAE
|
| 1995 |
+
from vae_model import DemoVAE
|
| 1996 |
+
from data_preprocessing import load_and_preprocess_data
|
| 1997 |
+
from main import run_analysis
|
| 1998 |
+
import numpy as np
|
| 1999 |
+
import os
|
| 2000 |
+
|
| 2001 |
+
# Prepare VAE training parameters
|
| 2002 |
+
MODEL_CONFIG.update({
|
| 2003 |
+
'latent_dim': latent_dim,
|
| 2004 |
+
'nepochs': nepochs,
|
| 2005 |
+
'bsize': bsize
|
| 2006 |
+
})
|
| 2007 |
+
|
| 2008 |
+
# First, find and preprocess data
|
| 2009 |
+
logger.info("Looking for data in directory and preprocessing...")
|
| 2010 |
+
|
| 2011 |
+
# This part is similar to app.train_models but only focuses on VAE
|
| 2012 |
+
if data_dir == "SreekarB/OSFData":
|
| 2013 |
+
# Use dataset, similar to existing code in app.train_models
|
| 2014 |
+
# For brevity, we'll call the full train_models function but only
|
| 2015 |
+
# extract the VAE-related results
|
| 2016 |
+
results = app.train_models(
|
| 2017 |
+
data_dir=data_dir,
|
| 2018 |
+
latent_dim=latent_dim,
|
| 2019 |
+
nepochs=nepochs,
|
| 2020 |
+
bsize=bsize
|
| 2021 |
+
)
|
| 2022 |
+
|
| 2023 |
+
# Store results in app_state for the next tabs
|
| 2024 |
+
app_state['vae'] = results.get('vae', None)
|
| 2025 |
+
app_state['latents'] = results.get('latents', None)
|
| 2026 |
+
app_state['demographics'] = results.get('demographics', None)
|
| 2027 |
+
app_state['vae_trained'] = True
|
| 2028 |
+
|
| 2029 |
+
# Return just the VAE visualizations
|
| 2030 |
+
return [
|
| 2031 |
+
results.get('vae', None), # FC matrix visualization
|
| 2032 |
+
results.get('learning', None) # VAE learning curves
|
| 2033 |
+
]
|
| 2034 |
+
else:
|
| 2035 |
+
# Local directory case
|
| 2036 |
+
results = app.train_models(
|
| 2037 |
+
data_dir=data_dir,
|
| 2038 |
+
latent_dim=latent_dim,
|
| 2039 |
+
nepochs=nepochs,
|
| 2040 |
+
bsize=bsize
|
| 2041 |
+
)
|
| 2042 |
+
|
| 2043 |
+
# Store results in app_state
|
| 2044 |
+
app_state['vae'] = results.get('vae', None)
|
| 2045 |
+
app_state['latents'] = results.get('latents', None)
|
| 2046 |
+
app_state['demographics'] = results.get('demographics', None)
|
| 2047 |
+
app_state['vae_trained'] = True
|
| 2048 |
+
|
| 2049 |
+
# Return just the VAE visualizations
|
| 2050 |
+
return [
|
| 2051 |
+
results.get('vae', None), # FC matrix visualization
|
| 2052 |
+
results.get('learning', None) # VAE learning curves
|
| 2053 |
+
]
|
| 2054 |
+
except Exception as e:
|
| 2055 |
+
logger.error(f"Error in VAE training: {str(e)}", exc_info=True)
|
| 2056 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2057 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}",
|
| 2058 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2059 |
+
fontsize=12, color='red', wrap=True)
|
| 2060 |
+
plt.axis('off')
|
| 2061 |
+
|
| 2062 |
+
# Return error figures for both outputs
|
| 2063 |
+
return [error_fig, error_fig]
|
| 2064 |
+
|
| 2065 |
+
# Tab 2: Random Forest Training Handler
|
| 2066 |
+
def handle_rf_training(outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
|
| 2067 |
+
"""Train the Random Forest model using the VAE latent representations"""
|
| 2068 |
+
# Check if VAE has been trained
|
| 2069 |
+
if not app_state['vae_trained'] or app_state['latents'] is None:
|
| 2070 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2071 |
+
message = "Error: You must train the VAE model in Tab 1 first!"
|
| 2072 |
+
plt.text(0.5, 0.5, message,
|
| 2073 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2074 |
+
fontsize=14, color='red')
|
| 2075 |
+
plt.axis('off')
|
| 2076 |
+
|
| 2077 |
+
# Return error for both outputs
|
| 2078 |
+
return [error_fig, error_fig, "Error: VAE not trained. Go to Tab 1 and train the VAE first."]
|
| 2079 |
|
| 2080 |
+
try:
|
| 2081 |
+
# Update RF configuration
|
| 2082 |
+
PREDICTION_CONFIG['default_outcome'] = outcome_variable
|
| 2083 |
+
PREDICTION_CONFIG['n_estimators'] = rf_n_estimators
|
| 2084 |
+
PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
|
| 2085 |
+
PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
|
| 2086 |
+
|
| 2087 |
+
logger.info(f"Training Random Forest model: outcome={outcome_variable}")
|
| 2088 |
+
logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
|
| 2089 |
+
|
| 2090 |
+
# Get data from app_state
|
| 2091 |
+
latents = app_state['latents']
|
| 2092 |
+
demographics = app_state['demographics']
|
| 2093 |
+
|
| 2094 |
+
# Train Random Forest predictor
|
| 2095 |
+
from rcf_prediction import AphasiaTreatmentPredictor
|
| 2096 |
+
import pandas as pd
|
| 2097 |
+
import numpy as np
|
| 2098 |
+
|
| 2099 |
+
# Need to find treatment outcomes data
|
| 2100 |
+
# This would normally be loaded in train_models, so we need
|
| 2101 |
+
# to mock it here or load from app_state
|
| 2102 |
+
if hasattr(app, 'last_treatment_file') and os.path.exists(app.last_treatment_file):
|
| 2103 |
+
treatment_file = app.last_treatment_file
|
| 2104 |
+
treatment_df = pd.read_csv(treatment_file)
|
| 2105 |
+
treatment_outcomes = treatment_df['outcome_score'].values
|
| 2106 |
+
|
| 2107 |
+
# Initialize predictor
|
| 2108 |
+
predictor = AphasiaTreatmentPredictor(
|
| 2109 |
+
n_estimators=rf_n_estimators,
|
| 2110 |
+
max_depth=rf_max_depth if rf_max_depth > 0 else None
|
| 2111 |
+
)
|
| 2112 |
+
|
| 2113 |
+
# Cross-validate
|
| 2114 |
+
cv_results = predictor.cross_validate(
|
| 2115 |
+
latents=latents,
|
| 2116 |
+
demographics=demographics,
|
| 2117 |
+
treatment_outcomes=treatment_outcomes,
|
| 2118 |
+
n_splits=rf_cv_folds
|
| 2119 |
+
)
|
| 2120 |
+
|
| 2121 |
+
# Fit final model
|
| 2122 |
+
predictor.fit(latents, demographics, treatment_outcomes)
|
| 2123 |
+
|
| 2124 |
+
# Store in app_state
|
| 2125 |
+
app_state['predictor'] = predictor
|
| 2126 |
+
app_state['rf_trained'] = True
|
| 2127 |
+
|
| 2128 |
+
# Create feature importance plot
|
| 2129 |
+
importance_fig = predictor.plot_feature_importance()
|
| 2130 |
+
|
| 2131 |
+
# Create prediction performance plot
|
| 2132 |
+
predictions = cv_results['predictions']
|
| 2133 |
+
prediction_stds = cv_results['prediction_stds']
|
| 2134 |
+
|
| 2135 |
+
performance_fig = plt.figure(figsize=(8, 6))
|
| 2136 |
+
|
| 2137 |
+
# Check if we have valid predictions
|
| 2138 |
+
if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes):
|
| 2139 |
+
# Only create scatter plot if we have matching data
|
| 2140 |
+
plt.scatter(treatment_outcomes, predictions)
|
| 2141 |
+
|
| 2142 |
+
# Reference line
|
| 2143 |
+
min_val = min(np.min(treatment_outcomes), np.min(predictions))
|
| 2144 |
+
max_val = max(np.max(treatment_outcomes), np.max(predictions))
|
| 2145 |
+
plt.plot([min_val, max_val], [min_val, max_val], 'r--')
|
| 2146 |
+
|
| 2147 |
+
# Confidence band
|
| 2148 |
+
plt.fill_between(treatment_outcomes,
|
| 2149 |
+
predictions - 2*prediction_stds,
|
| 2150 |
+
predictions + 2*prediction_stds,
|
| 2151 |
+
alpha=0.2, color='gray')
|
| 2152 |
+
|
| 2153 |
+
plt.xlabel('Actual Outcome')
|
| 2154 |
+
plt.ylabel('Predicted Outcome')
|
| 2155 |
+
|
| 2156 |
+
# Get performance metrics
|
| 2157 |
+
metrics_text = ""
|
| 2158 |
+
mean_metrics = cv_results.get('mean_metrics', {})
|
| 2159 |
+
|
| 2160 |
+
r2 = mean_metrics.get('r2', 0)
|
| 2161 |
+
rmse = mean_metrics.get('rmse', 0)
|
| 2162 |
+
plt.title(f'Treatment Outcome Prediction\nR² = {r2:.3f}, RMSE = {rmse:.3f}')
|
| 2163 |
+
metrics_text = f"Regression Model Performance:\nR² = {r2:.4f}\nRMSE = {rmse:.4f}"
|
| 2164 |
+
else:
|
| 2165 |
+
# Handle case with no data
|
| 2166 |
+
plt.text(0.5, 0.5, "No prediction data available",
|
| 2167 |
+
ha='center', va='center', transform=plt.gca().transAxes)
|
| 2168 |
+
metrics_text = "No performance metrics available"
|
| 2169 |
+
|
| 2170 |
+
plt.tight_layout()
|
| 2171 |
+
|
| 2172 |
+
return [importance_fig, performance_fig, metrics_text]
|
| 2173 |
+
else:
|
| 2174 |
+
# No treatment file available
|
| 2175 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2176 |
+
message = "Error: Treatment outcomes file not found. Please retrain the VAE in Tab 1."
|
| 2177 |
+
plt.text(0.5, 0.5, message,
|
| 2178 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2179 |
+
fontsize=14, color='red')
|
| 2180 |
+
plt.axis('off')
|
| 2181 |
+
|
| 2182 |
+
return [error_fig, error_fig, "Error: Treatment outcomes file not found."]
|
| 2183 |
+
|
| 2184 |
+
except Exception as e:
|
| 2185 |
+
logger.error(f"Error in RF training: {str(e)}", exc_info=True)
|
| 2186 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2187 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}",
|
| 2188 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2189 |
+
fontsize=12, color='red', wrap=True)
|
| 2190 |
+
plt.axis('off')
|
| 2191 |
+
|
| 2192 |
+
return [error_fig, error_fig, f"Error: {str(e)}"]
|
| 2193 |
+
|
| 2194 |
+
# Connect the tab handlers
|
| 2195 |
+
|
| 2196 |
+
# VAE Training tab
|
| 2197 |
+
train_vae_btn.click(
|
| 2198 |
+
fn=handle_vae_training,
|
| 2199 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2200 |
+
skip_behavioral, use_synthetic_nifti, use_synthetic_fc],
|
| 2201 |
+
outputs=[fc_plot, learning_plot]
|
| 2202 |
+
)
|
| 2203 |
+
|
| 2204 |
+
# Random Forest Training tab
|
| 2205 |
+
train_rf_btn.click(
|
| 2206 |
+
fn=handle_rf_training,
|
| 2207 |
+
inputs=[prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds],
|
| 2208 |
+
outputs=[importance_plot, prediction_plot, rf_metrics]
|
| 2209 |
)
|
| 2210 |
|
| 2211 |
+
# Tab 3: Treatment Prediction Handler
|
| 2212 |
+
def handle_treatment_prediction(fmri_file, age, sex, months, wab):
|
| 2213 |
+
"""Predict treatment outcome for a new patient"""
|
| 2214 |
+
# Check if models have been trained
|
| 2215 |
+
if not app_state['vae_trained'] or not app_state['rf_trained']:
|
| 2216 |
+
error_message = "Error: You must train both the VAE (Tab 1) and Random Forest (Tab 2) models first!"
|
| 2217 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2218 |
+
plt.text(0.5, 0.5, error_message,
|
| 2219 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2220 |
+
fontsize=14, color='red')
|
| 2221 |
+
plt.axis('off')
|
| 2222 |
+
|
| 2223 |
+
return [error_message, error_fig]
|
| 2224 |
+
|
| 2225 |
+
# Use the trained models from app_state for prediction
|
| 2226 |
+
try:
|
| 2227 |
+
# Set up prediction
|
| 2228 |
+
if app_state['vae'] is None or app_state['predictor'] is None:
|
| 2229 |
+
return ["Error: Models not properly trained", None]
|
| 2230 |
+
|
| 2231 |
+
# Create a temporary prediction app with our trained models
|
| 2232 |
+
temp_app = AphasiaPredictionApp()
|
| 2233 |
+
temp_app.vae = app_state['vae']
|
| 2234 |
+
temp_app.predictor = app_state['predictor']
|
| 2235 |
+
temp_app.trained = True
|
| 2236 |
+
temp_app.latent_dim = app_state['vae'].latent_dim if hasattr(app_state['vae'], 'latent_dim') else 32
|
| 2237 |
+
|
| 2238 |
+
# Make prediction
|
| 2239 |
+
return temp_app.predict_treatment(
|
| 2240 |
+
fmri_file=fmri_file,
|
| 2241 |
+
age=age,
|
| 2242 |
+
sex=sex,
|
| 2243 |
+
months_post_stroke=months,
|
| 2244 |
+
wab_score=wab
|
| 2245 |
+
)
|
| 2246 |
+
except Exception as e:
|
| 2247 |
+
logger.error(f"Error in treatment prediction: {str(e)}", exc_info=True)
|
| 2248 |
+
return [f"Error in prediction: {str(e)}", None]
|
| 2249 |
+
|
| 2250 |
+
# Connect the treatment prediction handler
|
| 2251 |
predict_btn.click(
|
| 2252 |
+
fn=handle_treatment_prediction,
|
| 2253 |
inputs=[fmri_file, age, sex, months, wab],
|
| 2254 |
outputs=[prediction_text, trajectory_plot]
|
| 2255 |
)
|
config.py
CHANGED
|
@@ -27,7 +27,6 @@ PREDICTION_CONFIG = {
|
|
| 27 |
'n_estimators': 100,
|
| 28 |
'max_depth': None,
|
| 29 |
'cv_folds': 5,
|
| 30 |
-
'prediction_type': 'regression',
|
| 31 |
'default_outcome': 'wab_aq',
|
| 32 |
'save_path': 'results/treatment_predictor.joblib',
|
| 33 |
'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
|
|
|
|
| 27 |
'n_estimators': 100,
|
| 28 |
'max_depth': None,
|
| 29 |
'cv_folds': 5,
|
|
|
|
| 30 |
'default_outcome': 'wab_aq',
|
| 31 |
'save_path': 'results/treatment_predictor.joblib',
|
| 32 |
'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
|
demo_fc_visualization.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demo script to visualize FC matrices from real fMRI data using nilearn's built-in datasets.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from nilearn import datasets
|
| 8 |
+
from nilearn import input_data, connectome
|
| 9 |
+
from fc_visualization import FCVisualizer
|
| 10 |
+
|
| 11 |
+
def visualize_from_nilearn_dataset():
|
| 12 |
+
"""Download and visualize FC matrices from nilearn's ADHD dataset."""
|
| 13 |
+
print("Downloading a sample fMRI dataset (ADHD)...")
|
| 14 |
+
adhd_dataset = datasets.fetch_adhd(n_subjects=1)
|
| 15 |
+
|
| 16 |
+
# Get the fMRI file path
|
| 17 |
+
func_file = adhd_dataset.func[0]
|
| 18 |
+
confound_file = adhd_dataset.confounds[0]
|
| 19 |
+
|
| 20 |
+
print(f"Downloaded fMRI file: {func_file}")
|
| 21 |
+
|
| 22 |
+
# Get Power atlas coordinates
|
| 23 |
+
power = datasets.fetch_coords_power_2011()
|
| 24 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 25 |
+
|
| 26 |
+
print(f"Using Power atlas with {len(coords)} ROIs")
|
| 27 |
+
|
| 28 |
+
# Create a masker to extract time series from the ROIs
|
| 29 |
+
masker = input_data.NiftiSpheresMasker(
|
| 30 |
+
coords,
|
| 31 |
+
radius=8, # 8mm radius
|
| 32 |
+
standardize=True,
|
| 33 |
+
memory='nilearn_cache',
|
| 34 |
+
memory_level=1,
|
| 35 |
+
verbose=1,
|
| 36 |
+
detrend=True,
|
| 37 |
+
low_pass=0.08,
|
| 38 |
+
high_pass=0.01,
|
| 39 |
+
t_r=2.0 # ADHD dataset TR
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Extract time series, including confounds
|
| 43 |
+
print("Extracting time series from ROIs...")
|
| 44 |
+
time_series = masker.fit_transform(func_file, confounds=confound_file)
|
| 45 |
+
print(f"Time series shape: {time_series.shape}")
|
| 46 |
+
|
| 47 |
+
# Compute correlation matrix (FC matrix)
|
| 48 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 49 |
+
kind='correlation',
|
| 50 |
+
vectorize=False,
|
| 51 |
+
discard_diagonal=False
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 55 |
+
print(f"FC matrix shape: {fc_matrix.shape}")
|
| 56 |
+
|
| 57 |
+
# Save the FC matrix for future use
|
| 58 |
+
np.save('adhd_fc_matrix.npy', fc_matrix)
|
| 59 |
+
print("Saved FC matrix to adhd_fc_matrix.npy")
|
| 60 |
+
|
| 61 |
+
# Visualize the FC matrix
|
| 62 |
+
visualizer = FCVisualizer(cmap='RdBu_r', vmin=-1, vmax=1)
|
| 63 |
+
fig, _ = visualizer.plot_single_matrix(fc_matrix, title="ADHD FC Matrix (Power Atlas)")
|
| 64 |
+
|
| 65 |
+
# Save the figure
|
| 66 |
+
fig.savefig('adhd_fc_matrix.png', dpi=300, bbox_inches='tight')
|
| 67 |
+
print("Saved visualization to adhd_fc_matrix.png")
|
| 68 |
+
|
| 69 |
+
# Show the figure
|
| 70 |
+
plt.show()
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
visualize_from_nilearn_dataset()
|
demovae/model.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from sklearn.linear_model import Ridge
|
| 10 |
+
from sklearn.linear_model import LogisticRegression
|
| 11 |
+
|
| 12 |
+
def to_torch(x):
|
| 13 |
+
return torch.from_numpy(x).float()
|
| 14 |
+
|
| 15 |
+
def to_cuda(x, use_cuda):
|
| 16 |
+
if use_cuda:
|
| 17 |
+
try:
|
| 18 |
+
return x.cuda()
|
| 19 |
+
except (RuntimeError, AssertionError) as e:
|
| 20 |
+
print(f"Warning: CUDA error: {e}. Falling back to CPU.")
|
| 21 |
+
return x
|
| 22 |
+
else:
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
def to_numpy(x):
|
| 26 |
+
return x.detach().cpu().numpy()
|
| 27 |
+
|
| 28 |
+
class VAE(nn.Module):
|
| 29 |
+
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
|
| 30 |
+
super(VAE, self).__init__()
|
| 31 |
+
self.input_dim = input_dim
|
| 32 |
+
self.latent_dim = latent_dim
|
| 33 |
+
self.demo_dim = demo_dim
|
| 34 |
+
self.use_cuda = use_cuda
|
| 35 |
+
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
|
| 36 |
+
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
|
| 37 |
+
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
|
| 38 |
+
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
|
| 39 |
+
|
| 40 |
+
def enc(self, x):
|
| 41 |
+
x = F.relu(self.enc1(x))
|
| 42 |
+
z = self.enc2(x)
|
| 43 |
+
return z
|
| 44 |
+
|
| 45 |
+
def gen(self, n):
|
| 46 |
+
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 47 |
+
|
| 48 |
+
def dec(self, z, demo):
|
| 49 |
+
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 50 |
+
x = F.relu(self.dec1(z))
|
| 51 |
+
x = self.dec2(x)
|
| 52 |
+
#x = x.reshape(len(z), 264, 5)
|
| 53 |
+
#x = torch.einsum('nac,nbc->nab', x, x)
|
| 54 |
+
#a,b = np.triu_indices(264, 1)
|
| 55 |
+
#x = x[:,a,b]
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
def rmse(a, b, mean=torch.mean):
|
| 59 |
+
return mean((a-b)**2)**0.5
|
| 60 |
+
|
| 61 |
+
def latent_loss(z, use_cuda=True):
|
| 62 |
+
C = z.T@z
|
| 63 |
+
mu = torch.mean(z, dim=0)
|
| 64 |
+
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
|
| 65 |
+
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 66 |
+
loss_C = rmse(C, tgt1)
|
| 67 |
+
loss_mu = rmse(mu, tgt2)
|
| 68 |
+
return loss_C, loss_mu, C, mu
|
| 69 |
+
|
| 70 |
+
def decor_loss(z, demo, use_cuda=True):
|
| 71 |
+
ps = []
|
| 72 |
+
losses = []
|
| 73 |
+
for di in range(demo.shape[1]):
|
| 74 |
+
d = demo[:,di]
|
| 75 |
+
d = d - torch.mean(d)
|
| 76 |
+
p = torch.einsum('n,nz->z', d, z)
|
| 77 |
+
p = p/torch.std(d)
|
| 78 |
+
p = p/torch.einsum('nz,nz->z', z, z)
|
| 79 |
+
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 80 |
+
loss = rmse(p, tgt)
|
| 81 |
+
losses.append(loss)
|
| 82 |
+
ps.append(p)
|
| 83 |
+
losses = torch.stack(losses)
|
| 84 |
+
return losses, ps
|
| 85 |
+
|
| 86 |
+
def pretty(x):
|
| 87 |
+
return f'{round(float(x), 4)}'
|
| 88 |
+
|
| 89 |
+
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 90 |
+
demo_t = []
|
| 91 |
+
demo_idx = 0
|
| 92 |
+
for d,t,s in zip(demo, demo_types, pred_stats):
|
| 93 |
+
if t == 'continuous':
|
| 94 |
+
demo_t.append(to_cuda(to_torch(d), use_cuda))
|
| 95 |
+
elif t == 'categorical':
|
| 96 |
+
for dd in d:
|
| 97 |
+
if dd not in s:
|
| 98 |
+
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
|
| 99 |
+
raise Exception('Bad demographic')
|
| 100 |
+
for ss in s:
|
| 101 |
+
idx = (d == ss).astype('bool')
|
| 102 |
+
zeros = torch.zeros(len(d))
|
| 103 |
+
zeros[idx] = 1
|
| 104 |
+
demo_t.append(to_cuda(zeros, use_cuda))
|
| 105 |
+
demo_idx += 1
|
| 106 |
+
demo_t = torch.stack(demo_t).permute(1,0)
|
| 107 |
+
return demo_t
|
| 108 |
+
|
| 109 |
+
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
| 110 |
+
# Get linear predictors for demographics
|
| 111 |
+
pred_w = []
|
| 112 |
+
pred_i = []
|
| 113 |
+
# Pred stats are mean and std for continuous, and a list of all values for categorical
|
| 114 |
+
pred_stats = []
|
| 115 |
+
for i,d,t in zip(range(len(demo)), demo, demo_types):
|
| 116 |
+
print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
|
| 117 |
+
if t == 'continuous':
|
| 118 |
+
pred_stats.append([np.mean(d), np.std(d)])
|
| 119 |
+
reg = Ridge(alpha=alpha).fit(x, d)
|
| 120 |
+
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
|
| 121 |
+
reg_i = reg.intercept_
|
| 122 |
+
pred_w.append(reg_w)
|
| 123 |
+
pred_i.append(reg_i)
|
| 124 |
+
elif t == 'categorical':
|
| 125 |
+
pred_stats.append(sorted(list(set(list(d)))))
|
| 126 |
+
reg = LogisticRegression(C=LR_C).fit(x, d)
|
| 127 |
+
# Binary
|
| 128 |
+
if len(reg.coef_) == 1:
|
| 129 |
+
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
|
| 130 |
+
reg_i = reg.intercept_[0]
|
| 131 |
+
pred_w.append(-reg_w)
|
| 132 |
+
pred_i.append(-reg_i)
|
| 133 |
+
pred_w.append(reg_w)
|
| 134 |
+
pred_i.append(reg_i)
|
| 135 |
+
# Categorical
|
| 136 |
+
else:
|
| 137 |
+
for i in range(len(reg.coef_)):
|
| 138 |
+
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
|
| 139 |
+
reg_i = reg.intercept_[i]
|
| 140 |
+
pred_w.append(reg_w)
|
| 141 |
+
pred_i.append(reg_i)
|
| 142 |
+
else:
|
| 143 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 144 |
+
raise Exception('Bad demographic type')
|
| 145 |
+
print(' done')
|
| 146 |
+
ret_obj.pred_stats = pred_stats
|
| 147 |
+
# Convert input to pytorch
|
| 148 |
+
print('Converting input to pytorch')
|
| 149 |
+
x = to_cuda(to_torch(x), vae.use_cuda)
|
| 150 |
+
# Convert demographics to pytorch
|
| 151 |
+
print('Converting demographics to pytorch')
|
| 152 |
+
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
|
| 153 |
+
# Training loop
|
| 154 |
+
print('Beginning VAE training')
|
| 155 |
+
ce = nn.CrossEntropyLoss()
|
| 156 |
+
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 157 |
+
for e in range(nepochs):
|
| 158 |
+
for bs in range(0,len(x),bsize):
|
| 159 |
+
xb = x[bs:(bs+bsize)]
|
| 160 |
+
db = demo_t[bs:(bs+bsize)]
|
| 161 |
+
optim.zero_grad()
|
| 162 |
+
# Reconstruct
|
| 163 |
+
z = vae.enc(xb)
|
| 164 |
+
y = vae.dec(z, db)
|
| 165 |
+
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
|
| 166 |
+
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
|
| 167 |
+
loss_decor = sum(loss_decor)
|
| 168 |
+
loss_rec = rmse(xb, y)
|
| 169 |
+
# Sample demographics
|
| 170 |
+
demo_gen = []
|
| 171 |
+
for s,t in zip(pred_stats, demo_types):
|
| 172 |
+
if t == 'continuous':
|
| 173 |
+
mu = s[0]
|
| 174 |
+
std = s[1]
|
| 175 |
+
dd = torch.randn(100).float()
|
| 176 |
+
dd = dd*std+mu
|
| 177 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 178 |
+
demo_gen.append(dd)
|
| 179 |
+
elif t == 'categorical':
|
| 180 |
+
idx = random.randint(0, len(s)-1)
|
| 181 |
+
for i in range(len(s)):
|
| 182 |
+
if idx == i:
|
| 183 |
+
dd = torch.ones(100).float()
|
| 184 |
+
else:
|
| 185 |
+
dd = torch.zeros(100).float()
|
| 186 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 187 |
+
demo_gen.append(dd)
|
| 188 |
+
demo_gen = torch.stack(demo_gen).permute(1,0)
|
| 189 |
+
# Generate
|
| 190 |
+
z = vae.gen(100)
|
| 191 |
+
y = vae.dec(z, demo_gen)
|
| 192 |
+
# Regressor/classifier guidance loss
|
| 193 |
+
losses_pred = []
|
| 194 |
+
idcs = []
|
| 195 |
+
dg_idx = 0
|
| 196 |
+
for s,t in zip(pred_stats, demo_types):
|
| 197 |
+
if t == 'continuous':
|
| 198 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 199 |
+
loss = rmse(demo_gen[:,dg_idx], yy)
|
| 200 |
+
losses_pred.append(loss)
|
| 201 |
+
idcs.append(float(demo_gen[0,dg_idx]))
|
| 202 |
+
dg_idx += 1
|
| 203 |
+
elif t == 'categorical':
|
| 204 |
+
loss = 0
|
| 205 |
+
for i in range(len(s)):
|
| 206 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 207 |
+
loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
|
| 208 |
+
idcs.append(int(demo_gen[0,dg_idx]))
|
| 209 |
+
dg_idx += 1
|
| 210 |
+
losses_pred.append(loss)
|
| 211 |
+
total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
|
| 212 |
+
total_loss.backward()
|
| 213 |
+
optim.step()
|
| 214 |
+
if e%pperiod == 0 or e == nepochs-1:
|
| 215 |
+
print(f'Epoch {e} ', end='')
|
| 216 |
+
print(f'ReconLoss {pretty(loss_rec)} ', end='')
|
| 217 |
+
print(f'CovarianceLoss {pretty(loss_C)} ', end='')
|
| 218 |
+
print(f'MeanLoss {pretty(loss_mu)} ', end='')
|
| 219 |
+
print(f'DecorLoss {pretty(loss_decor)} ', end='')
|
| 220 |
+
losses_pred = [pretty(loss) for loss in losses_pred]
|
| 221 |
+
print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
|
| 222 |
+
print()
|
| 223 |
+
print('Training complete.')
|
| 224 |
+
|
| 225 |
+
|
demovae/sklearn.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from .model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from sklearn.base import BaseEstimator
|
| 6 |
+
|
| 7 |
+
# For saving
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
class DemoVAE(BaseEstimator):
|
| 11 |
+
def __init__(self, **params):
|
| 12 |
+
self.set_params(**params)
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def get_default_params():
|
| 16 |
+
return dict(latent_dim=60, # Latent dimension
|
| 17 |
+
use_cuda=True, # GPU acceleration
|
| 18 |
+
nepochs=3000, # Training epochs
|
| 19 |
+
pperiod=100, # Epochs between printing updates
|
| 20 |
+
bsize=1000, # Batch size
|
| 21 |
+
loss_C_mult=1, # Covariance loss (KL div)
|
| 22 |
+
loss_mu_mult=1, # Mean loss (KL div)
|
| 23 |
+
loss_rec_mult=100, # Reconstruction loss
|
| 24 |
+
loss_decor_mult=10, # Latent-demographic decorrelation loss
|
| 25 |
+
loss_pred_mult=0.001, # Classifier/regressor guidance loss
|
| 26 |
+
alpha=100, # Regularization for continuous guidance models
|
| 27 |
+
LR_C=100, # Regularization for categorical guidance models
|
| 28 |
+
lr=1e-4, # Learning rate
|
| 29 |
+
weight_decay=0, # L2 regularization for VAE model
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def get_params(self, **params):
|
| 33 |
+
return dict(latent_dim=self.latent_dim,
|
| 34 |
+
use_cuda=self.use_cuda,
|
| 35 |
+
nepochs=self.nepochs,
|
| 36 |
+
pperiod=self.pperiod,
|
| 37 |
+
bsize=self.bsize,
|
| 38 |
+
loss_C_mult=self.loss_C_mult,
|
| 39 |
+
loss_mu_mult=self.loss_mu_mult,
|
| 40 |
+
loss_rec_mult=self.loss_rec_mult,
|
| 41 |
+
loss_decor_mult=self.loss_decor_mult,
|
| 42 |
+
loss_pred_mult=self.loss_pred_mult,
|
| 43 |
+
alpha=self.alpha,
|
| 44 |
+
LR_C=self.LR_C,
|
| 45 |
+
lr=self.lr,
|
| 46 |
+
weight_decay=self.weight_decay,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def set_params(self, **params):
|
| 50 |
+
dft = DemoVAE.get_default_params()
|
| 51 |
+
for key in dft:
|
| 52 |
+
if key in params:
|
| 53 |
+
setattr(self, key, params[key])
|
| 54 |
+
else:
|
| 55 |
+
setattr(self, key, dft[key])
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
def fit(self, x, demo, demo_types, **kwargs):
|
| 59 |
+
# Get demo_dim
|
| 60 |
+
demo_dim = 0
|
| 61 |
+
for d,t in zip(demo, demo_types):
|
| 62 |
+
if t == 'continuous':
|
| 63 |
+
demo_dim += 1
|
| 64 |
+
elif t == 'categorical':
|
| 65 |
+
ll = len(set(list(d)))
|
| 66 |
+
if ll == 1:
|
| 67 |
+
print('Only one type of category for categorical variable')
|
| 68 |
+
raise Exception('Bad categorical')
|
| 69 |
+
demo_dim += ll
|
| 70 |
+
else:
|
| 71 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 72 |
+
raise Exception('Bad demographic type')
|
| 73 |
+
# Save parameters
|
| 74 |
+
self.input_dim = x.shape[1]
|
| 75 |
+
self.demo_dim = demo_dim
|
| 76 |
+
# Create model
|
| 77 |
+
self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
|
| 78 |
+
# Train model
|
| 79 |
+
train_vae(self.vae, x, demo, demo_types,
|
| 80 |
+
self.nepochs, self.pperiod, self.bsize,
|
| 81 |
+
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
|
| 82 |
+
self.lr, self.weight_decay, self.alpha, self.LR_C,
|
| 83 |
+
self)
|
| 84 |
+
return self
|
| 85 |
+
|
| 86 |
+
def transform(self, x, demo, demo_types, **kwargs):
|
| 87 |
+
if isinstance(x, int):
|
| 88 |
+
# Generate
|
| 89 |
+
z = self.vae.gen(x)
|
| 90 |
+
else:
|
| 91 |
+
# Get latents for real data
|
| 92 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 93 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 94 |
+
y = self.vae.dec(z, demo_t)
|
| 95 |
+
return to_numpy(y)
|
| 96 |
+
|
| 97 |
+
def fit_transform(self, x, demo, demo_types, **kwargs):
|
| 98 |
+
self.fit(x, demo, demo_types)
|
| 99 |
+
return self.transform(x, demo, demo_types)
|
| 100 |
+
|
| 101 |
+
def get_latents(self, x):
|
| 102 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 103 |
+
return to_numpy(z)
|
| 104 |
+
|
| 105 |
+
def save(self, path):
|
| 106 |
+
params = self.get_params()
|
| 107 |
+
dct = dict(pred_stats=self.pred_stats,
|
| 108 |
+
params=params,
|
| 109 |
+
input_dim=self.input_dim,
|
| 110 |
+
demo_dim=self.demo_dim,
|
| 111 |
+
model_state_dict=self.vae.state_dict())
|
| 112 |
+
torch.save(dct, path)
|
| 113 |
+
|
| 114 |
+
def load(self, path):
|
| 115 |
+
dct = torch.load(path)
|
| 116 |
+
self.pred_stats = dct['pred_stats']
|
| 117 |
+
self.set_params(**dct['params'])
|
| 118 |
+
self.vae = VAE(dct['input_dim'],
|
| 119 |
+
dct['params']['latent_dim'],
|
| 120 |
+
dct['demo_dim'],
|
| 121 |
+
dct['params']['use_cuda'])
|
| 122 |
+
self.vae.load_state_dict(dct['model_state_dict'])
|
| 123 |
+
|
| 124 |
+
|
fc_visualization.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FC Matrix Visualization Module.
|
| 3 |
+
|
| 4 |
+
This module provides functionality for visualizing Functional Connectivity matrices
|
| 5 |
+
independently from the prediction pipeline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import nibabel as nib
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from nilearn import input_data, connectome
|
| 17 |
+
from nilearn.image import load_img
|
| 18 |
+
from nilearn import datasets
|
| 19 |
+
NILEARN_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
NILEARN_AVAILABLE = False
|
| 22 |
+
print("Warning: nilearn not available. Direct fMRI processing disabled.")
|
| 23 |
+
|
| 24 |
+
from config import PREPROCESS_CONFIG
|
| 25 |
+
|
| 26 |
+
class FCVisualizer:
|
| 27 |
+
"""Class for visualizing FC matrices."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, cmap='RdBu_r', vmin=-1, vmax=1):
|
| 30 |
+
"""
|
| 31 |
+
Initialize FCVisualizer with display parameters.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
cmap: Colormap to use for FC matrices
|
| 35 |
+
vmin: Minimum value for color scaling
|
| 36 |
+
vmax: Maximum value for color scaling
|
| 37 |
+
"""
|
| 38 |
+
self.cmap = cmap
|
| 39 |
+
self.vmin = vmin
|
| 40 |
+
self.vmax = vmax
|
| 41 |
+
|
| 42 |
+
def plot_single_matrix(self, matrix, title="FC Matrix", ax=None, fig=None):
|
| 43 |
+
"""
|
| 44 |
+
Plot a single FC matrix.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
matrix: 2D numpy array containing FC matrix
|
| 48 |
+
title: Title for the plot
|
| 49 |
+
ax: Matplotlib axis to plot on (optional)
|
| 50 |
+
fig: Matplotlib figure (optional)
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
fig, ax: The figure and axis objects
|
| 54 |
+
"""
|
| 55 |
+
if ax is None:
|
| 56 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 57 |
+
|
| 58 |
+
im = ax.imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
|
| 59 |
+
ax.set_title(title)
|
| 60 |
+
plt.colorbar(im, ax=ax)
|
| 61 |
+
|
| 62 |
+
return fig, ax
|
| 63 |
+
|
| 64 |
+
def plot_matrix_comparison(self, matrices, titles=None, figsize=None):
|
| 65 |
+
"""
|
| 66 |
+
Plot multiple FC matrices for comparison.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
matrices: List of 2D numpy arrays containing FC matrices
|
| 70 |
+
titles: List of titles for each matrix (optional)
|
| 71 |
+
figsize: Custom figure size (optional)
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
fig: The figure object
|
| 75 |
+
"""
|
| 76 |
+
n_matrices = len(matrices)
|
| 77 |
+
|
| 78 |
+
if figsize is None:
|
| 79 |
+
figsize = (5*n_matrices, 5)
|
| 80 |
+
|
| 81 |
+
if titles is None:
|
| 82 |
+
titles = [f"FC Matrix {i+1}" for i in range(n_matrices)]
|
| 83 |
+
|
| 84 |
+
fig, axes = plt.subplots(1, n_matrices, figsize=figsize)
|
| 85 |
+
|
| 86 |
+
# Handle single matrix case
|
| 87 |
+
if n_matrices == 1:
|
| 88 |
+
axes = [axes]
|
| 89 |
+
|
| 90 |
+
for i, (matrix, title) in enumerate(zip(matrices, titles)):
|
| 91 |
+
im = axes[i].imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
|
| 92 |
+
axes[i].set_title(title)
|
| 93 |
+
plt.colorbar(im, ax=axes[i])
|
| 94 |
+
|
| 95 |
+
plt.tight_layout()
|
| 96 |
+
return fig
|
| 97 |
+
|
| 98 |
+
def load_and_visualize_npy(self, file_path):
|
| 99 |
+
"""
|
| 100 |
+
Load and visualize an FC matrix from a .npy file.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
file_path: Path to the .npy file
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
fig: The figure object containing the visualization
|
| 107 |
+
"""
|
| 108 |
+
# Load the matrix
|
| 109 |
+
data = np.load(file_path)
|
| 110 |
+
|
| 111 |
+
# Check if it's an upper triangle or full matrix
|
| 112 |
+
if len(data.shape) == 1:
|
| 113 |
+
# Convert upper triangular to full matrix
|
| 114 |
+
matrix = self._triu_to_matrix(data)
|
| 115 |
+
else:
|
| 116 |
+
matrix = data
|
| 117 |
+
|
| 118 |
+
# Plot the matrix
|
| 119 |
+
filename = os.path.basename(file_path)
|
| 120 |
+
title = f"FC Matrix: {filename}"
|
| 121 |
+
fig, _ = self.plot_single_matrix(matrix, title=title)
|
| 122 |
+
return fig
|
| 123 |
+
|
| 124 |
+
def _triu_to_matrix(self, triu_values, fisher_z=True):
|
| 125 |
+
"""
|
| 126 |
+
Convert upper triangular values to a full FC matrix.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
triu_values: 1D array of upper triangular values
|
| 130 |
+
fisher_z: Whether values are Fisher z-transformed
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
full_matrix: 2D symmetric matrix
|
| 134 |
+
"""
|
| 135 |
+
# Calculate matrix size from triu length
|
| 136 |
+
n = int(np.sqrt(2 * len(triu_values) + 0.25) + 0.5)
|
| 137 |
+
|
| 138 |
+
# Initialize empty matrix
|
| 139 |
+
matrix = np.zeros((n, n))
|
| 140 |
+
|
| 141 |
+
# Get indices for upper triangle
|
| 142 |
+
triu_indices = np.triu_indices_from(matrix, k=1)
|
| 143 |
+
|
| 144 |
+
# If Fisher z-transformed, convert back
|
| 145 |
+
if fisher_z:
|
| 146 |
+
values_to_set = np.tanh(triu_values)
|
| 147 |
+
else:
|
| 148 |
+
values_to_set = triu_values
|
| 149 |
+
|
| 150 |
+
# Set upper triangle values
|
| 151 |
+
matrix[triu_indices] = values_to_set
|
| 152 |
+
|
| 153 |
+
# Make symmetric
|
| 154 |
+
matrix = matrix + matrix.T
|
| 155 |
+
|
| 156 |
+
# Set diagonal to 1.0 (perfect correlation)
|
| 157 |
+
np.fill_diagonal(matrix, 1.0)
|
| 158 |
+
|
| 159 |
+
return matrix
|
| 160 |
+
|
| 161 |
+
def process_and_visualize_fmri(self, fmri_file):
|
| 162 |
+
"""
|
| 163 |
+
Process an fMRI file and visualize its FC matrix.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
fmri_file: Path to the fMRI .nii or .nii.gz file
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
fig: The figure object containing the visualization,
|
| 170 |
+
or None if processing fails
|
| 171 |
+
"""
|
| 172 |
+
if not NILEARN_AVAILABLE:
|
| 173 |
+
print("Error: nilearn is required for fMRI processing")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# Extract FC matrix (upper triangular values)
|
| 178 |
+
fc_triu = self._process_single_fmri(fmri_file)
|
| 179 |
+
|
| 180 |
+
# Convert to full matrix
|
| 181 |
+
fc_matrix = self._triu_to_matrix(fc_triu)
|
| 182 |
+
|
| 183 |
+
# Plot the matrix
|
| 184 |
+
filename = os.path.basename(fmri_file)
|
| 185 |
+
title = f"FC Matrix: {filename}"
|
| 186 |
+
fig, _ = self.plot_single_matrix(fc_matrix, title=title)
|
| 187 |
+
return fig
|
| 188 |
+
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Error processing fMRI file: {e}")
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
def _process_single_fmri(self, fmri_file):
|
| 194 |
+
"""
|
| 195 |
+
Process a single fMRI file to FC matrix.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
fmri_file: Path to the fMRI .nii or .nii.gz file
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
fc_triu: 1D array of upper triangular values (Fisher z-transformed)
|
| 202 |
+
"""
|
| 203 |
+
print(f"Processing fMRI file: {fmri_file}")
|
| 204 |
+
|
| 205 |
+
# Use Power 264 atlas
|
| 206 |
+
power = datasets.fetch_coords_power_2011()
|
| 207 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 208 |
+
|
| 209 |
+
# Create masker
|
| 210 |
+
masker = input_data.NiftiSpheresMasker(
|
| 211 |
+
coords,
|
| 212 |
+
radius=PREPROCESS_CONFIG['radius'],
|
| 213 |
+
standardize=True,
|
| 214 |
+
memory='nilearn_cache',
|
| 215 |
+
memory_level=1,
|
| 216 |
+
verbose=0,
|
| 217 |
+
detrend=True,
|
| 218 |
+
low_pass=PREPROCESS_CONFIG['low_pass'],
|
| 219 |
+
high_pass=PREPROCESS_CONFIG['high_pass'],
|
| 220 |
+
t_r=PREPROCESS_CONFIG['t_r']
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Load and process fMRI
|
| 224 |
+
print(f"Loading NIfTI file...")
|
| 225 |
+
fmri_img = load_img(fmri_file)
|
| 226 |
+
print(f"NIfTI file loaded, shape: {fmri_img.shape}")
|
| 227 |
+
|
| 228 |
+
# Transform to time series
|
| 229 |
+
print(f"Extracting time series...")
|
| 230 |
+
time_series = masker.fit_transform(fmri_img)
|
| 231 |
+
print(f"Time series extracted, shape: {time_series.shape}")
|
| 232 |
+
|
| 233 |
+
# Compute FC matrix
|
| 234 |
+
print(f"Computing FC matrix...")
|
| 235 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 236 |
+
kind='correlation',
|
| 237 |
+
vectorize=False,
|
| 238 |
+
discard_diagonal=False
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 242 |
+
print(f"FC matrix computed, shape: {fc_matrix.shape}")
|
| 243 |
+
|
| 244 |
+
# Get upper triangular part
|
| 245 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 246 |
+
fc_triu = fc_matrix[triu_indices]
|
| 247 |
+
|
| 248 |
+
# Fisher z-transform
|
| 249 |
+
fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99)) # Clip to avoid infinite values
|
| 250 |
+
|
| 251 |
+
print(f"Processing complete. FC features shape: {fc_triu.shape}")
|
| 252 |
+
return fc_triu
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def create_synthetic_fc_matrix(seed=None):
|
| 256 |
+
"""
|
| 257 |
+
Create a synthetic FC matrix for demonstration purposes.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
seed: Random seed for reproducibility
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
matrix: 2D symmetric matrix representing FC
|
| 264 |
+
"""
|
| 265 |
+
if seed is not None:
|
| 266 |
+
np.random.seed(seed)
|
| 267 |
+
|
| 268 |
+
# Number of ROIs (Power atlas has 264)
|
| 269 |
+
n_rois = 264
|
| 270 |
+
|
| 271 |
+
# Create random correlation matrix
|
| 272 |
+
# Method: generate random normal values, create outer product, normalize
|
| 273 |
+
random_vectors = np.random.randn(n_rois, 50) # 50 random features
|
| 274 |
+
matrix = np.corrcoef(random_vectors)
|
| 275 |
+
|
| 276 |
+
# Ensure it's in the range [-1, 1] with 1s on diagonal
|
| 277 |
+
np.fill_diagonal(matrix, 1.0)
|
| 278 |
+
|
| 279 |
+
return matrix
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def main():
|
| 283 |
+
"""Command-line interface for FC matrix visualization."""
|
| 284 |
+
parser = argparse.ArgumentParser(description='Visualize FC matrices')
|
| 285 |
+
parser.add_argument('--input', type=str, help='Input file (fMRI .nii/.nii.gz or .npy FC matrix)')
|
| 286 |
+
parser.add_argument('--output', type=str, help='Output image file (PNG/JPG/PDF)')
|
| 287 |
+
parser.add_argument('--cmap', type=str, default='RdBu_r', help='Colormap (default: RdBu_r)')
|
| 288 |
+
parser.add_argument('--vmin', type=float, default=-1, help='Minimum value for colormap')
|
| 289 |
+
parser.add_argument('--vmax', type=float, default=1, help='Maximum value for colormap')
|
| 290 |
+
parser.add_argument('--synthetic', action='store_true', help='Generate a synthetic FC matrix')
|
| 291 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed for synthetic data')
|
| 292 |
+
|
| 293 |
+
args = parser.parse_args()
|
| 294 |
+
|
| 295 |
+
# Create visualizer
|
| 296 |
+
visualizer = FCVisualizer(cmap=args.cmap, vmin=args.vmin, vmax=args.vmax)
|
| 297 |
+
|
| 298 |
+
# Determine figure to create
|
| 299 |
+
fig = None
|
| 300 |
+
|
| 301 |
+
if args.synthetic:
|
| 302 |
+
# Create synthetic FC matrix
|
| 303 |
+
matrix = create_synthetic_fc_matrix(seed=args.seed)
|
| 304 |
+
fig, _ = visualizer.plot_single_matrix(matrix, title="Synthetic FC Matrix")
|
| 305 |
+
|
| 306 |
+
elif args.input:
|
| 307 |
+
input_path = Path(args.input)
|
| 308 |
+
|
| 309 |
+
if not input_path.exists():
|
| 310 |
+
print(f"Error: Input file not found: {args.input}")
|
| 311 |
+
return
|
| 312 |
+
|
| 313 |
+
# Process based on file type
|
| 314 |
+
if input_path.suffix == '.npy':
|
| 315 |
+
# It's a numpy file with FC matrix
|
| 316 |
+
fig = visualizer.load_and_visualize_npy(input_path)
|
| 317 |
+
|
| 318 |
+
elif input_path.suffix == '.nii' or input_path.suffix == '.gz':
|
| 319 |
+
# It's an fMRI file
|
| 320 |
+
if not NILEARN_AVAILABLE:
|
| 321 |
+
print("Error: nilearn is required for processing fMRI files")
|
| 322 |
+
return
|
| 323 |
+
fig = visualizer.process_and_visualize_fmri(input_path)
|
| 324 |
+
|
| 325 |
+
else:
|
| 326 |
+
print(f"Error: Unsupported file format: {input_path.suffix}")
|
| 327 |
+
print("Supported formats: .npy (FC matrix), .nii/.nii.gz (fMRI)")
|
| 328 |
+
return
|
| 329 |
+
|
| 330 |
+
else:
|
| 331 |
+
# No input or synthetic flag - show demo
|
| 332 |
+
print("No input file or --synthetic flag provided. Generating a demo matrix.")
|
| 333 |
+
matrix = create_synthetic_fc_matrix(seed=args.seed)
|
| 334 |
+
fig, _ = visualizer.plot_single_matrix(matrix, title="Demo FC Matrix")
|
| 335 |
+
|
| 336 |
+
# Save or display the figure
|
| 337 |
+
if fig is not None:
|
| 338 |
+
if args.output:
|
| 339 |
+
fig.savefig(args.output, dpi=300, bbox_inches='tight')
|
| 340 |
+
print(f"Visualization saved to {args.output}")
|
| 341 |
+
else:
|
| 342 |
+
plt.show()
|
| 343 |
+
print("Visualization displayed. Close the window to exit.")
|
| 344 |
+
else:
|
| 345 |
+
print("Error: Failed to create visualization")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
huggingface_fc_visualization.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to visualize FC matrices from HuggingFace dataset, comparing original FC to VAE-generated FC.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from fc_visualization import FCVisualizer
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import tempfile
|
| 12 |
+
import requests
|
| 13 |
+
from config import DATASET_CONFIG, PREPROCESS_CONFIG, MODEL_CONFIG
|
| 14 |
+
from data_preprocessing import process_single_fmri
|
| 15 |
+
from vae_model import VariationalAutoencoder
|
| 16 |
+
|
| 17 |
+
def download_sample_fmri(dataset, temp_dir, max_samples=5):
|
| 18 |
+
"""
|
| 19 |
+
Download sample fMRI files from HuggingFace dataset.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dataset: HuggingFace dataset object
|
| 23 |
+
temp_dir: Directory to save downloaded files
|
| 24 |
+
max_samples: Maximum number of samples to download
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
list of paths to downloaded files, demographic data, and file keys
|
| 28 |
+
"""
|
| 29 |
+
# Get first few samples to search for NIfTI files
|
| 30 |
+
nifti_keys = []
|
| 31 |
+
|
| 32 |
+
# Look through dataset features to find NIfTI files
|
| 33 |
+
for i, sample in enumerate(dataset):
|
| 34 |
+
if i >= 5: # Check first 5 samples
|
| 35 |
+
break
|
| 36 |
+
|
| 37 |
+
for key, value in sample.items():
|
| 38 |
+
if isinstance(value, str) and (value.endswith('.nii') or value.endswith('.nii.gz')):
|
| 39 |
+
if key not in nifti_keys:
|
| 40 |
+
nifti_keys.append(key)
|
| 41 |
+
|
| 42 |
+
print(f"Found {len(nifti_keys)} NIfTI file types in the dataset: {nifti_keys}")
|
| 43 |
+
|
| 44 |
+
if not nifti_keys:
|
| 45 |
+
print("No NIfTI files found in the dataset")
|
| 46 |
+
return [], [], []
|
| 47 |
+
|
| 48 |
+
# Collect nifti files and demographics
|
| 49 |
+
nifti_files = []
|
| 50 |
+
demo_data = []
|
| 51 |
+
|
| 52 |
+
# Process a limited number of samples
|
| 53 |
+
num_samples = min(max_samples, len(dataset))
|
| 54 |
+
|
| 55 |
+
for sample_idx in range(num_samples):
|
| 56 |
+
sample = dataset[sample_idx]
|
| 57 |
+
|
| 58 |
+
for key in nifti_keys:
|
| 59 |
+
try:
|
| 60 |
+
file_url = sample[key]
|
| 61 |
+
if not file_url or not isinstance(file_url, str):
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
print(f"Processing sample {sample_idx+1}, file: {key}")
|
| 65 |
+
|
| 66 |
+
# Download and save the file
|
| 67 |
+
local_file = os.path.join(temp_dir, f"sample_{sample_idx}_{key}.nii.gz")
|
| 68 |
+
print(f"Downloading {file_url} to {local_file}")
|
| 69 |
+
|
| 70 |
+
response = requests.get(file_url)
|
| 71 |
+
with open(local_file, 'wb') as f:
|
| 72 |
+
f.write(response.content)
|
| 73 |
+
|
| 74 |
+
nifti_files.append(local_file)
|
| 75 |
+
|
| 76 |
+
# Extract demo data if available (or use placeholders)
|
| 77 |
+
age = sample.get('age', 65.0) if 'age' in sample else 65.0
|
| 78 |
+
sex = sample.get('sex', 'M') if 'sex' in sample else 'M'
|
| 79 |
+
mpo = sample.get('months_post_onset', 12.0) if 'months_post_onset' in sample else 12.0
|
| 80 |
+
wab = sample.get('wab_aq', 50.0) if 'wab_aq' in sample else 50.0
|
| 81 |
+
|
| 82 |
+
demo_sample = [age, sex, mpo, wab]
|
| 83 |
+
demo_data.append(demo_sample)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error processing sample {sample_idx}, {key}: {e}")
|
| 87 |
+
|
| 88 |
+
return nifti_files, demo_data, nifti_keys
|
| 89 |
+
|
| 90 |
+
class VariationalAutoencoder:
|
| 91 |
+
"""
|
| 92 |
+
Simplified VAE implementation for the visualization script.
|
| 93 |
+
"""
|
| 94 |
+
def __init__(self, n_features, latent_dim, demo_data, demo_types, **kwargs):
|
| 95 |
+
"""
|
| 96 |
+
Initialize the VAE.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
n_features: Number of input features
|
| 100 |
+
latent_dim: Dimension of latent space
|
| 101 |
+
demo_data: Demographic data
|
| 102 |
+
demo_types: Types of demographic variables
|
| 103 |
+
**kwargs: Additional parameters
|
| 104 |
+
"""
|
| 105 |
+
import torch
|
| 106 |
+
import torch.nn as nn
|
| 107 |
+
|
| 108 |
+
self.n_features = n_features
|
| 109 |
+
self.latent_dim = latent_dim
|
| 110 |
+
self.demo_dim = self._calculate_demo_dim(demo_data, demo_types)
|
| 111 |
+
self.nepochs = kwargs.get('nepochs', 100)
|
| 112 |
+
self.batch_size = kwargs.get('bsize', 8)
|
| 113 |
+
self.learning_rate = kwargs.get('lr', 1e-3)
|
| 114 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 115 |
+
|
| 116 |
+
# Build VAE model
|
| 117 |
+
self.encoder = nn.Sequential(
|
| 118 |
+
nn.Linear(n_features, 512),
|
| 119 |
+
nn.ReLU(),
|
| 120 |
+
nn.BatchNorm1d(512),
|
| 121 |
+
nn.Linear(512, 256),
|
| 122 |
+
nn.ReLU(),
|
| 123 |
+
nn.BatchNorm1d(256),
|
| 124 |
+
nn.Linear(256, latent_dim * 2) # mu and logvar
|
| 125 |
+
).to(self.device)
|
| 126 |
+
|
| 127 |
+
self.decoder = nn.Sequential(
|
| 128 |
+
nn.Linear(latent_dim + self.demo_dim, 256),
|
| 129 |
+
nn.ReLU(),
|
| 130 |
+
nn.BatchNorm1d(256),
|
| 131 |
+
nn.Linear(256, 512),
|
| 132 |
+
nn.ReLU(),
|
| 133 |
+
nn.BatchNorm1d(512),
|
| 134 |
+
nn.Linear(512, n_features)
|
| 135 |
+
).to(self.device)
|
| 136 |
+
|
| 137 |
+
self.optimizer = torch.optim.Adam(
|
| 138 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()),
|
| 139 |
+
lr=self.learning_rate
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.demo_stats = None # Will be set during training
|
| 143 |
+
|
| 144 |
+
def _calculate_demo_dim(self, demo_data, demo_types):
|
| 145 |
+
"""Calculate dimension of demographic data after one-hot encoding"""
|
| 146 |
+
demo_dim = 0
|
| 147 |
+
for d, t in zip(demo_data, demo_types):
|
| 148 |
+
if t == 'continuous':
|
| 149 |
+
demo_dim += 1
|
| 150 |
+
elif t == 'categorical':
|
| 151 |
+
if isinstance(d[0], str):
|
| 152 |
+
# Get unique categories
|
| 153 |
+
unique_values = list(set(d))
|
| 154 |
+
demo_dim += len(unique_values)
|
| 155 |
+
else:
|
| 156 |
+
demo_dim += len(set(d))
|
| 157 |
+
return demo_dim
|
| 158 |
+
|
| 159 |
+
def _encode(self, x):
|
| 160 |
+
"""Encode input data to latent space"""
|
| 161 |
+
import torch
|
| 162 |
+
|
| 163 |
+
x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device)
|
| 164 |
+
h = self.encoder(x_tensor)
|
| 165 |
+
mu, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:]
|
| 166 |
+
return mu, logvar
|
| 167 |
+
|
| 168 |
+
def _reparameterize(self, mu, logvar):
|
| 169 |
+
"""Reparameterization trick for sampling from latent space"""
|
| 170 |
+
import torch
|
| 171 |
+
|
| 172 |
+
std = torch.exp(0.5 * logvar)
|
| 173 |
+
eps = torch.randn_like(std)
|
| 174 |
+
z = mu + eps * std
|
| 175 |
+
return z
|
| 176 |
+
|
| 177 |
+
def _decode(self, z, demo):
|
| 178 |
+
"""Decode latent representation back to input space"""
|
| 179 |
+
import torch
|
| 180 |
+
|
| 181 |
+
# Concatenate latent code with demographic data
|
| 182 |
+
z_concat = torch.cat([z, demo], dim=1)
|
| 183 |
+
return self.decoder(z_concat)
|
| 184 |
+
|
| 185 |
+
def _prepare_demographics(self, demo_data, demo_types):
|
| 186 |
+
"""Convert demographics to tensor with one-hot encoding for categorical variables"""
|
| 187 |
+
import torch
|
| 188 |
+
import numpy as np
|
| 189 |
+
|
| 190 |
+
if self.demo_stats is None:
|
| 191 |
+
# First time - compute stats
|
| 192 |
+
self.demo_stats = []
|
| 193 |
+
for d, t in zip(demo_data, demo_types):
|
| 194 |
+
if t == 'continuous':
|
| 195 |
+
# Standardize continuous features
|
| 196 |
+
self.demo_stats.append(('continuous', (np.mean(d), np.std(d))))
|
| 197 |
+
elif t == 'categorical':
|
| 198 |
+
# Record unique values for one-hot encoding
|
| 199 |
+
if isinstance(d[0], str):
|
| 200 |
+
unique_values = sorted(list(set(d)))
|
| 201 |
+
else:
|
| 202 |
+
unique_values = sorted(list(set(d)))
|
| 203 |
+
self.demo_stats.append(('categorical', unique_values))
|
| 204 |
+
|
| 205 |
+
# Process demographics based on saved stats
|
| 206 |
+
demo_tensors = []
|
| 207 |
+
for (d, (dtype, stats)) in zip(demo_data, self.demo_stats):
|
| 208 |
+
if dtype == 'continuous':
|
| 209 |
+
mean, std = stats
|
| 210 |
+
# Standardize
|
| 211 |
+
standardized = (np.array(d) - mean) / (std + 1e-10)
|
| 212 |
+
demo_tensors.append(torch.tensor(standardized, dtype=torch.float32).reshape(-1, 1))
|
| 213 |
+
else: # categorical
|
| 214 |
+
unique_values = stats
|
| 215 |
+
# One-hot encode
|
| 216 |
+
one_hot_vectors = []
|
| 217 |
+
for val in d:
|
| 218 |
+
try:
|
| 219 |
+
idx = unique_values.index(val)
|
| 220 |
+
vec = [0.0] * len(unique_values)
|
| 221 |
+
vec[idx] = 1.0
|
| 222 |
+
one_hot_vectors.append(vec)
|
| 223 |
+
except ValueError:
|
| 224 |
+
# Handle unseen categories - use all zeros
|
| 225 |
+
vec = [0.0] * len(unique_values)
|
| 226 |
+
one_hot_vectors.append(vec)
|
| 227 |
+
demo_tensors.append(torch.tensor(one_hot_vectors, dtype=torch.float32))
|
| 228 |
+
|
| 229 |
+
# Concatenate all demographic features
|
| 230 |
+
return torch.cat(demo_tensors, dim=1).to(self.device)
|
| 231 |
+
|
| 232 |
+
def fit(self, X, demo_data, demo_types):
|
| 233 |
+
"""
|
| 234 |
+
Train the VAE model.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
X: Input data (FC matrices)
|
| 238 |
+
demo_data: List of demographic variables
|
| 239 |
+
demo_types: Types of demographic variables
|
| 240 |
+
"""
|
| 241 |
+
import torch
|
| 242 |
+
import torch.nn.functional as F
|
| 243 |
+
import numpy as np
|
| 244 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 245 |
+
|
| 246 |
+
print(f"Training VAE on {len(X)} samples for {self.nepochs} epochs...")
|
| 247 |
+
|
| 248 |
+
# Prepare demographic data
|
| 249 |
+
demo_tensor = self._prepare_demographics(demo_data, demo_types)
|
| 250 |
+
|
| 251 |
+
# Convert input data to tensor
|
| 252 |
+
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
| 253 |
+
|
| 254 |
+
# Create dataset and dataloader
|
| 255 |
+
dataset = TensorDataset(X_tensor, demo_tensor)
|
| 256 |
+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
| 257 |
+
|
| 258 |
+
# Training loop
|
| 259 |
+
self.train_losses = []
|
| 260 |
+
|
| 261 |
+
for epoch in range(self.nepochs):
|
| 262 |
+
epoch_losses = []
|
| 263 |
+
|
| 264 |
+
for batch_x, batch_demo in dataloader:
|
| 265 |
+
# Forward pass
|
| 266 |
+
mu, logvar = self._encode(batch_x)
|
| 267 |
+
z = self._reparameterize(mu, logvar)
|
| 268 |
+
x_recon = self._decode(z, batch_demo)
|
| 269 |
+
|
| 270 |
+
# Compute loss
|
| 271 |
+
recon_loss = F.mse_loss(x_recon, batch_x)
|
| 272 |
+
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
| 273 |
+
kl_loss = kl_loss / batch_x.size(0) # Normalize by batch size
|
| 274 |
+
|
| 275 |
+
# Total loss
|
| 276 |
+
loss = recon_loss + 0.1 * kl_loss
|
| 277 |
+
|
| 278 |
+
# Backward and optimize
|
| 279 |
+
self.optimizer.zero_grad()
|
| 280 |
+
loss.backward()
|
| 281 |
+
self.optimizer.step()
|
| 282 |
+
|
| 283 |
+
epoch_losses.append(loss.item())
|
| 284 |
+
|
| 285 |
+
# Record average loss for this epoch
|
| 286 |
+
avg_loss = np.mean(epoch_losses)
|
| 287 |
+
self.train_losses.append(avg_loss)
|
| 288 |
+
|
| 289 |
+
# Print progress every 10 epochs
|
| 290 |
+
if (epoch + 1) % 10 == 0:
|
| 291 |
+
print(f"Epoch {epoch+1}/{self.nepochs}, Loss: {avg_loss:.6f}")
|
| 292 |
+
|
| 293 |
+
print("VAE training complete!")
|
| 294 |
+
return self.train_losses
|
| 295 |
+
|
| 296 |
+
def reconstruct(self, X, demo_data=None, demo_types=None):
|
| 297 |
+
"""
|
| 298 |
+
Reconstruct input data.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
X: Input data
|
| 302 |
+
demo_data: Demographic data (optional)
|
| 303 |
+
demo_types: Types of demographic variables (optional)
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
Reconstructed data
|
| 307 |
+
"""
|
| 308 |
+
import torch
|
| 309 |
+
|
| 310 |
+
# Set to evaluation mode
|
| 311 |
+
self.encoder.eval()
|
| 312 |
+
self.decoder.eval()
|
| 313 |
+
|
| 314 |
+
with torch.no_grad():
|
| 315 |
+
# Encode to latent space
|
| 316 |
+
mu, _ = self._encode(X)
|
| 317 |
+
|
| 318 |
+
# Use demo data if provided, otherwise use the demo data from training
|
| 319 |
+
if demo_data is not None and demo_types is not None:
|
| 320 |
+
demo_tensor = self._prepare_demographics(demo_data, demo_types)
|
| 321 |
+
else:
|
| 322 |
+
# This would fail if model wasn't trained
|
| 323 |
+
raise ValueError("Demo data and types must be provided for reconstruction")
|
| 324 |
+
|
| 325 |
+
# Decode
|
| 326 |
+
recon = self._decode(mu, demo_tensor)
|
| 327 |
+
|
| 328 |
+
# Convert to numpy
|
| 329 |
+
return recon.cpu().numpy()
|
| 330 |
+
|
| 331 |
+
def generate(self, n_samples, demo_data, demo_types):
|
| 332 |
+
"""
|
| 333 |
+
Generate new samples from the latent space.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
n_samples: Number of samples to generate
|
| 337 |
+
demo_data: Demographic data
|
| 338 |
+
demo_types: Types of demographic variables
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Generated samples
|
| 342 |
+
"""
|
| 343 |
+
import torch
|
| 344 |
+
|
| 345 |
+
# Set to evaluation mode
|
| 346 |
+
self.decoder.eval()
|
| 347 |
+
|
| 348 |
+
with torch.no_grad():
|
| 349 |
+
# Sample from standard normal
|
| 350 |
+
z = torch.randn(n_samples, self.latent_dim).to(self.device)
|
| 351 |
+
|
| 352 |
+
# Prepare demographic data
|
| 353 |
+
demo_tensor = self._prepare_demographics(demo_data, demo_types)
|
| 354 |
+
|
| 355 |
+
# Check dimensions
|
| 356 |
+
if demo_tensor.shape[0] != n_samples:
|
| 357 |
+
# Handle mismatch - repeat the first demographic sample
|
| 358 |
+
if demo_tensor.shape[0] >= 1:
|
| 359 |
+
demo_tensor = demo_tensor[0].unsqueeze(0).repeat(n_samples, 1)
|
| 360 |
+
|
| 361 |
+
# Generate samples
|
| 362 |
+
generated = self._decode(z, demo_tensor)
|
| 363 |
+
|
| 364 |
+
# Convert to numpy
|
| 365 |
+
return generated.cpu().numpy()
|
| 366 |
+
|
| 367 |
+
def generate_comparison():
|
| 368 |
+
"""Download, process and visualize FC matrices from the HuggingFace dataset,
|
| 369 |
+
comparing original to VAE-generated matrices."""
|
| 370 |
+
print("Loading dataset from HuggingFace...")
|
| 371 |
+
|
| 372 |
+
# Load the HuggingFace dataset using config
|
| 373 |
+
dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData')
|
| 374 |
+
dataset_split = DATASET_CONFIG.get('split', 'train')
|
| 375 |
+
|
| 376 |
+
dataset = load_dataset(dataset_name, split=dataset_split)
|
| 377 |
+
print(f"Dataset loaded: {dataset}")
|
| 378 |
+
|
| 379 |
+
# Create temporary directory for downloaded NIfTI files
|
| 380 |
+
temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
|
| 381 |
+
print(f"Created temp directory for NIfTI files: {temp_dir}")
|
| 382 |
+
|
| 383 |
+
# Download and process fMRI files
|
| 384 |
+
nifti_files, demo_samples, nifti_keys = download_sample_fmri(dataset, temp_dir, max_samples=5)
|
| 385 |
+
|
| 386 |
+
if not nifti_files:
|
| 387 |
+
print("No valid fMRI files were found")
|
| 388 |
+
return
|
| 389 |
+
|
| 390 |
+
# Process all fMRI files to FC matrices
|
| 391 |
+
fc_matrices = []
|
| 392 |
+
demo_data = []
|
| 393 |
+
|
| 394 |
+
for file_idx, (file_path, demo_sample) in enumerate(zip(nifti_files, demo_samples)):
|
| 395 |
+
try:
|
| 396 |
+
print(f"Processing file {file_idx+1}/{len(nifti_files)}: {file_path}")
|
| 397 |
+
fc_triu = process_single_fmri(file_path)
|
| 398 |
+
fc_matrices.append(fc_triu)
|
| 399 |
+
demo_data.append(demo_sample)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
print(f"Error processing file {file_path}: {e}")
|
| 402 |
+
|
| 403 |
+
if not fc_matrices:
|
| 404 |
+
print("No valid FC matrices were generated")
|
| 405 |
+
return
|
| 406 |
+
|
| 407 |
+
# Convert to numpy arrays
|
| 408 |
+
X = np.array(fc_matrices)
|
| 409 |
+
|
| 410 |
+
# Normalize the data
|
| 411 |
+
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 412 |
+
|
| 413 |
+
# Prepare demographic data
|
| 414 |
+
# Transpose to get [feature_type][sample] format
|
| 415 |
+
demo_data = np.array(demo_data).T.tolist()
|
| 416 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 417 |
+
|
| 418 |
+
# Train a VAE on the FC matrices
|
| 419 |
+
print("Training VAE on the FC matrices...")
|
| 420 |
+
n_features = X.shape[1]
|
| 421 |
+
|
| 422 |
+
# Configure a smaller/faster VAE for demonstration
|
| 423 |
+
vae = VariationalAutoencoder(
|
| 424 |
+
n_features=n_features,
|
| 425 |
+
latent_dim=MODEL_CONFIG.get('latent_dim', 32),
|
| 426 |
+
demo_data=demo_data,
|
| 427 |
+
demo_types=demo_types,
|
| 428 |
+
nepochs=100, # Reduced for demo
|
| 429 |
+
bsize=2,
|
| 430 |
+
lr=1e-3
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Train the VAE
|
| 434 |
+
vae.fit(X, demo_data, demo_types)
|
| 435 |
+
|
| 436 |
+
# Generate reconstructed FC matrices
|
| 437 |
+
print("Generating reconstructed FC matrices...")
|
| 438 |
+
reconstructed = vae.reconstruct(X, demo_data, demo_types)
|
| 439 |
+
|
| 440 |
+
# Generate a synthetic FC matrix
|
| 441 |
+
print("Generating a synthetic FC matrix...")
|
| 442 |
+
# For generating a new sample, we'll use demographics from first patient
|
| 443 |
+
first_demo_data = [[d[0]] for d in demo_data]
|
| 444 |
+
generated = vae.generate(1, first_demo_data, demo_types)
|
| 445 |
+
|
| 446 |
+
# Visualize original, reconstructed, and generated FC matrices
|
| 447 |
+
visualizer = FCVisualizer()
|
| 448 |
+
|
| 449 |
+
# Process each sample to generate comparisons
|
| 450 |
+
for i in range(min(3, len(X))):
|
| 451 |
+
# Convert upper triangular vectors to full matrices for visualization
|
| 452 |
+
original_matrix = visualizer._triu_to_matrix(X[i])
|
| 453 |
+
recon_matrix = visualizer._triu_to_matrix(reconstructed[i])
|
| 454 |
+
|
| 455 |
+
# Use the generate method for a single synthetic sample
|
| 456 |
+
if i == 0:
|
| 457 |
+
gen_matrix = visualizer._triu_to_matrix(generated[0])
|
| 458 |
+
|
| 459 |
+
# Visualize all three - original, reconstructed, generated
|
| 460 |
+
fig = visualizer.plot_matrix_comparison(
|
| 461 |
+
[original_matrix, recon_matrix, gen_matrix],
|
| 462 |
+
titles=["Original FC", "Reconstructed FC", "Generated FC"]
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
output_file = f"fc_comparison_with_generated.png"
|
| 466 |
+
fig.savefig(output_file, dpi=300, bbox_inches='tight')
|
| 467 |
+
print(f"Saved full comparison to {output_file}")
|
| 468 |
+
|
| 469 |
+
# Visualize original vs reconstructed for each sample
|
| 470 |
+
fig = visualizer.plot_matrix_comparison(
|
| 471 |
+
[original_matrix, recon_matrix],
|
| 472 |
+
titles=[f"Original FC (Sample {i+1})", f"Reconstructed FC (Sample {i+1})"]
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
output_file = f"sample_{i}_original_vs_reconstructed.png"
|
| 476 |
+
fig.savefig(output_file, dpi=300, bbox_inches='tight')
|
| 477 |
+
print(f"Saved comparison to {output_file}")
|
| 478 |
+
|
| 479 |
+
# Save the matrices
|
| 480 |
+
np.save(f"sample_{i}_original_fc.npy", original_matrix)
|
| 481 |
+
np.save(f"sample_{i}_reconstructed_fc.npy", recon_matrix)
|
| 482 |
+
|
| 483 |
+
# Save the generated matrix
|
| 484 |
+
np.save("generated_fc.npy", gen_matrix)
|
| 485 |
+
|
| 486 |
+
print("Processing complete")
|
| 487 |
+
|
| 488 |
+
if __name__ == "__main__":
|
| 489 |
+
generate_comparison()
|
main.py
CHANGED
|
@@ -99,7 +99,6 @@ def run_analysis(data_dir="data",
|
|
| 99 |
# Initialize and train treatment predictor
|
| 100 |
print("Training treatment predictor...")
|
| 101 |
predictor = AphasiaTreatmentPredictor(
|
| 102 |
-
prediction_type=PREDICTION_CONFIG.get('prediction_type', 'regression'),
|
| 103 |
n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
|
| 104 |
max_depth=PREDICTION_CONFIG.get('max_depth', None)
|
| 105 |
)
|
|
@@ -129,18 +128,11 @@ def run_analysis(data_dir="data",
|
|
| 129 |
|
| 130 |
# For regression, get R2 metrics, otherwise use accuracy
|
| 131 |
try:
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
|
| 136 |
-
else:
|
| 137 |
-
cv_std = 0.0
|
| 138 |
else:
|
| 139 |
-
|
| 140 |
-
if fold_metrics and "accuracy" in fold_metrics[0]:
|
| 141 |
-
cv_std = np.std([fold.get("accuracy", 0.0) for fold in fold_metrics])
|
| 142 |
-
else:
|
| 143 |
-
cv_std = 0.0
|
| 144 |
except Exception as e:
|
| 145 |
print(f"Error calculating CV metrics: {e}")
|
| 146 |
cv_mean, cv_std = 0.0, 0.0
|
|
|
|
| 99 |
# Initialize and train treatment predictor
|
| 100 |
print("Training treatment predictor...")
|
| 101 |
predictor = AphasiaTreatmentPredictor(
|
|
|
|
| 102 |
n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
|
| 103 |
max_depth=PREDICTION_CONFIG.get('max_depth', None)
|
| 104 |
)
|
|
|
|
| 128 |
|
| 129 |
# For regression, get R2 metrics, otherwise use accuracy
|
| 130 |
try:
|
| 131 |
+
cv_mean = mean_metrics.get("r2", 0.0)
|
| 132 |
+
if fold_metrics and "r2" in fold_metrics[0]:
|
| 133 |
+
cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
|
|
|
|
|
|
|
|
|
|
| 134 |
else:
|
| 135 |
+
cv_std = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
except Exception as e:
|
| 137 |
print(f"Error calculating CV metrics: {e}")
|
| 138 |
cv_mean, cv_std = 0.0, 0.0
|
rcf_prediction.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import numpy as np
|
| 2 |
-
from sklearn.ensemble import RandomForestRegressor
|
| 3 |
from sklearn.model_selection import cross_val_score, KFold
|
| 4 |
import pandas as pd
|
| 5 |
-
from sklearn.metrics import mean_squared_error, r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import os
|
| 8 |
import joblib
|
|
@@ -12,35 +12,27 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
class AphasiaTreatmentPredictor:
|
| 15 |
-
def __init__(self,
|
| 16 |
"""
|
| 17 |
-
Initialize the Treatment Predictor with Random Forest
|
| 18 |
|
| 19 |
Args:
|
| 20 |
-
prediction_type (str): "classification" or "regression" depending on outcome variable type
|
| 21 |
n_estimators (int): Number of trees in the forest
|
| 22 |
max_depth (int): Maximum depth of trees (None for unlimited)
|
| 23 |
random_state (int): Random seed for reproducibility
|
| 24 |
"""
|
| 25 |
-
self.prediction_type =
|
| 26 |
self.n_estimators = n_estimators
|
| 27 |
self.max_depth = max_depth
|
| 28 |
self.random_state = random_state
|
| 29 |
self.feature_importance = None
|
| 30 |
self.feature_names = None
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
)
|
| 38 |
-
else: # regression
|
| 39 |
-
self.model = RandomForestRegressor(
|
| 40 |
-
n_estimators=n_estimators,
|
| 41 |
-
max_depth=max_depth,
|
| 42 |
-
random_state=random_state
|
| 43 |
-
)
|
| 44 |
|
| 45 |
def prepare_features(self, latents, demographics):
|
| 46 |
"""
|
|
@@ -115,34 +107,11 @@ class AphasiaTreatmentPredictor:
|
|
| 115 |
predictions = self.model.predict(X)
|
| 116 |
|
| 117 |
# Get prediction intervals using tree variance
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
prediction_std = np.std(tree_predictions, axis=0)
|
| 122 |
-
else: # classification
|
| 123 |
-
# For classification, use probability as a measure of confidence
|
| 124 |
-
proba = self.model.predict_proba(X)
|
| 125 |
-
# Use max probability as confidence measure
|
| 126 |
-
prediction_std = 1 - np.max(proba, axis=1)
|
| 127 |
|
| 128 |
return predictions, prediction_std
|
| 129 |
-
|
| 130 |
-
def predict_proba(self, latents, demographics):
|
| 131 |
-
"""
|
| 132 |
-
Get probability estimates for classification
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
latents (np.ndarray): Latent representations from VAE
|
| 136 |
-
demographics (dict or pd.DataFrame): Demographic information
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
np.ndarray: Probability estimates for each class
|
| 140 |
-
"""
|
| 141 |
-
if self.prediction_type != "classification":
|
| 142 |
-
raise ValueError("Probability prediction only available for classification")
|
| 143 |
-
|
| 144 |
-
X, _ = self.prepare_features(latents, demographics)
|
| 145 |
-
return self.model.predict_proba(X)
|
| 146 |
|
| 147 |
def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
|
| 148 |
"""
|
|
@@ -174,18 +143,11 @@ class AphasiaTreatmentPredictor:
|
|
| 174 |
y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
|
| 175 |
|
| 176 |
# Clone the model for this fold
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
)
|
| 183 |
-
else:
|
| 184 |
-
fold_model = RandomForestRegressor(
|
| 185 |
-
n_estimators=self.n_estimators,
|
| 186 |
-
max_depth=self.max_depth,
|
| 187 |
-
random_state=self.random_state
|
| 188 |
-
)
|
| 189 |
|
| 190 |
# Train the model
|
| 191 |
fold_model.fit(X_train, y_train)
|
|
@@ -197,38 +159,19 @@ class AphasiaTreatmentPredictor:
|
|
| 197 |
predictions[test_idx] = pred
|
| 198 |
|
| 199 |
# Calculate metrics
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
prediction_stds[test_idx] = pred_std
|
| 214 |
-
|
| 215 |
-
else: # classification
|
| 216 |
-
acc = accuracy_score(y_test, pred)
|
| 217 |
-
prec = precision_score(y_test, pred, average='weighted', zero_division=0)
|
| 218 |
-
rec = recall_score(y_test, pred, average='weighted', zero_division=0)
|
| 219 |
-
f1 = f1_score(y_test, pred, average='weighted', zero_division=0)
|
| 220 |
-
metrics = {
|
| 221 |
-
"accuracy": acc,
|
| 222 |
-
"precision": prec,
|
| 223 |
-
"recall": rec,
|
| 224 |
-
"f1": f1
|
| 225 |
-
}
|
| 226 |
-
|
| 227 |
-
# Use probability as a measure of confidence
|
| 228 |
-
proba = fold_model.predict_proba(X_test)
|
| 229 |
-
# Use max probability as confidence measure
|
| 230 |
-
pred_std = 1 - np.max(proba, axis=1)
|
| 231 |
-
prediction_stds[test_idx] = pred_std
|
| 232 |
|
| 233 |
fold_metrics.append(metrics)
|
| 234 |
logger.info(f"Fold {fold+1} metrics: {metrics}")
|
|
@@ -335,7 +278,6 @@ class AphasiaTreatmentPredictor:
|
|
| 335 |
|
| 336 |
# Create new instance
|
| 337 |
predictor = cls(
|
| 338 |
-
prediction_type=data['prediction_type'],
|
| 339 |
n_estimators=data['n_estimators'],
|
| 340 |
max_depth=data['max_depth'],
|
| 341 |
random_state=data['random_state']
|
|
@@ -350,7 +292,7 @@ class AphasiaTreatmentPredictor:
|
|
| 350 |
return predictor
|
| 351 |
|
| 352 |
|
| 353 |
-
def train_predictor_from_latents(latents, outcomes, demographics=None,
|
| 354 |
"""
|
| 355 |
Train a treatment outcome predictor from VAE latent representations
|
| 356 |
|
|
@@ -358,17 +300,16 @@ def train_predictor_from_latents(latents, outcomes, demographics=None, predictio
|
|
| 358 |
latents (np.ndarray): Latent representations from VAE
|
| 359 |
outcomes (np.ndarray): Treatment outcome values
|
| 360 |
demographics (dict or pd.DataFrame, optional): Demographic information to include as features
|
| 361 |
-
prediction_type (str): "classification" or "regression"
|
| 362 |
cv (int): Number of folds for cross-validation
|
| 363 |
**kwargs: Additional parameters for the AphasiaTreatmentPredictor
|
| 364 |
|
| 365 |
Returns:
|
| 366 |
dict: Training results and trained model
|
| 367 |
"""
|
| 368 |
-
logger.info(f"Training
|
| 369 |
|
| 370 |
# Create predictor
|
| 371 |
-
predictor = AphasiaTreatmentPredictor(
|
| 372 |
|
| 373 |
# Run cross-validation
|
| 374 |
cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 3 |
from sklearn.model_selection import cross_val_score, KFold
|
| 4 |
import pandas as pd
|
| 5 |
+
from sklearn.metrics import mean_squared_error, r2_score
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import os
|
| 8 |
import joblib
|
|
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
class AphasiaTreatmentPredictor:
|
| 15 |
+
def __init__(self, n_estimators=100, max_depth=None, random_state=42):
|
| 16 |
"""
|
| 17 |
+
Initialize the Treatment Predictor with Random Forest Regressor
|
| 18 |
|
| 19 |
Args:
|
|
|
|
| 20 |
n_estimators (int): Number of trees in the forest
|
| 21 |
max_depth (int): Maximum depth of trees (None for unlimited)
|
| 22 |
random_state (int): Random seed for reproducibility
|
| 23 |
"""
|
| 24 |
+
self.prediction_type = "regression"
|
| 25 |
self.n_estimators = n_estimators
|
| 26 |
self.max_depth = max_depth
|
| 27 |
self.random_state = random_state
|
| 28 |
self.feature_importance = None
|
| 29 |
self.feature_names = None
|
| 30 |
|
| 31 |
+
self.model = RandomForestRegressor(
|
| 32 |
+
n_estimators=n_estimators,
|
| 33 |
+
max_depth=max_depth,
|
| 34 |
+
random_state=random_state
|
| 35 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def prepare_features(self, latents, demographics):
|
| 38 |
"""
|
|
|
|
| 107 |
predictions = self.model.predict(X)
|
| 108 |
|
| 109 |
# Get prediction intervals using tree variance
|
| 110 |
+
tree_predictions = np.array([tree.predict(X)
|
| 111 |
+
for tree in self.model.estimators_])
|
| 112 |
+
prediction_std = np.std(tree_predictions, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
return predictions, prediction_std
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
|
| 117 |
"""
|
|
|
|
| 143 |
y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
|
| 144 |
|
| 145 |
# Clone the model for this fold
|
| 146 |
+
fold_model = RandomForestRegressor(
|
| 147 |
+
n_estimators=self.n_estimators,
|
| 148 |
+
max_depth=self.max_depth,
|
| 149 |
+
random_state=self.random_state
|
| 150 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Train the model
|
| 153 |
fold_model.fit(X_train, y_train)
|
|
|
|
| 159 |
predictions[test_idx] = pred
|
| 160 |
|
| 161 |
# Calculate metrics
|
| 162 |
+
rmse = np.sqrt(mean_squared_error(y_test, pred))
|
| 163 |
+
r2 = r2_score(y_test, pred)
|
| 164 |
+
metrics = {
|
| 165 |
+
"r2": r2,
|
| 166 |
+
"rmse": rmse,
|
| 167 |
+
"mse": rmse**2
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# Get prediction intervals using tree variance
|
| 171 |
+
tree_predictions = np.array([tree.predict(X_test)
|
| 172 |
+
for tree in fold_model.estimators_])
|
| 173 |
+
pred_std = np.std(tree_predictions, axis=0)
|
| 174 |
+
prediction_stds[test_idx] = pred_std
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
fold_metrics.append(metrics)
|
| 177 |
logger.info(f"Fold {fold+1} metrics: {metrics}")
|
|
|
|
| 278 |
|
| 279 |
# Create new instance
|
| 280 |
predictor = cls(
|
|
|
|
| 281 |
n_estimators=data['n_estimators'],
|
| 282 |
max_depth=data['max_depth'],
|
| 283 |
random_state=data['random_state']
|
|
|
|
| 292 |
return predictor
|
| 293 |
|
| 294 |
|
| 295 |
+
def train_predictor_from_latents(latents, outcomes, demographics=None, cv=5, **kwargs):
|
| 296 |
"""
|
| 297 |
Train a treatment outcome predictor from VAE latent representations
|
| 298 |
|
|
|
|
| 300 |
latents (np.ndarray): Latent representations from VAE
|
| 301 |
outcomes (np.ndarray): Treatment outcome values
|
| 302 |
demographics (dict or pd.DataFrame, optional): Demographic information to include as features
|
|
|
|
| 303 |
cv (int): Number of folds for cross-validation
|
| 304 |
**kwargs: Additional parameters for the AphasiaTreatmentPredictor
|
| 305 |
|
| 306 |
Returns:
|
| 307 |
dict: Training results and trained model
|
| 308 |
"""
|
| 309 |
+
logger.info(f"Training regression model for treatment prediction")
|
| 310 |
|
| 311 |
# Create predictor
|
| 312 |
+
predictor = AphasiaTreatmentPredictor(**kwargs)
|
| 313 |
|
| 314 |
# Run cross-validation
|
| 315 |
cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
|
src/.DS_Store
CHANGED
|
Binary files a/src/.DS_Store and b/src/.DS_Store differ
|
|
|
vae_model.py
CHANGED
|
@@ -26,17 +26,33 @@ class VAE(nn.Module):
|
|
| 26 |
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
| 27 |
|
| 28 |
def enc(self, x):
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
return z
|
| 32 |
|
| 33 |
def gen(self, n):
|
| 34 |
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 35 |
|
| 36 |
def dec(self, z, demo):
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
return x
|
| 41 |
|
| 42 |
class DemoVAE(BaseEstimator):
|
|
@@ -106,16 +122,77 @@ class DemoVAE(BaseEstimator):
|
|
| 106 |
return train_losses, val_losses
|
| 107 |
|
| 108 |
def transform(self, x, demo, demo_types):
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
return to_numpy(y)
|
| 116 |
|
| 117 |
def get_latents(self, x):
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return to_numpy(z)
|
| 120 |
|
| 121 |
def save(self, path):
|
|
|
|
| 26 |
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
| 27 |
|
| 28 |
def enc(self, x):
|
| 29 |
+
# First layer with activation
|
| 30 |
+
h = self.enc1(x)
|
| 31 |
+
h = F.relu(h)
|
| 32 |
+
|
| 33 |
+
# Apply batch norm - handle training vs eval mode automatically
|
| 34 |
+
h = self.bn1(h)
|
| 35 |
+
|
| 36 |
+
# Output layer
|
| 37 |
+
z = self.enc2(h)
|
| 38 |
return z
|
| 39 |
|
| 40 |
def gen(self, n):
|
| 41 |
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 42 |
|
| 43 |
def dec(self, z, demo):
|
| 44 |
+
# Concatenate latent code with demographic data
|
| 45 |
+
z_combined = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 46 |
+
|
| 47 |
+
# First decoder layer with activation
|
| 48 |
+
h = self.dec1(z_combined)
|
| 49 |
+
h = F.relu(h)
|
| 50 |
+
|
| 51 |
+
# Apply batch norm - handle training vs eval mode automatically
|
| 52 |
+
h = self.bn2(h)
|
| 53 |
+
|
| 54 |
+
# Output layer
|
| 55 |
+
x = self.dec2(h)
|
| 56 |
return x
|
| 57 |
|
| 58 |
class DemoVAE(BaseEstimator):
|
|
|
|
| 122 |
return train_losses, val_losses
|
| 123 |
|
| 124 |
def transform(self, x, demo, demo_types):
|
| 125 |
+
# Set model to evaluation mode to handle batch norm with batch size of 1
|
| 126 |
+
self.vae.eval()
|
| 127 |
+
|
| 128 |
+
# Use torch.no_grad to disable gradient calculation during inference
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
if isinstance(x, int):
|
| 131 |
+
z = self.vae.gen(x)
|
| 132 |
+
else:
|
| 133 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 134 |
+
|
| 135 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 136 |
+
|
| 137 |
+
# Handle batch size of 1 for batch normalization
|
| 138 |
+
if z.size(0) == 1:
|
| 139 |
+
# If batch size is 1, we need to be careful with batch norm
|
| 140 |
+
# Clone and repeat the input to create a fake batch if needed
|
| 141 |
+
if hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2'):
|
| 142 |
+
try:
|
| 143 |
+
# Try normal decoding first
|
| 144 |
+
y = self.vae.dec(z, demo_t)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
# If it fails, use a workaround for batch norm
|
| 147 |
+
print(f"Using batch norm workaround for inference: {e}")
|
| 148 |
+
# Create a batch by repeating the input
|
| 149 |
+
z_batch = z.repeat(2, 1)
|
| 150 |
+
demo_t_batch = demo_t.repeat(2, 1)
|
| 151 |
+
# Get the output and use only the first element
|
| 152 |
+
y_batch = self.vae.dec(z_batch, demo_t_batch)
|
| 153 |
+
y = y_batch[0:1]
|
| 154 |
+
else:
|
| 155 |
+
# No batch norm, proceed normally
|
| 156 |
+
y = self.vae.dec(z, demo_t)
|
| 157 |
+
else:
|
| 158 |
+
# Normal batch size, proceed as usual
|
| 159 |
+
y = self.vae.dec(z, demo_t)
|
| 160 |
+
|
| 161 |
return to_numpy(y)
|
| 162 |
|
| 163 |
def get_latents(self, x):
|
| 164 |
+
# Set model to evaluation mode
|
| 165 |
+
self.vae.eval()
|
| 166 |
+
|
| 167 |
+
# Use torch.no_grad for inference
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
try:
|
| 170 |
+
# Convert to torch tensor and move to CUDA if needed
|
| 171 |
+
x_tensor = to_cuda(to_torch(x), self.vae.use_cuda)
|
| 172 |
+
|
| 173 |
+
# Get latent representation
|
| 174 |
+
z = self.vae.enc(x_tensor)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error in encoder: {e}")
|
| 177 |
+
# Try workaround for batch norm if needed
|
| 178 |
+
if x.shape[0] == 1 and (hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2')):
|
| 179 |
+
print("Using batch normalization workaround for single sample")
|
| 180 |
+
# Repeat the input to create a batch of size 2
|
| 181 |
+
if len(x.shape) == 2:
|
| 182 |
+
x_batch = np.repeat(x, 2, axis=0)
|
| 183 |
+
else:
|
| 184 |
+
x_batch = np.array([x[0], x[0]])
|
| 185 |
+
|
| 186 |
+
# Process the batch
|
| 187 |
+
x_tensor = to_cuda(to_torch(x_batch), self.vae.use_cuda)
|
| 188 |
+
z_batch = self.vae.enc(x_tensor)
|
| 189 |
+
|
| 190 |
+
# Extract just the first sample's latent representation
|
| 191 |
+
z = z_batch[0:1]
|
| 192 |
+
else:
|
| 193 |
+
# Re-raise if we can't handle it
|
| 194 |
+
raise
|
| 195 |
+
|
| 196 |
return to_numpy(z)
|
| 197 |
|
| 198 |
def save(self, path):
|