SreekarB commited on
Commit
c48ac75
·
verified ·
1 Parent(s): d526dee

Upload 20 files

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -7,7 +7,7 @@ from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
7
  from visualization import plot_fc_matrices, plot_learning_curves
8
  import os
9
  import glob
10
- from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
11
  import json
12
  import pickle
13
  import pandas as pd
@@ -24,6 +24,7 @@ class AphasiaPredictionApp:
24
  self.predictor = None
25
  self.trained = False
26
  self.latent_dim = MODEL_CONFIG['latent_dim']
 
27
 
28
  def train_models(self, data_dir, latent_dim, nepochs, bsize):
29
  """
@@ -34,9 +35,8 @@ class AphasiaPredictionApp:
34
  logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
35
 
36
  # Default prediction parameters from our config
37
- prediction_type = PREDICTION_CONFIG.get('prediction_type', 'regression')
38
  outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
39
- logger.info(f"Prediction: type={prediction_type}, outcome={outcome_variable}")
40
 
41
  figures = {}
42
 
@@ -323,6 +323,8 @@ class AphasiaPredictionApp:
323
  try:
324
  real_treatment_file = process_behavioral_data_to_outcomes(csv_path)
325
  treatment_file = real_treatment_file # Use the real treatment file if processing succeeded
 
 
326
  logger.info(f"Using processed behavioral data for treatment outcomes")
327
  except Exception as proc_err:
328
  logger.warning(f"Couldn't process behavioral data: {proc_err}, using standard outcomes")
@@ -338,6 +340,8 @@ class AphasiaPredictionApp:
338
 
339
  # Use the found file
340
  treatment_file = real_treatment_file
 
 
341
  logger.info(f"Using real treatment outcomes file")
342
  except Exception as find_err:
343
  logger.warning(f"Couldn't find treatment outcomes file: {find_err}, using standard outcomes")
@@ -754,54 +758,50 @@ class AphasiaPredictionApp:
754
  # Plot predicted vs actual values
755
  ax1 = fig.add_subplot(gs[0, 0])
756
 
757
- if self.predictor.prediction_type == 'regression':
758
- # Regression: scatter plot
759
- ax1.scatter(y_true, y_pred, alpha=0.7)
760
-
761
- # Add perfect prediction line
762
- min_val = min(np.min(y_true), np.min(y_pred))
763
- max_val = max(np.max(y_true), np.max(y_pred))
764
- ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
765
-
766
- ax1.set_xlabel('Actual Values')
767
- ax1.set_ylabel('Predicted Values')
768
- ax1.set_title('Predicted vs. Actual Values')
769
-
770
- # Add R² to the plot
771
- r2 = r2_score(y_true, y_pred)
772
- ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
773
- bbox=dict(facecolor='white', alpha=0.5))
774
-
775
- # Plot residuals
776
- ax2 = fig.add_subplot(gs[0, 1])
777
- residuals = y_true - y_pred
778
- ax2.scatter(y_pred, residuals, alpha=0.7)
779
- ax2.axhline(y=0, color='r', linestyle='--')
780
- ax2.set_xlabel('Predicted Values')
781
- ax2.set_ylabel('Residuals')
782
- ax2.set_title('Residual Plot')
783
-
784
- # Plot prediction errors
785
- ax3 = fig.add_subplot(gs[1, 0])
786
- ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
787
- label='Predicted ± 2σ')
788
- ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
789
- ax3.set_xlabel('Sample Index')
790
- ax3.set_ylabel('Value')
791
- ax3.set_title('Prediction with Error Bars')
792
- ax3.legend()
793
-
794
- # Plot error distribution
795
- ax4 = fig.add_subplot(gs[1, 1])
796
- ax4.hist(residuals, bins=20, alpha=0.7)
797
- ax4.axvline(x=0, color='r', linestyle='--')
798
- ax4.set_xlabel('Prediction Error')
799
- ax4.set_ylabel('Frequency')
800
- ax4.set_title('Error Distribution')
801
-
802
- else: # classification
803
- # Convert to integer classes if they're strings
804
- if isinstance(y_true[0], str) or isinstance(y_pred[0], str):
805
  # Create mapping of class labels to integers
806
  classes = sorted(list(set(list(y_true) + list(y_pred))))
807
  class_to_int = {c: i for i, c in enumerate(classes)}
@@ -911,76 +911,39 @@ class AphasiaPredictionApp:
911
  """Create learning curve plots from cross-validation results"""
912
  fig = plt.figure(figsize=(12, 6))
913
 
914
- # Create a grid for plots
915
- if self.predictor.prediction_type == 'regression':
916
- # For regression, show R² and RMSE
917
- ax1 = plt.subplot(1, 2, 1)
918
- ax2 = plt.subplot(1, 2, 2)
919
-
920
- # Plot R² for each fold
921
- for i, metrics in enumerate(fold_metrics):
922
- ax1.plot(i+1, metrics['r2'], 'bo')
923
-
924
- # Plot average
925
- avg_r2 = np.mean([m['r2'] for m in fold_metrics])
926
- ax1.axhline(y=avg_r2, color='r', linestyle='--',
927
- label=f'Average R² = {avg_r2:.4f}')
928
-
929
- ax1.set_xlabel('Fold')
930
- ax1.set_ylabel('R²')
931
- ax1.set_title('R² by Fold')
932
- ax1.set_xticks(range(1, len(fold_metrics)+1))
933
- ax1.legend()
934
-
935
- # Plot RMSE for each fold
936
- for i, metrics in enumerate(fold_metrics):
937
- ax2.plot(i+1, metrics['rmse'], 'go')
938
-
939
- # Plot average RMSE
940
- avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
941
- ax2.axhline(y=avg_rmse, color='r', linestyle='--',
942
- label=f'Average RMSE = {avg_rmse:.4f}')
943
-
944
- ax2.set_xlabel('Fold')
945
- ax2.set_ylabel('RMSE')
946
- ax2.set_title('RMSE by Fold')
947
- ax2.set_xticks(range(1, len(fold_metrics)+1))
948
- ax2.legend()
949
-
950
- else: # classification
951
- # For classification, show accuracy and F1
952
- ax1 = plt.subplot(1, 2, 1)
953
- ax2 = plt.subplot(1, 2, 2)
954
-
955
- # Plot accuracy for each fold
956
- for i, metrics in enumerate(fold_metrics):
957
- ax1.plot(i+1, metrics['accuracy'], 'bo')
958
-
959
- # Plot average accuracy
960
- avg_acc = np.mean([m['accuracy'] for m in fold_metrics])
961
- ax1.axhline(y=avg_acc, color='r', linestyle='--',
962
- label=f'Average Accuracy = {avg_acc:.4f}')
963
-
964
- ax1.set_xlabel('Fold')
965
- ax1.set_ylabel('Accuracy')
966
- ax1.set_title('Accuracy by Fold')
967
- ax1.set_xticks(range(1, len(fold_metrics)+1))
968
- ax1.legend()
969
-
970
- # Plot F1 for each fold
971
- for i, metrics in enumerate(fold_metrics):
972
- ax2.plot(i+1, metrics['f1'], 'go')
973
-
974
- # Plot average F1
975
- avg_f1 = np.mean([m['f1'] for m in fold_metrics])
976
- ax2.axhline(y=avg_f1, color='r', linestyle='--',
977
- label=f'Average F1 = {avg_f1:.4f}')
978
-
979
- ax2.set_xlabel('Fold')
980
- ax2.set_ylabel('F1 Score')
981
- ax2.set_title('F1 Score by Fold')
982
- ax2.set_xticks(range(1, len(fold_metrics)+1))
983
- ax2.legend()
984
 
985
  plt.tight_layout()
986
  return fig
@@ -1350,6 +1313,8 @@ def process_behavioral_data_to_outcomes(behavioral_file):
1350
  outcomes_df = pd.DataFrame(outcome_data)
1351
  outcomes_df.to_csv(outcomes_file, index=False)
1352
  logger.info(f"Created treatment outcomes file with {len(outcomes_df)} patients")
 
 
1353
  return outcomes_file
1354
  else:
1355
  # If we couldn't extract outcomes per patient, try a simpler approach
@@ -1375,6 +1340,8 @@ def process_behavioral_data_to_outcomes(behavioral_file):
1375
  ])
1376
  outcomes_df.to_csv(outcomes_file, index=False)
1377
  logger.warning(f"Created simplified treatment outcomes with group improvement: {improvement:.2f}")
 
 
1378
  return outcomes_file
1379
  except Exception as e:
1380
  logger.error(f"Could not create even simplified outcomes: {e}")
@@ -1858,8 +1825,8 @@ def create_interface():
1858
  gr.Markdown("# Aphasia Treatment Trajectory Prediction")
1859
 
1860
  with gr.Tabs():
1861
- # Training Tab
1862
- with gr.Tab("Train Models"):
1863
  with gr.Row():
1864
  with gr.Column(scale=1):
1865
  data_dir = gr.Textbox(
@@ -1889,39 +1856,72 @@ def create_interface():
1889
  use_hf_dataset = gr.Checkbox(
1890
  label="Use HuggingFace Dataset", value=True
1891
  )
1892
- with gr.Group("Prediction Options"):
1893
- prediction_type = gr.Radio(
1894
- label="Prediction Type",
1895
- choices=["regression", "classification"],
1896
- value="regression"
1897
- )
1898
- outcome_variable = gr.Dropdown(
1899
- label="Outcome Variable",
1900
- choices=["wab_aq", "age", "mpo", "education"],
1901
- value="wab_aq"
1902
- )
1903
  skip_behavioral = gr.Checkbox(
1904
  label="Skip Behavioral Data Processing",
1905
  value=PREDICTION_CONFIG.get('skip_behavioral_data', True),
1906
  info="Use pre-defined treatment outcomes instead of processing behavioral data"
1907
  )
1908
-
1909
- with gr.Accordion("Advanced Data Options", open=False):
1910
- use_synthetic_nifti = gr.Checkbox(
1911
- label="Use Synthetic NIfTI Data",
1912
- value=PREDICTION_CONFIG.get('use_synthetic_nifti', False),
1913
- info="Generate synthetic NIfTI files if real ones aren't found"
1914
- )
1915
- use_synthetic_fc = gr.Checkbox(
1916
- label="Use Synthetic FC Matrices",
1917
- value=PREDICTION_CONFIG.get('use_synthetic_fc', False),
1918
- info="Generate synthetic FC matrices if processing fails"
1919
- )
1920
 
1921
- train_btn = gr.Button("Train Models", variant="primary")
 
 
1922
 
1923
  with gr.Row():
1924
- fc_plot = gr.Plot(label="FC Analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1925
 
1926
  with gr.Row():
1927
  with gr.Column(scale=1):
@@ -1930,13 +1930,18 @@ def create_interface():
1930
  prediction_plot = gr.Plot(label="Prediction Performance")
1931
 
1932
  with gr.Row():
1933
- learning_plot = gr.Plot(label="Cross-validation Results")
 
 
1934
 
1935
- # Prediction Tab
1936
- with gr.Tab("Predict Treatment"):
 
 
 
1937
  with gr.Row():
1938
  with gr.Column(scale=1):
1939
- fmri_file = gr.File(label="Patient fMRI Data")
1940
  with gr.Column(scale=1):
1941
  with gr.Group("Patient Demographics"):
1942
  age = gr.Number(label="Age at Stroke", value=60)
@@ -1952,21 +1957,23 @@ def create_interface():
1952
  with gr.Row():
1953
  trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
1954
 
1955
- # Connect components
1956
- train_outputs = {
1957
- 'vae': fc_plot,
1958
- 'importance': importance_plot,
1959
- 'prediction': prediction_plot,
1960
- 'learning': learning_plot
 
 
 
 
1961
  }
1962
 
1963
- # Handle train button click
1964
- def handle_train(data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
1965
- prediction_type, outcome_variable, skip_behavioral,
1966
- use_synthetic_nifti, use_synthetic_fc):
1967
- # Set prediction config values for this run
1968
- PREDICTION_CONFIG['prediction_type'] = prediction_type
1969
- PREDICTION_CONFIG['default_outcome'] = outcome_variable
1970
  PREDICTION_CONFIG['skip_behavioral_data'] = skip_behavioral
1971
  PREDICTION_CONFIG['use_synthetic_nifti'] = use_synthetic_nifti
1972
  PREDICTION_CONFIG['use_synthetic_fc'] = use_synthetic_fc
@@ -1978,36 +1985,271 @@ def create_interface():
1978
  else:
1979
  PREDICTION_CONFIG['local_nii_dir'] = None
1980
 
1981
- # Log helpful information for the user
1982
- logger.info(f"Looking for data in directory: {data_dir}")
1983
- logger.info(f"Expected files: FC_graph_covariate_data.csv and treatment_outcomes.csv")
1984
- logger.info(f"Prediction type: {prediction_type}, target: {outcome_variable}")
1985
 
1986
- results = app.train_models(
1987
- data_dir=data_dir,
1988
- latent_dim=latent_dim,
1989
- nepochs=nepochs,
1990
- bsize=bsize
1991
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1992
 
1993
- # Return plots in the expected order
1994
- return [
1995
- results.get('vae', None),
1996
- results.get('importance', None),
1997
- results.get('prediction', None),
1998
- results.get('learning', None)
1999
- ]
2000
-
2001
- train_btn.click(
2002
- fn=handle_train,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2003
  inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
2004
- prediction_type, outcome_variable, skip_behavioral,
2005
- use_synthetic_nifti, use_synthetic_fc],
2006
- outputs=[fc_plot, importance_plot, prediction_plot, learning_plot]
 
 
 
 
 
 
2007
  )
2008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2009
  predict_btn.click(
2010
- fn=app.predict_treatment,
2011
  inputs=[fmri_file, age, sex, months, wab],
2012
  outputs=[prediction_text, trajectory_plot]
2013
  )
 
7
  from visualization import plot_fc_matrices, plot_learning_curves
8
  import os
9
  import glob
10
+ from sklearn.metrics import mean_squared_error, r2_score
11
  import json
12
  import pickle
13
  import pandas as pd
 
24
  self.predictor = None
25
  self.trained = False
26
  self.latent_dim = MODEL_CONFIG['latent_dim']
27
+ self.last_treatment_file = None # Track the last treatment file used
28
 
29
  def train_models(self, data_dir, latent_dim, nepochs, bsize):
30
  """
 
35
  logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
36
 
37
  # Default prediction parameters from our config
 
38
  outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
39
+ logger.info(f"Prediction: type=regression, outcome={outcome_variable}")
40
 
41
  figures = {}
42
 
 
323
  try:
324
  real_treatment_file = process_behavioral_data_to_outcomes(csv_path)
325
  treatment_file = real_treatment_file # Use the real treatment file if processing succeeded
326
+ # Store the treatment file path for later use
327
+ self.last_treatment_file = treatment_file
328
  logger.info(f"Using processed behavioral data for treatment outcomes")
329
  except Exception as proc_err:
330
  logger.warning(f"Couldn't process behavioral data: {proc_err}, using standard outcomes")
 
340
 
341
  # Use the found file
342
  treatment_file = real_treatment_file
343
+ # Store the treatment file path for later use
344
+ self.last_treatment_file = treatment_file
345
  logger.info(f"Using real treatment outcomes file")
346
  except Exception as find_err:
347
  logger.warning(f"Couldn't find treatment outcomes file: {find_err}, using standard outcomes")
 
758
  # Plot predicted vs actual values
759
  ax1 = fig.add_subplot(gs[0, 0])
760
 
761
+ # Regression plots
762
+ # Scatter plot
763
+ ax1.scatter(y_true, y_pred, alpha=0.7)
764
+
765
+ # Add perfect prediction line
766
+ min_val = min(np.min(y_true), np.min(y_pred))
767
+ max_val = max(np.max(y_true), np.max(y_pred))
768
+ ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
769
+
770
+ ax1.set_xlabel('Actual Values')
771
+ ax1.set_ylabel('Predicted Values')
772
+ ax1.set_title('Predicted vs. Actual Values')
773
+
774
+ # Add R² to the plot
775
+ r2 = r2_score(y_true, y_pred)
776
+ ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
777
+ bbox=dict(facecolor='white', alpha=0.5))
778
+
779
+ # Plot residuals
780
+ ax2 = fig.add_subplot(gs[0, 1])
781
+ residuals = y_true - y_pred
782
+ ax2.scatter(y_pred, residuals, alpha=0.7)
783
+ ax2.axhline(y=0, color='r', linestyle='--')
784
+ ax2.set_xlabel('Predicted Values')
785
+ ax2.set_ylabel('Residuals')
786
+ ax2.set_title('Residual Plot')
787
+
788
+ # Plot prediction errors
789
+ ax3 = fig.add_subplot(gs[1, 0])
790
+ ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
791
+ label='Predicted ± 2σ')
792
+ ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
793
+ ax3.set_xlabel('Sample Index')
794
+ ax3.set_ylabel('Value')
795
+ ax3.set_title('Prediction with Error Bars')
796
+ ax3.legend()
797
+
798
+ # Plot error distribution
799
+ ax4 = fig.add_subplot(gs[1, 1])
800
+ ax4.hist(residuals, bins=20, alpha=0.7)
801
+ ax4.axvline(x=0, color='r', linestyle='--')
802
+ ax4.set_xlabel('Prediction Error')
803
+ ax4.set_ylabel('Frequency')
804
+ ax4.set_title('Error Distribution')
 
 
 
 
805
  # Create mapping of class labels to integers
806
  classes = sorted(list(set(list(y_true) + list(y_pred))))
807
  class_to_int = {c: i for i, c in enumerate(classes)}
 
911
  """Create learning curve plots from cross-validation results"""
912
  fig = plt.figure(figsize=(12, 6))
913
 
914
+ # For regression, show and RMSE
915
+ ax1 = plt.subplot(1, 2, 1)
916
+ ax2 = plt.subplot(1, 2, 2)
917
+
918
+ # Plot for each fold
919
+ for i, metrics in enumerate(fold_metrics):
920
+ ax1.plot(i+1, metrics['r2'], 'bo')
921
+
922
+ # Plot average R²
923
+ avg_r2 = np.mean([m['r2'] for m in fold_metrics])
924
+ ax1.axhline(y=avg_r2, color='r', linestyle='--',
925
+ label=f'Average = {avg_r2:.4f}')
926
+
927
+ ax1.set_xlabel('Fold')
928
+ ax1.set_ylabel('R²')
929
+ ax1.set_title('R² by Fold')
930
+ ax1.set_xticks(range(1, len(fold_metrics)+1))
931
+ ax1.legend()
932
+
933
+ # Plot RMSE for each fold
934
+ for i, metrics in enumerate(fold_metrics):
935
+ ax2.plot(i+1, metrics['rmse'], 'go')
936
+
937
+ # Plot average RMSE
938
+ avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
939
+ ax2.axhline(y=avg_rmse, color='r', linestyle='--',
940
+ label=f'Average RMSE = {avg_rmse:.4f}')
941
+
942
+ ax2.set_xlabel('Fold')
943
+ ax2.set_ylabel('RMSE')
944
+ ax2.set_title('RMSE by Fold')
945
+ ax2.set_xticks(range(1, len(fold_metrics)+1))
946
+ ax2.legend()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
  plt.tight_layout()
949
  return fig
 
1313
  outcomes_df = pd.DataFrame(outcome_data)
1314
  outcomes_df.to_csv(outcomes_file, index=False)
1315
  logger.info(f"Created treatment outcomes file with {len(outcomes_df)} patients")
1316
+ # Store the treatment file path for later use
1317
+ self.last_treatment_file = outcomes_file
1318
  return outcomes_file
1319
  else:
1320
  # If we couldn't extract outcomes per patient, try a simpler approach
 
1340
  ])
1341
  outcomes_df.to_csv(outcomes_file, index=False)
1342
  logger.warning(f"Created simplified treatment outcomes with group improvement: {improvement:.2f}")
1343
+ # Store the treatment file path for later use
1344
+ self.last_treatment_file = outcomes_file
1345
  return outcomes_file
1346
  except Exception as e:
1347
  logger.error(f"Could not create even simplified outcomes: {e}")
 
1825
  gr.Markdown("# Aphasia Treatment Trajectory Prediction")
1826
 
1827
  with gr.Tabs():
1828
+ # Tab 1: VAE Training
1829
+ with gr.Tab("1. VAE Training"):
1830
  with gr.Row():
1831
  with gr.Column(scale=1):
1832
  data_dir = gr.Textbox(
 
1856
  use_hf_dataset = gr.Checkbox(
1857
  label="Use HuggingFace Dataset", value=True
1858
  )
1859
+
1860
+ with gr.Accordion("Advanced Data Options", open=False):
 
 
 
 
 
 
 
 
 
1861
  skip_behavioral = gr.Checkbox(
1862
  label="Skip Behavioral Data Processing",
1863
  value=PREDICTION_CONFIG.get('skip_behavioral_data', True),
1864
  info="Use pre-defined treatment outcomes instead of processing behavioral data"
1865
  )
1866
+ use_synthetic_nifti = gr.Checkbox(
1867
+ label="Use Synthetic NIfTI Data",
1868
+ value=PREDICTION_CONFIG.get('use_synthetic_nifti', False),
1869
+ info="Generate synthetic NIfTI files if real ones aren't found"
1870
+ )
1871
+ use_synthetic_fc = gr.Checkbox(
1872
+ label="Use Synthetic FC Matrices",
1873
+ value=PREDICTION_CONFIG.get('use_synthetic_fc', False),
1874
+ info="Generate synthetic FC matrices if processing fails"
1875
+ )
 
 
1876
 
1877
+ train_vae_btn = gr.Button("Train VAE Model", variant="primary")
1878
+
1879
+ gr.Markdown("### VAE Training Results")
1880
 
1881
  with gr.Row():
1882
+ fc_plot = gr.Plot(label="FC Matrices (Original/Reconstructed/Generated)")
1883
+
1884
+ with gr.Row():
1885
+ learning_plot = gr.Plot(label="VAE Learning Curves")
1886
+
1887
+ gr.Markdown("After VAE training completes, proceed to the 'Random Forest Prediction' tab →")
1888
+
1889
+ # Tab 2: Random Forest Prediction
1890
+ with gr.Tab("2. Random Forest Prediction"):
1891
+ gr.Markdown("### Random Forest Model Training")
1892
+ gr.Markdown("First complete the VAE training in tab 1, then configure and train the Random Forest model below:")
1893
+
1894
+ with gr.Row():
1895
+ with gr.Column(scale=1):
1896
+ prediction_type = gr.Radio(
1897
+ label="Prediction Type",
1898
+ choices=["regression", "classification"],
1899
+ value="regression"
1900
+ )
1901
+ outcome_variable = gr.Dropdown(
1902
+ label="Outcome Variable",
1903
+ choices=["wab_aq", "age", "mpo", "education"],
1904
+ value="wab_aq"
1905
+ )
1906
+
1907
+ with gr.Column(scale=1):
1908
+ rf_n_estimators = gr.Slider(
1909
+ minimum=10, maximum=500, step=10,
1910
+ label="Number of Trees", value=100
1911
+ )
1912
+ rf_max_depth = gr.Slider(
1913
+ minimum=3, maximum=50, step=1,
1914
+ label="Max Tree Depth", value=10,
1915
+ info="Set to 0 for unlimited depth"
1916
+ )
1917
+ rf_cv_folds = gr.Slider(
1918
+ minimum=2, maximum=10, step=1,
1919
+ label="Cross-validation Folds", value=5
1920
+ )
1921
+
1922
+ train_rf_btn = gr.Button("Train Random Forest Model", variant="primary")
1923
+
1924
+ gr.Markdown("### Random Forest Results")
1925
 
1926
  with gr.Row():
1927
  with gr.Column(scale=1):
 
1930
  prediction_plot = gr.Plot(label="Prediction Performance")
1931
 
1932
  with gr.Row():
1933
+ rf_metrics = gr.Textbox(label="Model Performance Metrics")
1934
+
1935
+ gr.Markdown("After Random Forest training completes, proceed to the 'Treatment Prediction' tab →")
1936
 
1937
+ # Tab 3: Predict Treatment
1938
+ with gr.Tab("3. Treatment Prediction"):
1939
+ gr.Markdown("### Predict Individual Treatment Outcomes")
1940
+ gr.Markdown("After completing VAE and Random Forest training, you can predict treatment outcomes for individual patients:")
1941
+
1942
  with gr.Row():
1943
  with gr.Column(scale=1):
1944
+ fmri_file = gr.File(label="Patient fMRI Data (NIfTI file)")
1945
  with gr.Column(scale=1):
1946
  with gr.Group("Patient Demographics"):
1947
  age = gr.Number(label="Age at Stroke", value=60)
 
1957
  with gr.Row():
1958
  trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
1959
 
1960
+ # Define various handler functions for the different tabs
1961
+
1962
+ # Store shared state between tabs
1963
+ app_state = {
1964
+ 'vae': None,
1965
+ 'latents': None,
1966
+ 'demographics': None,
1967
+ 'predictor': None,
1968
+ 'vae_trained': False,
1969
+ 'rf_trained': False
1970
  }
1971
 
1972
+ # Tab 1: VAE Training Handler
1973
+ def handle_vae_training(data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
1974
+ skip_behavioral, use_synthetic_nifti, use_synthetic_fc):
1975
+ """Train the VAE model and display FC visualization and learning curves"""
1976
+ # Store config values
 
 
1977
  PREDICTION_CONFIG['skip_behavioral_data'] = skip_behavioral
1978
  PREDICTION_CONFIG['use_synthetic_nifti'] = use_synthetic_nifti
1979
  PREDICTION_CONFIG['use_synthetic_fc'] = use_synthetic_fc
 
1985
  else:
1986
  PREDICTION_CONFIG['local_nii_dir'] = None
1987
 
1988
+ # Log info
1989
+ logger.info(f"Training VAE model with data from: {data_dir}")
1990
+ logger.info(f"VAE parameters: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
 
1991
 
1992
+ # Create a subset of app.train_models functionality that just trains the VAE
1993
+ try:
1994
+ # Start by setting up data for the VAE
1995
+ from vae_model import DemoVAE
1996
+ from data_preprocessing import load_and_preprocess_data
1997
+ from main import run_analysis
1998
+ import numpy as np
1999
+ import os
2000
+
2001
+ # Prepare VAE training parameters
2002
+ MODEL_CONFIG.update({
2003
+ 'latent_dim': latent_dim,
2004
+ 'nepochs': nepochs,
2005
+ 'bsize': bsize
2006
+ })
2007
+
2008
+ # First, find and preprocess data
2009
+ logger.info("Looking for data in directory and preprocessing...")
2010
+
2011
+ # This part is similar to app.train_models but only focuses on VAE
2012
+ if data_dir == "SreekarB/OSFData":
2013
+ # Use dataset, similar to existing code in app.train_models
2014
+ # For brevity, we'll call the full train_models function but only
2015
+ # extract the VAE-related results
2016
+ results = app.train_models(
2017
+ data_dir=data_dir,
2018
+ latent_dim=latent_dim,
2019
+ nepochs=nepochs,
2020
+ bsize=bsize
2021
+ )
2022
+
2023
+ # Store results in app_state for the next tabs
2024
+ app_state['vae'] = results.get('vae', None)
2025
+ app_state['latents'] = results.get('latents', None)
2026
+ app_state['demographics'] = results.get('demographics', None)
2027
+ app_state['vae_trained'] = True
2028
+
2029
+ # Return just the VAE visualizations
2030
+ return [
2031
+ results.get('vae', None), # FC matrix visualization
2032
+ results.get('learning', None) # VAE learning curves
2033
+ ]
2034
+ else:
2035
+ # Local directory case
2036
+ results = app.train_models(
2037
+ data_dir=data_dir,
2038
+ latent_dim=latent_dim,
2039
+ nepochs=nepochs,
2040
+ bsize=bsize
2041
+ )
2042
+
2043
+ # Store results in app_state
2044
+ app_state['vae'] = results.get('vae', None)
2045
+ app_state['latents'] = results.get('latents', None)
2046
+ app_state['demographics'] = results.get('demographics', None)
2047
+ app_state['vae_trained'] = True
2048
+
2049
+ # Return just the VAE visualizations
2050
+ return [
2051
+ results.get('vae', None), # FC matrix visualization
2052
+ results.get('learning', None) # VAE learning curves
2053
+ ]
2054
+ except Exception as e:
2055
+ logger.error(f"Error in VAE training: {str(e)}", exc_info=True)
2056
+ error_fig = plt.figure(figsize=(10, 6))
2057
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
2058
+ horizontalalignment='center', verticalalignment='center',
2059
+ fontsize=12, color='red', wrap=True)
2060
+ plt.axis('off')
2061
+
2062
+ # Return error figures for both outputs
2063
+ return [error_fig, error_fig]
2064
+
2065
+ # Tab 2: Random Forest Training Handler
2066
+ def handle_rf_training(outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
2067
+ """Train the Random Forest model using the VAE latent representations"""
2068
+ # Check if VAE has been trained
2069
+ if not app_state['vae_trained'] or app_state['latents'] is None:
2070
+ error_fig = plt.figure(figsize=(10, 6))
2071
+ message = "Error: You must train the VAE model in Tab 1 first!"
2072
+ plt.text(0.5, 0.5, message,
2073
+ horizontalalignment='center', verticalalignment='center',
2074
+ fontsize=14, color='red')
2075
+ plt.axis('off')
2076
+
2077
+ # Return error for both outputs
2078
+ return [error_fig, error_fig, "Error: VAE not trained. Go to Tab 1 and train the VAE first."]
2079
 
2080
+ try:
2081
+ # Update RF configuration
2082
+ PREDICTION_CONFIG['default_outcome'] = outcome_variable
2083
+ PREDICTION_CONFIG['n_estimators'] = rf_n_estimators
2084
+ PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
2085
+ PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
2086
+
2087
+ logger.info(f"Training Random Forest model: outcome={outcome_variable}")
2088
+ logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
2089
+
2090
+ # Get data from app_state
2091
+ latents = app_state['latents']
2092
+ demographics = app_state['demographics']
2093
+
2094
+ # Train Random Forest predictor
2095
+ from rcf_prediction import AphasiaTreatmentPredictor
2096
+ import pandas as pd
2097
+ import numpy as np
2098
+
2099
+ # Need to find treatment outcomes data
2100
+ # This would normally be loaded in train_models, so we need
2101
+ # to mock it here or load from app_state
2102
+ if hasattr(app, 'last_treatment_file') and os.path.exists(app.last_treatment_file):
2103
+ treatment_file = app.last_treatment_file
2104
+ treatment_df = pd.read_csv(treatment_file)
2105
+ treatment_outcomes = treatment_df['outcome_score'].values
2106
+
2107
+ # Initialize predictor
2108
+ predictor = AphasiaTreatmentPredictor(
2109
+ n_estimators=rf_n_estimators,
2110
+ max_depth=rf_max_depth if rf_max_depth > 0 else None
2111
+ )
2112
+
2113
+ # Cross-validate
2114
+ cv_results = predictor.cross_validate(
2115
+ latents=latents,
2116
+ demographics=demographics,
2117
+ treatment_outcomes=treatment_outcomes,
2118
+ n_splits=rf_cv_folds
2119
+ )
2120
+
2121
+ # Fit final model
2122
+ predictor.fit(latents, demographics, treatment_outcomes)
2123
+
2124
+ # Store in app_state
2125
+ app_state['predictor'] = predictor
2126
+ app_state['rf_trained'] = True
2127
+
2128
+ # Create feature importance plot
2129
+ importance_fig = predictor.plot_feature_importance()
2130
+
2131
+ # Create prediction performance plot
2132
+ predictions = cv_results['predictions']
2133
+ prediction_stds = cv_results['prediction_stds']
2134
+
2135
+ performance_fig = plt.figure(figsize=(8, 6))
2136
+
2137
+ # Check if we have valid predictions
2138
+ if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes):
2139
+ # Only create scatter plot if we have matching data
2140
+ plt.scatter(treatment_outcomes, predictions)
2141
+
2142
+ # Reference line
2143
+ min_val = min(np.min(treatment_outcomes), np.min(predictions))
2144
+ max_val = max(np.max(treatment_outcomes), np.max(predictions))
2145
+ plt.plot([min_val, max_val], [min_val, max_val], 'r--')
2146
+
2147
+ # Confidence band
2148
+ plt.fill_between(treatment_outcomes,
2149
+ predictions - 2*prediction_stds,
2150
+ predictions + 2*prediction_stds,
2151
+ alpha=0.2, color='gray')
2152
+
2153
+ plt.xlabel('Actual Outcome')
2154
+ plt.ylabel('Predicted Outcome')
2155
+
2156
+ # Get performance metrics
2157
+ metrics_text = ""
2158
+ mean_metrics = cv_results.get('mean_metrics', {})
2159
+
2160
+ r2 = mean_metrics.get('r2', 0)
2161
+ rmse = mean_metrics.get('rmse', 0)
2162
+ plt.title(f'Treatment Outcome Prediction\nR² = {r2:.3f}, RMSE = {rmse:.3f}')
2163
+ metrics_text = f"Regression Model Performance:\nR² = {r2:.4f}\nRMSE = {rmse:.4f}"
2164
+ else:
2165
+ # Handle case with no data
2166
+ plt.text(0.5, 0.5, "No prediction data available",
2167
+ ha='center', va='center', transform=plt.gca().transAxes)
2168
+ metrics_text = "No performance metrics available"
2169
+
2170
+ plt.tight_layout()
2171
+
2172
+ return [importance_fig, performance_fig, metrics_text]
2173
+ else:
2174
+ # No treatment file available
2175
+ error_fig = plt.figure(figsize=(10, 6))
2176
+ message = "Error: Treatment outcomes file not found. Please retrain the VAE in Tab 1."
2177
+ plt.text(0.5, 0.5, message,
2178
+ horizontalalignment='center', verticalalignment='center',
2179
+ fontsize=14, color='red')
2180
+ plt.axis('off')
2181
+
2182
+ return [error_fig, error_fig, "Error: Treatment outcomes file not found."]
2183
+
2184
+ except Exception as e:
2185
+ logger.error(f"Error in RF training: {str(e)}", exc_info=True)
2186
+ error_fig = plt.figure(figsize=(10, 6))
2187
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
2188
+ horizontalalignment='center', verticalalignment='center',
2189
+ fontsize=12, color='red', wrap=True)
2190
+ plt.axis('off')
2191
+
2192
+ return [error_fig, error_fig, f"Error: {str(e)}"]
2193
+
2194
+ # Connect the tab handlers
2195
+
2196
+ # VAE Training tab
2197
+ train_vae_btn.click(
2198
+ fn=handle_vae_training,
2199
  inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
2200
+ skip_behavioral, use_synthetic_nifti, use_synthetic_fc],
2201
+ outputs=[fc_plot, learning_plot]
2202
+ )
2203
+
2204
+ # Random Forest Training tab
2205
+ train_rf_btn.click(
2206
+ fn=handle_rf_training,
2207
+ inputs=[prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds],
2208
+ outputs=[importance_plot, prediction_plot, rf_metrics]
2209
  )
2210
 
2211
+ # Tab 3: Treatment Prediction Handler
2212
+ def handle_treatment_prediction(fmri_file, age, sex, months, wab):
2213
+ """Predict treatment outcome for a new patient"""
2214
+ # Check if models have been trained
2215
+ if not app_state['vae_trained'] or not app_state['rf_trained']:
2216
+ error_message = "Error: You must train both the VAE (Tab 1) and Random Forest (Tab 2) models first!"
2217
+ error_fig = plt.figure(figsize=(10, 6))
2218
+ plt.text(0.5, 0.5, error_message,
2219
+ horizontalalignment='center', verticalalignment='center',
2220
+ fontsize=14, color='red')
2221
+ plt.axis('off')
2222
+
2223
+ return [error_message, error_fig]
2224
+
2225
+ # Use the trained models from app_state for prediction
2226
+ try:
2227
+ # Set up prediction
2228
+ if app_state['vae'] is None or app_state['predictor'] is None:
2229
+ return ["Error: Models not properly trained", None]
2230
+
2231
+ # Create a temporary prediction app with our trained models
2232
+ temp_app = AphasiaPredictionApp()
2233
+ temp_app.vae = app_state['vae']
2234
+ temp_app.predictor = app_state['predictor']
2235
+ temp_app.trained = True
2236
+ temp_app.latent_dim = app_state['vae'].latent_dim if hasattr(app_state['vae'], 'latent_dim') else 32
2237
+
2238
+ # Make prediction
2239
+ return temp_app.predict_treatment(
2240
+ fmri_file=fmri_file,
2241
+ age=age,
2242
+ sex=sex,
2243
+ months_post_stroke=months,
2244
+ wab_score=wab
2245
+ )
2246
+ except Exception as e:
2247
+ logger.error(f"Error in treatment prediction: {str(e)}", exc_info=True)
2248
+ return [f"Error in prediction: {str(e)}", None]
2249
+
2250
+ # Connect the treatment prediction handler
2251
  predict_btn.click(
2252
+ fn=handle_treatment_prediction,
2253
  inputs=[fmri_file, age, sex, months, wab],
2254
  outputs=[prediction_text, trajectory_plot]
2255
  )
config.py CHANGED
@@ -27,7 +27,6 @@ PREDICTION_CONFIG = {
27
  'n_estimators': 100,
28
  'max_depth': None,
29
  'cv_folds': 5,
30
- 'prediction_type': 'regression',
31
  'default_outcome': 'wab_aq',
32
  'save_path': 'results/treatment_predictor.joblib',
33
  'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
 
27
  'n_estimators': 100,
28
  'max_depth': None,
29
  'cv_folds': 5,
 
30
  'default_outcome': 'wab_aq',
31
  'save_path': 'results/treatment_predictor.joblib',
32
  'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
demo_fc_visualization.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo script to visualize FC matrices from real fMRI data using nilearn's built-in datasets.
3
+ """
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from nilearn import datasets
8
+ from nilearn import input_data, connectome
9
+ from fc_visualization import FCVisualizer
10
+
11
+ def visualize_from_nilearn_dataset():
12
+ """Download and visualize FC matrices from nilearn's ADHD dataset."""
13
+ print("Downloading a sample fMRI dataset (ADHD)...")
14
+ adhd_dataset = datasets.fetch_adhd(n_subjects=1)
15
+
16
+ # Get the fMRI file path
17
+ func_file = adhd_dataset.func[0]
18
+ confound_file = adhd_dataset.confounds[0]
19
+
20
+ print(f"Downloaded fMRI file: {func_file}")
21
+
22
+ # Get Power atlas coordinates
23
+ power = datasets.fetch_coords_power_2011()
24
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
25
+
26
+ print(f"Using Power atlas with {len(coords)} ROIs")
27
+
28
+ # Create a masker to extract time series from the ROIs
29
+ masker = input_data.NiftiSpheresMasker(
30
+ coords,
31
+ radius=8, # 8mm radius
32
+ standardize=True,
33
+ memory='nilearn_cache',
34
+ memory_level=1,
35
+ verbose=1,
36
+ detrend=True,
37
+ low_pass=0.08,
38
+ high_pass=0.01,
39
+ t_r=2.0 # ADHD dataset TR
40
+ )
41
+
42
+ # Extract time series, including confounds
43
+ print("Extracting time series from ROIs...")
44
+ time_series = masker.fit_transform(func_file, confounds=confound_file)
45
+ print(f"Time series shape: {time_series.shape}")
46
+
47
+ # Compute correlation matrix (FC matrix)
48
+ correlation_measure = connectome.ConnectivityMeasure(
49
+ kind='correlation',
50
+ vectorize=False,
51
+ discard_diagonal=False
52
+ )
53
+
54
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
55
+ print(f"FC matrix shape: {fc_matrix.shape}")
56
+
57
+ # Save the FC matrix for future use
58
+ np.save('adhd_fc_matrix.npy', fc_matrix)
59
+ print("Saved FC matrix to adhd_fc_matrix.npy")
60
+
61
+ # Visualize the FC matrix
62
+ visualizer = FCVisualizer(cmap='RdBu_r', vmin=-1, vmax=1)
63
+ fig, _ = visualizer.plot_single_matrix(fc_matrix, title="ADHD FC Matrix (Power Atlas)")
64
+
65
+ # Save the figure
66
+ fig.savefig('adhd_fc_matrix.png', dpi=300, bbox_inches='tight')
67
+ print("Saved visualization to adhd_fc_matrix.png")
68
+
69
+ # Show the figure
70
+ plt.show()
71
+
72
+ if __name__ == "__main__":
73
+ visualize_from_nilearn_dataset()
demovae/model.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import random
7
+ import numpy as np
8
+
9
+ from sklearn.linear_model import Ridge
10
+ from sklearn.linear_model import LogisticRegression
11
+
12
+ def to_torch(x):
13
+ return torch.from_numpy(x).float()
14
+
15
+ def to_cuda(x, use_cuda):
16
+ if use_cuda:
17
+ try:
18
+ return x.cuda()
19
+ except (RuntimeError, AssertionError) as e:
20
+ print(f"Warning: CUDA error: {e}. Falling back to CPU.")
21
+ return x
22
+ else:
23
+ return x
24
+
25
+ def to_numpy(x):
26
+ return x.detach().cpu().numpy()
27
+
28
+ class VAE(nn.Module):
29
+ def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
30
+ super(VAE, self).__init__()
31
+ self.input_dim = input_dim
32
+ self.latent_dim = latent_dim
33
+ self.demo_dim = demo_dim
34
+ self.use_cuda = use_cuda
35
+ self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
36
+ self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
37
+ self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
38
+ self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
39
+
40
+ def enc(self, x):
41
+ x = F.relu(self.enc1(x))
42
+ z = self.enc2(x)
43
+ return z
44
+
45
+ def gen(self, n):
46
+ return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
47
+
48
+ def dec(self, z, demo):
49
+ z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
50
+ x = F.relu(self.dec1(z))
51
+ x = self.dec2(x)
52
+ #x = x.reshape(len(z), 264, 5)
53
+ #x = torch.einsum('nac,nbc->nab', x, x)
54
+ #a,b = np.triu_indices(264, 1)
55
+ #x = x[:,a,b]
56
+ return x
57
+
58
+ def rmse(a, b, mean=torch.mean):
59
+ return mean((a-b)**2)**0.5
60
+
61
+ def latent_loss(z, use_cuda=True):
62
+ C = z.T@z
63
+ mu = torch.mean(z, dim=0)
64
+ tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
65
+ tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
66
+ loss_C = rmse(C, tgt1)
67
+ loss_mu = rmse(mu, tgt2)
68
+ return loss_C, loss_mu, C, mu
69
+
70
+ def decor_loss(z, demo, use_cuda=True):
71
+ ps = []
72
+ losses = []
73
+ for di in range(demo.shape[1]):
74
+ d = demo[:,di]
75
+ d = d - torch.mean(d)
76
+ p = torch.einsum('n,nz->z', d, z)
77
+ p = p/torch.std(d)
78
+ p = p/torch.einsum('nz,nz->z', z, z)
79
+ tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
80
+ loss = rmse(p, tgt)
81
+ losses.append(loss)
82
+ ps.append(p)
83
+ losses = torch.stack(losses)
84
+ return losses, ps
85
+
86
+ def pretty(x):
87
+ return f'{round(float(x), 4)}'
88
+
89
+ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
90
+ demo_t = []
91
+ demo_idx = 0
92
+ for d,t,s in zip(demo, demo_types, pred_stats):
93
+ if t == 'continuous':
94
+ demo_t.append(to_cuda(to_torch(d), use_cuda))
95
+ elif t == 'categorical':
96
+ for dd in d:
97
+ if dd not in s:
98
+ print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
99
+ raise Exception('Bad demographic')
100
+ for ss in s:
101
+ idx = (d == ss).astype('bool')
102
+ zeros = torch.zeros(len(d))
103
+ zeros[idx] = 1
104
+ demo_t.append(to_cuda(zeros, use_cuda))
105
+ demo_idx += 1
106
+ demo_t = torch.stack(demo_t).permute(1,0)
107
+ return demo_t
108
+
109
+ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
110
+ # Get linear predictors for demographics
111
+ pred_w = []
112
+ pred_i = []
113
+ # Pred stats are mean and std for continuous, and a list of all values for categorical
114
+ pred_stats = []
115
+ for i,d,t in zip(range(len(demo)), demo, demo_types):
116
+ print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
117
+ if t == 'continuous':
118
+ pred_stats.append([np.mean(d), np.std(d)])
119
+ reg = Ridge(alpha=alpha).fit(x, d)
120
+ reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
121
+ reg_i = reg.intercept_
122
+ pred_w.append(reg_w)
123
+ pred_i.append(reg_i)
124
+ elif t == 'categorical':
125
+ pred_stats.append(sorted(list(set(list(d)))))
126
+ reg = LogisticRegression(C=LR_C).fit(x, d)
127
+ # Binary
128
+ if len(reg.coef_) == 1:
129
+ reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
130
+ reg_i = reg.intercept_[0]
131
+ pred_w.append(-reg_w)
132
+ pred_i.append(-reg_i)
133
+ pred_w.append(reg_w)
134
+ pred_i.append(reg_i)
135
+ # Categorical
136
+ else:
137
+ for i in range(len(reg.coef_)):
138
+ reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
139
+ reg_i = reg.intercept_[i]
140
+ pred_w.append(reg_w)
141
+ pred_i.append(reg_i)
142
+ else:
143
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
144
+ raise Exception('Bad demographic type')
145
+ print(' done')
146
+ ret_obj.pred_stats = pred_stats
147
+ # Convert input to pytorch
148
+ print('Converting input to pytorch')
149
+ x = to_cuda(to_torch(x), vae.use_cuda)
150
+ # Convert demographics to pytorch
151
+ print('Converting demographics to pytorch')
152
+ demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
153
+ # Training loop
154
+ print('Beginning VAE training')
155
+ ce = nn.CrossEntropyLoss()
156
+ optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
157
+ for e in range(nepochs):
158
+ for bs in range(0,len(x),bsize):
159
+ xb = x[bs:(bs+bsize)]
160
+ db = demo_t[bs:(bs+bsize)]
161
+ optim.zero_grad()
162
+ # Reconstruct
163
+ z = vae.enc(xb)
164
+ y = vae.dec(z, db)
165
+ loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
166
+ loss_decor, _ = decor_loss(z, db, vae.use_cuda)
167
+ loss_decor = sum(loss_decor)
168
+ loss_rec = rmse(xb, y)
169
+ # Sample demographics
170
+ demo_gen = []
171
+ for s,t in zip(pred_stats, demo_types):
172
+ if t == 'continuous':
173
+ mu = s[0]
174
+ std = s[1]
175
+ dd = torch.randn(100).float()
176
+ dd = dd*std+mu
177
+ dd = to_cuda(dd, vae.use_cuda)
178
+ demo_gen.append(dd)
179
+ elif t == 'categorical':
180
+ idx = random.randint(0, len(s)-1)
181
+ for i in range(len(s)):
182
+ if idx == i:
183
+ dd = torch.ones(100).float()
184
+ else:
185
+ dd = torch.zeros(100).float()
186
+ dd = to_cuda(dd, vae.use_cuda)
187
+ demo_gen.append(dd)
188
+ demo_gen = torch.stack(demo_gen).permute(1,0)
189
+ # Generate
190
+ z = vae.gen(100)
191
+ y = vae.dec(z, demo_gen)
192
+ # Regressor/classifier guidance loss
193
+ losses_pred = []
194
+ idcs = []
195
+ dg_idx = 0
196
+ for s,t in zip(pred_stats, demo_types):
197
+ if t == 'continuous':
198
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
199
+ loss = rmse(demo_gen[:,dg_idx], yy)
200
+ losses_pred.append(loss)
201
+ idcs.append(float(demo_gen[0,dg_idx]))
202
+ dg_idx += 1
203
+ elif t == 'categorical':
204
+ loss = 0
205
+ for i in range(len(s)):
206
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
207
+ loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
208
+ idcs.append(int(demo_gen[0,dg_idx]))
209
+ dg_idx += 1
210
+ losses_pred.append(loss)
211
+ total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
212
+ total_loss.backward()
213
+ optim.step()
214
+ if e%pperiod == 0 or e == nepochs-1:
215
+ print(f'Epoch {e} ', end='')
216
+ print(f'ReconLoss {pretty(loss_rec)} ', end='')
217
+ print(f'CovarianceLoss {pretty(loss_C)} ', end='')
218
+ print(f'MeanLoss {pretty(loss_mu)} ', end='')
219
+ print(f'DecorLoss {pretty(loss_decor)} ', end='')
220
+ losses_pred = [pretty(loss) for loss in losses_pred]
221
+ print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
222
+ print()
223
+ print('Training complete.')
224
+
225
+
demovae/sklearn.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
3
+ import numpy as np
4
+
5
+ from sklearn.base import BaseEstimator
6
+
7
+ # For saving
8
+ import torch
9
+
10
+ class DemoVAE(BaseEstimator):
11
+ def __init__(self, **params):
12
+ self.set_params(**params)
13
+
14
+ @staticmethod
15
+ def get_default_params():
16
+ return dict(latent_dim=60, # Latent dimension
17
+ use_cuda=True, # GPU acceleration
18
+ nepochs=3000, # Training epochs
19
+ pperiod=100, # Epochs between printing updates
20
+ bsize=1000, # Batch size
21
+ loss_C_mult=1, # Covariance loss (KL div)
22
+ loss_mu_mult=1, # Mean loss (KL div)
23
+ loss_rec_mult=100, # Reconstruction loss
24
+ loss_decor_mult=10, # Latent-demographic decorrelation loss
25
+ loss_pred_mult=0.001, # Classifier/regressor guidance loss
26
+ alpha=100, # Regularization for continuous guidance models
27
+ LR_C=100, # Regularization for categorical guidance models
28
+ lr=1e-4, # Learning rate
29
+ weight_decay=0, # L2 regularization for VAE model
30
+ )
31
+
32
+ def get_params(self, **params):
33
+ return dict(latent_dim=self.latent_dim,
34
+ use_cuda=self.use_cuda,
35
+ nepochs=self.nepochs,
36
+ pperiod=self.pperiod,
37
+ bsize=self.bsize,
38
+ loss_C_mult=self.loss_C_mult,
39
+ loss_mu_mult=self.loss_mu_mult,
40
+ loss_rec_mult=self.loss_rec_mult,
41
+ loss_decor_mult=self.loss_decor_mult,
42
+ loss_pred_mult=self.loss_pred_mult,
43
+ alpha=self.alpha,
44
+ LR_C=self.LR_C,
45
+ lr=self.lr,
46
+ weight_decay=self.weight_decay,
47
+ )
48
+
49
+ def set_params(self, **params):
50
+ dft = DemoVAE.get_default_params()
51
+ for key in dft:
52
+ if key in params:
53
+ setattr(self, key, params[key])
54
+ else:
55
+ setattr(self, key, dft[key])
56
+ return self
57
+
58
+ def fit(self, x, demo, demo_types, **kwargs):
59
+ # Get demo_dim
60
+ demo_dim = 0
61
+ for d,t in zip(demo, demo_types):
62
+ if t == 'continuous':
63
+ demo_dim += 1
64
+ elif t == 'categorical':
65
+ ll = len(set(list(d)))
66
+ if ll == 1:
67
+ print('Only one type of category for categorical variable')
68
+ raise Exception('Bad categorical')
69
+ demo_dim += ll
70
+ else:
71
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
72
+ raise Exception('Bad demographic type')
73
+ # Save parameters
74
+ self.input_dim = x.shape[1]
75
+ self.demo_dim = demo_dim
76
+ # Create model
77
+ self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
78
+ # Train model
79
+ train_vae(self.vae, x, demo, demo_types,
80
+ self.nepochs, self.pperiod, self.bsize,
81
+ self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
82
+ self.lr, self.weight_decay, self.alpha, self.LR_C,
83
+ self)
84
+ return self
85
+
86
+ def transform(self, x, demo, demo_types, **kwargs):
87
+ if isinstance(x, int):
88
+ # Generate
89
+ z = self.vae.gen(x)
90
+ else:
91
+ # Get latents for real data
92
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
93
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
94
+ y = self.vae.dec(z, demo_t)
95
+ return to_numpy(y)
96
+
97
+ def fit_transform(self, x, demo, demo_types, **kwargs):
98
+ self.fit(x, demo, demo_types)
99
+ return self.transform(x, demo, demo_types)
100
+
101
+ def get_latents(self, x):
102
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
103
+ return to_numpy(z)
104
+
105
+ def save(self, path):
106
+ params = self.get_params()
107
+ dct = dict(pred_stats=self.pred_stats,
108
+ params=params,
109
+ input_dim=self.input_dim,
110
+ demo_dim=self.demo_dim,
111
+ model_state_dict=self.vae.state_dict())
112
+ torch.save(dct, path)
113
+
114
+ def load(self, path):
115
+ dct = torch.load(path)
116
+ self.pred_stats = dct['pred_stats']
117
+ self.set_params(**dct['params'])
118
+ self.vae = VAE(dct['input_dim'],
119
+ dct['params']['latent_dim'],
120
+ dct['demo_dim'],
121
+ dct['params']['use_cuda'])
122
+ self.vae.load_state_dict(dct['model_state_dict'])
123
+
124
+
fc_visualization.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FC Matrix Visualization Module.
3
+
4
+ This module provides functionality for visualizing Functional Connectivity matrices
5
+ independently from the prediction pipeline.
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from pathlib import Path
11
+ import argparse
12
+ import os
13
+ import nibabel as nib
14
+
15
+ try:
16
+ from nilearn import input_data, connectome
17
+ from nilearn.image import load_img
18
+ from nilearn import datasets
19
+ NILEARN_AVAILABLE = True
20
+ except ImportError:
21
+ NILEARN_AVAILABLE = False
22
+ print("Warning: nilearn not available. Direct fMRI processing disabled.")
23
+
24
+ from config import PREPROCESS_CONFIG
25
+
26
+ class FCVisualizer:
27
+ """Class for visualizing FC matrices."""
28
+
29
+ def __init__(self, cmap='RdBu_r', vmin=-1, vmax=1):
30
+ """
31
+ Initialize FCVisualizer with display parameters.
32
+
33
+ Args:
34
+ cmap: Colormap to use for FC matrices
35
+ vmin: Minimum value for color scaling
36
+ vmax: Maximum value for color scaling
37
+ """
38
+ self.cmap = cmap
39
+ self.vmin = vmin
40
+ self.vmax = vmax
41
+
42
+ def plot_single_matrix(self, matrix, title="FC Matrix", ax=None, fig=None):
43
+ """
44
+ Plot a single FC matrix.
45
+
46
+ Args:
47
+ matrix: 2D numpy array containing FC matrix
48
+ title: Title for the plot
49
+ ax: Matplotlib axis to plot on (optional)
50
+ fig: Matplotlib figure (optional)
51
+
52
+ Returns:
53
+ fig, ax: The figure and axis objects
54
+ """
55
+ if ax is None:
56
+ fig, ax = plt.subplots(figsize=(8, 6))
57
+
58
+ im = ax.imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
59
+ ax.set_title(title)
60
+ plt.colorbar(im, ax=ax)
61
+
62
+ return fig, ax
63
+
64
+ def plot_matrix_comparison(self, matrices, titles=None, figsize=None):
65
+ """
66
+ Plot multiple FC matrices for comparison.
67
+
68
+ Args:
69
+ matrices: List of 2D numpy arrays containing FC matrices
70
+ titles: List of titles for each matrix (optional)
71
+ figsize: Custom figure size (optional)
72
+
73
+ Returns:
74
+ fig: The figure object
75
+ """
76
+ n_matrices = len(matrices)
77
+
78
+ if figsize is None:
79
+ figsize = (5*n_matrices, 5)
80
+
81
+ if titles is None:
82
+ titles = [f"FC Matrix {i+1}" for i in range(n_matrices)]
83
+
84
+ fig, axes = plt.subplots(1, n_matrices, figsize=figsize)
85
+
86
+ # Handle single matrix case
87
+ if n_matrices == 1:
88
+ axes = [axes]
89
+
90
+ for i, (matrix, title) in enumerate(zip(matrices, titles)):
91
+ im = axes[i].imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
92
+ axes[i].set_title(title)
93
+ plt.colorbar(im, ax=axes[i])
94
+
95
+ plt.tight_layout()
96
+ return fig
97
+
98
+ def load_and_visualize_npy(self, file_path):
99
+ """
100
+ Load and visualize an FC matrix from a .npy file.
101
+
102
+ Args:
103
+ file_path: Path to the .npy file
104
+
105
+ Returns:
106
+ fig: The figure object containing the visualization
107
+ """
108
+ # Load the matrix
109
+ data = np.load(file_path)
110
+
111
+ # Check if it's an upper triangle or full matrix
112
+ if len(data.shape) == 1:
113
+ # Convert upper triangular to full matrix
114
+ matrix = self._triu_to_matrix(data)
115
+ else:
116
+ matrix = data
117
+
118
+ # Plot the matrix
119
+ filename = os.path.basename(file_path)
120
+ title = f"FC Matrix: {filename}"
121
+ fig, _ = self.plot_single_matrix(matrix, title=title)
122
+ return fig
123
+
124
+ def _triu_to_matrix(self, triu_values, fisher_z=True):
125
+ """
126
+ Convert upper triangular values to a full FC matrix.
127
+
128
+ Args:
129
+ triu_values: 1D array of upper triangular values
130
+ fisher_z: Whether values are Fisher z-transformed
131
+
132
+ Returns:
133
+ full_matrix: 2D symmetric matrix
134
+ """
135
+ # Calculate matrix size from triu length
136
+ n = int(np.sqrt(2 * len(triu_values) + 0.25) + 0.5)
137
+
138
+ # Initialize empty matrix
139
+ matrix = np.zeros((n, n))
140
+
141
+ # Get indices for upper triangle
142
+ triu_indices = np.triu_indices_from(matrix, k=1)
143
+
144
+ # If Fisher z-transformed, convert back
145
+ if fisher_z:
146
+ values_to_set = np.tanh(triu_values)
147
+ else:
148
+ values_to_set = triu_values
149
+
150
+ # Set upper triangle values
151
+ matrix[triu_indices] = values_to_set
152
+
153
+ # Make symmetric
154
+ matrix = matrix + matrix.T
155
+
156
+ # Set diagonal to 1.0 (perfect correlation)
157
+ np.fill_diagonal(matrix, 1.0)
158
+
159
+ return matrix
160
+
161
+ def process_and_visualize_fmri(self, fmri_file):
162
+ """
163
+ Process an fMRI file and visualize its FC matrix.
164
+
165
+ Args:
166
+ fmri_file: Path to the fMRI .nii or .nii.gz file
167
+
168
+ Returns:
169
+ fig: The figure object containing the visualization,
170
+ or None if processing fails
171
+ """
172
+ if not NILEARN_AVAILABLE:
173
+ print("Error: nilearn is required for fMRI processing")
174
+ return None
175
+
176
+ try:
177
+ # Extract FC matrix (upper triangular values)
178
+ fc_triu = self._process_single_fmri(fmri_file)
179
+
180
+ # Convert to full matrix
181
+ fc_matrix = self._triu_to_matrix(fc_triu)
182
+
183
+ # Plot the matrix
184
+ filename = os.path.basename(fmri_file)
185
+ title = f"FC Matrix: {filename}"
186
+ fig, _ = self.plot_single_matrix(fc_matrix, title=title)
187
+ return fig
188
+
189
+ except Exception as e:
190
+ print(f"Error processing fMRI file: {e}")
191
+ return None
192
+
193
+ def _process_single_fmri(self, fmri_file):
194
+ """
195
+ Process a single fMRI file to FC matrix.
196
+
197
+ Args:
198
+ fmri_file: Path to the fMRI .nii or .nii.gz file
199
+
200
+ Returns:
201
+ fc_triu: 1D array of upper triangular values (Fisher z-transformed)
202
+ """
203
+ print(f"Processing fMRI file: {fmri_file}")
204
+
205
+ # Use Power 264 atlas
206
+ power = datasets.fetch_coords_power_2011()
207
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
208
+
209
+ # Create masker
210
+ masker = input_data.NiftiSpheresMasker(
211
+ coords,
212
+ radius=PREPROCESS_CONFIG['radius'],
213
+ standardize=True,
214
+ memory='nilearn_cache',
215
+ memory_level=1,
216
+ verbose=0,
217
+ detrend=True,
218
+ low_pass=PREPROCESS_CONFIG['low_pass'],
219
+ high_pass=PREPROCESS_CONFIG['high_pass'],
220
+ t_r=PREPROCESS_CONFIG['t_r']
221
+ )
222
+
223
+ # Load and process fMRI
224
+ print(f"Loading NIfTI file...")
225
+ fmri_img = load_img(fmri_file)
226
+ print(f"NIfTI file loaded, shape: {fmri_img.shape}")
227
+
228
+ # Transform to time series
229
+ print(f"Extracting time series...")
230
+ time_series = masker.fit_transform(fmri_img)
231
+ print(f"Time series extracted, shape: {time_series.shape}")
232
+
233
+ # Compute FC matrix
234
+ print(f"Computing FC matrix...")
235
+ correlation_measure = connectome.ConnectivityMeasure(
236
+ kind='correlation',
237
+ vectorize=False,
238
+ discard_diagonal=False
239
+ )
240
+
241
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
242
+ print(f"FC matrix computed, shape: {fc_matrix.shape}")
243
+
244
+ # Get upper triangular part
245
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
246
+ fc_triu = fc_matrix[triu_indices]
247
+
248
+ # Fisher z-transform
249
+ fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99)) # Clip to avoid infinite values
250
+
251
+ print(f"Processing complete. FC features shape: {fc_triu.shape}")
252
+ return fc_triu
253
+
254
+
255
+ def create_synthetic_fc_matrix(seed=None):
256
+ """
257
+ Create a synthetic FC matrix for demonstration purposes.
258
+
259
+ Args:
260
+ seed: Random seed for reproducibility
261
+
262
+ Returns:
263
+ matrix: 2D symmetric matrix representing FC
264
+ """
265
+ if seed is not None:
266
+ np.random.seed(seed)
267
+
268
+ # Number of ROIs (Power atlas has 264)
269
+ n_rois = 264
270
+
271
+ # Create random correlation matrix
272
+ # Method: generate random normal values, create outer product, normalize
273
+ random_vectors = np.random.randn(n_rois, 50) # 50 random features
274
+ matrix = np.corrcoef(random_vectors)
275
+
276
+ # Ensure it's in the range [-1, 1] with 1s on diagonal
277
+ np.fill_diagonal(matrix, 1.0)
278
+
279
+ return matrix
280
+
281
+
282
+ def main():
283
+ """Command-line interface for FC matrix visualization."""
284
+ parser = argparse.ArgumentParser(description='Visualize FC matrices')
285
+ parser.add_argument('--input', type=str, help='Input file (fMRI .nii/.nii.gz or .npy FC matrix)')
286
+ parser.add_argument('--output', type=str, help='Output image file (PNG/JPG/PDF)')
287
+ parser.add_argument('--cmap', type=str, default='RdBu_r', help='Colormap (default: RdBu_r)')
288
+ parser.add_argument('--vmin', type=float, default=-1, help='Minimum value for colormap')
289
+ parser.add_argument('--vmax', type=float, default=1, help='Maximum value for colormap')
290
+ parser.add_argument('--synthetic', action='store_true', help='Generate a synthetic FC matrix')
291
+ parser.add_argument('--seed', type=int, default=42, help='Random seed for synthetic data')
292
+
293
+ args = parser.parse_args()
294
+
295
+ # Create visualizer
296
+ visualizer = FCVisualizer(cmap=args.cmap, vmin=args.vmin, vmax=args.vmax)
297
+
298
+ # Determine figure to create
299
+ fig = None
300
+
301
+ if args.synthetic:
302
+ # Create synthetic FC matrix
303
+ matrix = create_synthetic_fc_matrix(seed=args.seed)
304
+ fig, _ = visualizer.plot_single_matrix(matrix, title="Synthetic FC Matrix")
305
+
306
+ elif args.input:
307
+ input_path = Path(args.input)
308
+
309
+ if not input_path.exists():
310
+ print(f"Error: Input file not found: {args.input}")
311
+ return
312
+
313
+ # Process based on file type
314
+ if input_path.suffix == '.npy':
315
+ # It's a numpy file with FC matrix
316
+ fig = visualizer.load_and_visualize_npy(input_path)
317
+
318
+ elif input_path.suffix == '.nii' or input_path.suffix == '.gz':
319
+ # It's an fMRI file
320
+ if not NILEARN_AVAILABLE:
321
+ print("Error: nilearn is required for processing fMRI files")
322
+ return
323
+ fig = visualizer.process_and_visualize_fmri(input_path)
324
+
325
+ else:
326
+ print(f"Error: Unsupported file format: {input_path.suffix}")
327
+ print("Supported formats: .npy (FC matrix), .nii/.nii.gz (fMRI)")
328
+ return
329
+
330
+ else:
331
+ # No input or synthetic flag - show demo
332
+ print("No input file or --synthetic flag provided. Generating a demo matrix.")
333
+ matrix = create_synthetic_fc_matrix(seed=args.seed)
334
+ fig, _ = visualizer.plot_single_matrix(matrix, title="Demo FC Matrix")
335
+
336
+ # Save or display the figure
337
+ if fig is not None:
338
+ if args.output:
339
+ fig.savefig(args.output, dpi=300, bbox_inches='tight')
340
+ print(f"Visualization saved to {args.output}")
341
+ else:
342
+ plt.show()
343
+ print("Visualization displayed. Close the window to exit.")
344
+ else:
345
+ print("Error: Failed to create visualization")
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
huggingface_fc_visualization.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to visualize FC matrices from HuggingFace dataset, comparing original FC to VAE-generated FC.
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from datasets import load_dataset
9
+ from fc_visualization import FCVisualizer
10
+ from pathlib import Path
11
+ import tempfile
12
+ import requests
13
+ from config import DATASET_CONFIG, PREPROCESS_CONFIG, MODEL_CONFIG
14
+ from data_preprocessing import process_single_fmri
15
+ from vae_model import VariationalAutoencoder
16
+
17
+ def download_sample_fmri(dataset, temp_dir, max_samples=5):
18
+ """
19
+ Download sample fMRI files from HuggingFace dataset.
20
+
21
+ Args:
22
+ dataset: HuggingFace dataset object
23
+ temp_dir: Directory to save downloaded files
24
+ max_samples: Maximum number of samples to download
25
+
26
+ Returns:
27
+ list of paths to downloaded files, demographic data, and file keys
28
+ """
29
+ # Get first few samples to search for NIfTI files
30
+ nifti_keys = []
31
+
32
+ # Look through dataset features to find NIfTI files
33
+ for i, sample in enumerate(dataset):
34
+ if i >= 5: # Check first 5 samples
35
+ break
36
+
37
+ for key, value in sample.items():
38
+ if isinstance(value, str) and (value.endswith('.nii') or value.endswith('.nii.gz')):
39
+ if key not in nifti_keys:
40
+ nifti_keys.append(key)
41
+
42
+ print(f"Found {len(nifti_keys)} NIfTI file types in the dataset: {nifti_keys}")
43
+
44
+ if not nifti_keys:
45
+ print("No NIfTI files found in the dataset")
46
+ return [], [], []
47
+
48
+ # Collect nifti files and demographics
49
+ nifti_files = []
50
+ demo_data = []
51
+
52
+ # Process a limited number of samples
53
+ num_samples = min(max_samples, len(dataset))
54
+
55
+ for sample_idx in range(num_samples):
56
+ sample = dataset[sample_idx]
57
+
58
+ for key in nifti_keys:
59
+ try:
60
+ file_url = sample[key]
61
+ if not file_url or not isinstance(file_url, str):
62
+ continue
63
+
64
+ print(f"Processing sample {sample_idx+1}, file: {key}")
65
+
66
+ # Download and save the file
67
+ local_file = os.path.join(temp_dir, f"sample_{sample_idx}_{key}.nii.gz")
68
+ print(f"Downloading {file_url} to {local_file}")
69
+
70
+ response = requests.get(file_url)
71
+ with open(local_file, 'wb') as f:
72
+ f.write(response.content)
73
+
74
+ nifti_files.append(local_file)
75
+
76
+ # Extract demo data if available (or use placeholders)
77
+ age = sample.get('age', 65.0) if 'age' in sample else 65.0
78
+ sex = sample.get('sex', 'M') if 'sex' in sample else 'M'
79
+ mpo = sample.get('months_post_onset', 12.0) if 'months_post_onset' in sample else 12.0
80
+ wab = sample.get('wab_aq', 50.0) if 'wab_aq' in sample else 50.0
81
+
82
+ demo_sample = [age, sex, mpo, wab]
83
+ demo_data.append(demo_sample)
84
+
85
+ except Exception as e:
86
+ print(f"Error processing sample {sample_idx}, {key}: {e}")
87
+
88
+ return nifti_files, demo_data, nifti_keys
89
+
90
+ class VariationalAutoencoder:
91
+ """
92
+ Simplified VAE implementation for the visualization script.
93
+ """
94
+ def __init__(self, n_features, latent_dim, demo_data, demo_types, **kwargs):
95
+ """
96
+ Initialize the VAE.
97
+
98
+ Args:
99
+ n_features: Number of input features
100
+ latent_dim: Dimension of latent space
101
+ demo_data: Demographic data
102
+ demo_types: Types of demographic variables
103
+ **kwargs: Additional parameters
104
+ """
105
+ import torch
106
+ import torch.nn as nn
107
+
108
+ self.n_features = n_features
109
+ self.latent_dim = latent_dim
110
+ self.demo_dim = self._calculate_demo_dim(demo_data, demo_types)
111
+ self.nepochs = kwargs.get('nepochs', 100)
112
+ self.batch_size = kwargs.get('bsize', 8)
113
+ self.learning_rate = kwargs.get('lr', 1e-3)
114
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
115
+
116
+ # Build VAE model
117
+ self.encoder = nn.Sequential(
118
+ nn.Linear(n_features, 512),
119
+ nn.ReLU(),
120
+ nn.BatchNorm1d(512),
121
+ nn.Linear(512, 256),
122
+ nn.ReLU(),
123
+ nn.BatchNorm1d(256),
124
+ nn.Linear(256, latent_dim * 2) # mu and logvar
125
+ ).to(self.device)
126
+
127
+ self.decoder = nn.Sequential(
128
+ nn.Linear(latent_dim + self.demo_dim, 256),
129
+ nn.ReLU(),
130
+ nn.BatchNorm1d(256),
131
+ nn.Linear(256, 512),
132
+ nn.ReLU(),
133
+ nn.BatchNorm1d(512),
134
+ nn.Linear(512, n_features)
135
+ ).to(self.device)
136
+
137
+ self.optimizer = torch.optim.Adam(
138
+ list(self.encoder.parameters()) + list(self.decoder.parameters()),
139
+ lr=self.learning_rate
140
+ )
141
+
142
+ self.demo_stats = None # Will be set during training
143
+
144
+ def _calculate_demo_dim(self, demo_data, demo_types):
145
+ """Calculate dimension of demographic data after one-hot encoding"""
146
+ demo_dim = 0
147
+ for d, t in zip(demo_data, demo_types):
148
+ if t == 'continuous':
149
+ demo_dim += 1
150
+ elif t == 'categorical':
151
+ if isinstance(d[0], str):
152
+ # Get unique categories
153
+ unique_values = list(set(d))
154
+ demo_dim += len(unique_values)
155
+ else:
156
+ demo_dim += len(set(d))
157
+ return demo_dim
158
+
159
+ def _encode(self, x):
160
+ """Encode input data to latent space"""
161
+ import torch
162
+
163
+ x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device)
164
+ h = self.encoder(x_tensor)
165
+ mu, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:]
166
+ return mu, logvar
167
+
168
+ def _reparameterize(self, mu, logvar):
169
+ """Reparameterization trick for sampling from latent space"""
170
+ import torch
171
+
172
+ std = torch.exp(0.5 * logvar)
173
+ eps = torch.randn_like(std)
174
+ z = mu + eps * std
175
+ return z
176
+
177
+ def _decode(self, z, demo):
178
+ """Decode latent representation back to input space"""
179
+ import torch
180
+
181
+ # Concatenate latent code with demographic data
182
+ z_concat = torch.cat([z, demo], dim=1)
183
+ return self.decoder(z_concat)
184
+
185
+ def _prepare_demographics(self, demo_data, demo_types):
186
+ """Convert demographics to tensor with one-hot encoding for categorical variables"""
187
+ import torch
188
+ import numpy as np
189
+
190
+ if self.demo_stats is None:
191
+ # First time - compute stats
192
+ self.demo_stats = []
193
+ for d, t in zip(demo_data, demo_types):
194
+ if t == 'continuous':
195
+ # Standardize continuous features
196
+ self.demo_stats.append(('continuous', (np.mean(d), np.std(d))))
197
+ elif t == 'categorical':
198
+ # Record unique values for one-hot encoding
199
+ if isinstance(d[0], str):
200
+ unique_values = sorted(list(set(d)))
201
+ else:
202
+ unique_values = sorted(list(set(d)))
203
+ self.demo_stats.append(('categorical', unique_values))
204
+
205
+ # Process demographics based on saved stats
206
+ demo_tensors = []
207
+ for (d, (dtype, stats)) in zip(demo_data, self.demo_stats):
208
+ if dtype == 'continuous':
209
+ mean, std = stats
210
+ # Standardize
211
+ standardized = (np.array(d) - mean) / (std + 1e-10)
212
+ demo_tensors.append(torch.tensor(standardized, dtype=torch.float32).reshape(-1, 1))
213
+ else: # categorical
214
+ unique_values = stats
215
+ # One-hot encode
216
+ one_hot_vectors = []
217
+ for val in d:
218
+ try:
219
+ idx = unique_values.index(val)
220
+ vec = [0.0] * len(unique_values)
221
+ vec[idx] = 1.0
222
+ one_hot_vectors.append(vec)
223
+ except ValueError:
224
+ # Handle unseen categories - use all zeros
225
+ vec = [0.0] * len(unique_values)
226
+ one_hot_vectors.append(vec)
227
+ demo_tensors.append(torch.tensor(one_hot_vectors, dtype=torch.float32))
228
+
229
+ # Concatenate all demographic features
230
+ return torch.cat(demo_tensors, dim=1).to(self.device)
231
+
232
+ def fit(self, X, demo_data, demo_types):
233
+ """
234
+ Train the VAE model.
235
+
236
+ Args:
237
+ X: Input data (FC matrices)
238
+ demo_data: List of demographic variables
239
+ demo_types: Types of demographic variables
240
+ """
241
+ import torch
242
+ import torch.nn.functional as F
243
+ import numpy as np
244
+ from torch.utils.data import DataLoader, TensorDataset
245
+
246
+ print(f"Training VAE on {len(X)} samples for {self.nepochs} epochs...")
247
+
248
+ # Prepare demographic data
249
+ demo_tensor = self._prepare_demographics(demo_data, demo_types)
250
+
251
+ # Convert input data to tensor
252
+ X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
253
+
254
+ # Create dataset and dataloader
255
+ dataset = TensorDataset(X_tensor, demo_tensor)
256
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
257
+
258
+ # Training loop
259
+ self.train_losses = []
260
+
261
+ for epoch in range(self.nepochs):
262
+ epoch_losses = []
263
+
264
+ for batch_x, batch_demo in dataloader:
265
+ # Forward pass
266
+ mu, logvar = self._encode(batch_x)
267
+ z = self._reparameterize(mu, logvar)
268
+ x_recon = self._decode(z, batch_demo)
269
+
270
+ # Compute loss
271
+ recon_loss = F.mse_loss(x_recon, batch_x)
272
+ kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
273
+ kl_loss = kl_loss / batch_x.size(0) # Normalize by batch size
274
+
275
+ # Total loss
276
+ loss = recon_loss + 0.1 * kl_loss
277
+
278
+ # Backward and optimize
279
+ self.optimizer.zero_grad()
280
+ loss.backward()
281
+ self.optimizer.step()
282
+
283
+ epoch_losses.append(loss.item())
284
+
285
+ # Record average loss for this epoch
286
+ avg_loss = np.mean(epoch_losses)
287
+ self.train_losses.append(avg_loss)
288
+
289
+ # Print progress every 10 epochs
290
+ if (epoch + 1) % 10 == 0:
291
+ print(f"Epoch {epoch+1}/{self.nepochs}, Loss: {avg_loss:.6f}")
292
+
293
+ print("VAE training complete!")
294
+ return self.train_losses
295
+
296
+ def reconstruct(self, X, demo_data=None, demo_types=None):
297
+ """
298
+ Reconstruct input data.
299
+
300
+ Args:
301
+ X: Input data
302
+ demo_data: Demographic data (optional)
303
+ demo_types: Types of demographic variables (optional)
304
+
305
+ Returns:
306
+ Reconstructed data
307
+ """
308
+ import torch
309
+
310
+ # Set to evaluation mode
311
+ self.encoder.eval()
312
+ self.decoder.eval()
313
+
314
+ with torch.no_grad():
315
+ # Encode to latent space
316
+ mu, _ = self._encode(X)
317
+
318
+ # Use demo data if provided, otherwise use the demo data from training
319
+ if demo_data is not None and demo_types is not None:
320
+ demo_tensor = self._prepare_demographics(demo_data, demo_types)
321
+ else:
322
+ # This would fail if model wasn't trained
323
+ raise ValueError("Demo data and types must be provided for reconstruction")
324
+
325
+ # Decode
326
+ recon = self._decode(mu, demo_tensor)
327
+
328
+ # Convert to numpy
329
+ return recon.cpu().numpy()
330
+
331
+ def generate(self, n_samples, demo_data, demo_types):
332
+ """
333
+ Generate new samples from the latent space.
334
+
335
+ Args:
336
+ n_samples: Number of samples to generate
337
+ demo_data: Demographic data
338
+ demo_types: Types of demographic variables
339
+
340
+ Returns:
341
+ Generated samples
342
+ """
343
+ import torch
344
+
345
+ # Set to evaluation mode
346
+ self.decoder.eval()
347
+
348
+ with torch.no_grad():
349
+ # Sample from standard normal
350
+ z = torch.randn(n_samples, self.latent_dim).to(self.device)
351
+
352
+ # Prepare demographic data
353
+ demo_tensor = self._prepare_demographics(demo_data, demo_types)
354
+
355
+ # Check dimensions
356
+ if demo_tensor.shape[0] != n_samples:
357
+ # Handle mismatch - repeat the first demographic sample
358
+ if demo_tensor.shape[0] >= 1:
359
+ demo_tensor = demo_tensor[0].unsqueeze(0).repeat(n_samples, 1)
360
+
361
+ # Generate samples
362
+ generated = self._decode(z, demo_tensor)
363
+
364
+ # Convert to numpy
365
+ return generated.cpu().numpy()
366
+
367
+ def generate_comparison():
368
+ """Download, process and visualize FC matrices from the HuggingFace dataset,
369
+ comparing original to VAE-generated matrices."""
370
+ print("Loading dataset from HuggingFace...")
371
+
372
+ # Load the HuggingFace dataset using config
373
+ dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData')
374
+ dataset_split = DATASET_CONFIG.get('split', 'train')
375
+
376
+ dataset = load_dataset(dataset_name, split=dataset_split)
377
+ print(f"Dataset loaded: {dataset}")
378
+
379
+ # Create temporary directory for downloaded NIfTI files
380
+ temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
381
+ print(f"Created temp directory for NIfTI files: {temp_dir}")
382
+
383
+ # Download and process fMRI files
384
+ nifti_files, demo_samples, nifti_keys = download_sample_fmri(dataset, temp_dir, max_samples=5)
385
+
386
+ if not nifti_files:
387
+ print("No valid fMRI files were found")
388
+ return
389
+
390
+ # Process all fMRI files to FC matrices
391
+ fc_matrices = []
392
+ demo_data = []
393
+
394
+ for file_idx, (file_path, demo_sample) in enumerate(zip(nifti_files, demo_samples)):
395
+ try:
396
+ print(f"Processing file {file_idx+1}/{len(nifti_files)}: {file_path}")
397
+ fc_triu = process_single_fmri(file_path)
398
+ fc_matrices.append(fc_triu)
399
+ demo_data.append(demo_sample)
400
+ except Exception as e:
401
+ print(f"Error processing file {file_path}: {e}")
402
+
403
+ if not fc_matrices:
404
+ print("No valid FC matrices were generated")
405
+ return
406
+
407
+ # Convert to numpy arrays
408
+ X = np.array(fc_matrices)
409
+
410
+ # Normalize the data
411
+ X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
412
+
413
+ # Prepare demographic data
414
+ # Transpose to get [feature_type][sample] format
415
+ demo_data = np.array(demo_data).T.tolist()
416
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
417
+
418
+ # Train a VAE on the FC matrices
419
+ print("Training VAE on the FC matrices...")
420
+ n_features = X.shape[1]
421
+
422
+ # Configure a smaller/faster VAE for demonstration
423
+ vae = VariationalAutoencoder(
424
+ n_features=n_features,
425
+ latent_dim=MODEL_CONFIG.get('latent_dim', 32),
426
+ demo_data=demo_data,
427
+ demo_types=demo_types,
428
+ nepochs=100, # Reduced for demo
429
+ bsize=2,
430
+ lr=1e-3
431
+ )
432
+
433
+ # Train the VAE
434
+ vae.fit(X, demo_data, demo_types)
435
+
436
+ # Generate reconstructed FC matrices
437
+ print("Generating reconstructed FC matrices...")
438
+ reconstructed = vae.reconstruct(X, demo_data, demo_types)
439
+
440
+ # Generate a synthetic FC matrix
441
+ print("Generating a synthetic FC matrix...")
442
+ # For generating a new sample, we'll use demographics from first patient
443
+ first_demo_data = [[d[0]] for d in demo_data]
444
+ generated = vae.generate(1, first_demo_data, demo_types)
445
+
446
+ # Visualize original, reconstructed, and generated FC matrices
447
+ visualizer = FCVisualizer()
448
+
449
+ # Process each sample to generate comparisons
450
+ for i in range(min(3, len(X))):
451
+ # Convert upper triangular vectors to full matrices for visualization
452
+ original_matrix = visualizer._triu_to_matrix(X[i])
453
+ recon_matrix = visualizer._triu_to_matrix(reconstructed[i])
454
+
455
+ # Use the generate method for a single synthetic sample
456
+ if i == 0:
457
+ gen_matrix = visualizer._triu_to_matrix(generated[0])
458
+
459
+ # Visualize all three - original, reconstructed, generated
460
+ fig = visualizer.plot_matrix_comparison(
461
+ [original_matrix, recon_matrix, gen_matrix],
462
+ titles=["Original FC", "Reconstructed FC", "Generated FC"]
463
+ )
464
+
465
+ output_file = f"fc_comparison_with_generated.png"
466
+ fig.savefig(output_file, dpi=300, bbox_inches='tight')
467
+ print(f"Saved full comparison to {output_file}")
468
+
469
+ # Visualize original vs reconstructed for each sample
470
+ fig = visualizer.plot_matrix_comparison(
471
+ [original_matrix, recon_matrix],
472
+ titles=[f"Original FC (Sample {i+1})", f"Reconstructed FC (Sample {i+1})"]
473
+ )
474
+
475
+ output_file = f"sample_{i}_original_vs_reconstructed.png"
476
+ fig.savefig(output_file, dpi=300, bbox_inches='tight')
477
+ print(f"Saved comparison to {output_file}")
478
+
479
+ # Save the matrices
480
+ np.save(f"sample_{i}_original_fc.npy", original_matrix)
481
+ np.save(f"sample_{i}_reconstructed_fc.npy", recon_matrix)
482
+
483
+ # Save the generated matrix
484
+ np.save("generated_fc.npy", gen_matrix)
485
+
486
+ print("Processing complete")
487
+
488
+ if __name__ == "__main__":
489
+ generate_comparison()
main.py CHANGED
@@ -99,7 +99,6 @@ def run_analysis(data_dir="data",
99
  # Initialize and train treatment predictor
100
  print("Training treatment predictor...")
101
  predictor = AphasiaTreatmentPredictor(
102
- prediction_type=PREDICTION_CONFIG.get('prediction_type', 'regression'),
103
  n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
104
  max_depth=PREDICTION_CONFIG.get('max_depth', None)
105
  )
@@ -129,18 +128,11 @@ def run_analysis(data_dir="data",
129
 
130
  # For regression, get R2 metrics, otherwise use accuracy
131
  try:
132
- if predictor.prediction_type == "regression":
133
- cv_mean = mean_metrics.get("r2", 0.0)
134
- if fold_metrics and "r2" in fold_metrics[0]:
135
- cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
136
- else:
137
- cv_std = 0.0
138
  else:
139
- cv_mean = mean_metrics.get("accuracy", 0.0)
140
- if fold_metrics and "accuracy" in fold_metrics[0]:
141
- cv_std = np.std([fold.get("accuracy", 0.0) for fold in fold_metrics])
142
- else:
143
- cv_std = 0.0
144
  except Exception as e:
145
  print(f"Error calculating CV metrics: {e}")
146
  cv_mean, cv_std = 0.0, 0.0
 
99
  # Initialize and train treatment predictor
100
  print("Training treatment predictor...")
101
  predictor = AphasiaTreatmentPredictor(
 
102
  n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
103
  max_depth=PREDICTION_CONFIG.get('max_depth', None)
104
  )
 
128
 
129
  # For regression, get R2 metrics, otherwise use accuracy
130
  try:
131
+ cv_mean = mean_metrics.get("r2", 0.0)
132
+ if fold_metrics and "r2" in fold_metrics[0]:
133
+ cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
 
 
 
134
  else:
135
+ cv_std = 0.0
 
 
 
 
136
  except Exception as e:
137
  print(f"Error calculating CV metrics: {e}")
138
  cv_mean, cv_std = 0.0, 0.0
rcf_prediction.py CHANGED
@@ -1,8 +1,8 @@
1
  import numpy as np
2
- from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
- from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score
6
  import matplotlib.pyplot as plt
7
  import os
8
  import joblib
@@ -12,35 +12,27 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
12
  logger = logging.getLogger(__name__)
13
 
14
  class AphasiaTreatmentPredictor:
15
- def __init__(self, prediction_type="regression", n_estimators=100, max_depth=None, random_state=42):
16
  """
17
- Initialize the Treatment Predictor with Random Forest
18
 
19
  Args:
20
- prediction_type (str): "classification" or "regression" depending on outcome variable type
21
  n_estimators (int): Number of trees in the forest
22
  max_depth (int): Maximum depth of trees (None for unlimited)
23
  random_state (int): Random seed for reproducibility
24
  """
25
- self.prediction_type = prediction_type
26
  self.n_estimators = n_estimators
27
  self.max_depth = max_depth
28
  self.random_state = random_state
29
  self.feature_importance = None
30
  self.feature_names = None
31
 
32
- if prediction_type == "classification":
33
- self.model = RandomForestClassifier(
34
- n_estimators=n_estimators,
35
- max_depth=max_depth,
36
- random_state=random_state
37
- )
38
- else: # regression
39
- self.model = RandomForestRegressor(
40
- n_estimators=n_estimators,
41
- max_depth=max_depth,
42
- random_state=random_state
43
- )
44
 
45
  def prepare_features(self, latents, demographics):
46
  """
@@ -115,34 +107,11 @@ class AphasiaTreatmentPredictor:
115
  predictions = self.model.predict(X)
116
 
117
  # Get prediction intervals using tree variance
118
- if self.prediction_type == "regression":
119
- tree_predictions = np.array([tree.predict(X)
120
- for tree in self.model.estimators_])
121
- prediction_std = np.std(tree_predictions, axis=0)
122
- else: # classification
123
- # For classification, use probability as a measure of confidence
124
- proba = self.model.predict_proba(X)
125
- # Use max probability as confidence measure
126
- prediction_std = 1 - np.max(proba, axis=1)
127
 
128
  return predictions, prediction_std
129
-
130
- def predict_proba(self, latents, demographics):
131
- """
132
- Get probability estimates for classification
133
-
134
- Args:
135
- latents (np.ndarray): Latent representations from VAE
136
- demographics (dict or pd.DataFrame): Demographic information
137
-
138
- Returns:
139
- np.ndarray: Probability estimates for each class
140
- """
141
- if self.prediction_type != "classification":
142
- raise ValueError("Probability prediction only available for classification")
143
-
144
- X, _ = self.prepare_features(latents, demographics)
145
- return self.model.predict_proba(X)
146
 
147
  def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
148
  """
@@ -174,18 +143,11 @@ class AphasiaTreatmentPredictor:
174
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
175
 
176
  # Clone the model for this fold
177
- if self.prediction_type == "classification":
178
- fold_model = RandomForestClassifier(
179
- n_estimators=self.n_estimators,
180
- max_depth=self.max_depth,
181
- random_state=self.random_state
182
- )
183
- else:
184
- fold_model = RandomForestRegressor(
185
- n_estimators=self.n_estimators,
186
- max_depth=self.max_depth,
187
- random_state=self.random_state
188
- )
189
 
190
  # Train the model
191
  fold_model.fit(X_train, y_train)
@@ -197,38 +159,19 @@ class AphasiaTreatmentPredictor:
197
  predictions[test_idx] = pred
198
 
199
  # Calculate metrics
200
- if self.prediction_type == "regression":
201
- rmse = np.sqrt(mean_squared_error(y_test, pred))
202
- r2 = r2_score(y_test, pred)
203
- metrics = {
204
- "r2": r2,
205
- "rmse": rmse,
206
- "mse": rmse**2
207
- }
208
-
209
- # Get prediction intervals using tree variance
210
- tree_predictions = np.array([tree.predict(X_test)
211
- for tree in fold_model.estimators_])
212
- pred_std = np.std(tree_predictions, axis=0)
213
- prediction_stds[test_idx] = pred_std
214
-
215
- else: # classification
216
- acc = accuracy_score(y_test, pred)
217
- prec = precision_score(y_test, pred, average='weighted', zero_division=0)
218
- rec = recall_score(y_test, pred, average='weighted', zero_division=0)
219
- f1 = f1_score(y_test, pred, average='weighted', zero_division=0)
220
- metrics = {
221
- "accuracy": acc,
222
- "precision": prec,
223
- "recall": rec,
224
- "f1": f1
225
- }
226
-
227
- # Use probability as a measure of confidence
228
- proba = fold_model.predict_proba(X_test)
229
- # Use max probability as confidence measure
230
- pred_std = 1 - np.max(proba, axis=1)
231
- prediction_stds[test_idx] = pred_std
232
 
233
  fold_metrics.append(metrics)
234
  logger.info(f"Fold {fold+1} metrics: {metrics}")
@@ -335,7 +278,6 @@ class AphasiaTreatmentPredictor:
335
 
336
  # Create new instance
337
  predictor = cls(
338
- prediction_type=data['prediction_type'],
339
  n_estimators=data['n_estimators'],
340
  max_depth=data['max_depth'],
341
  random_state=data['random_state']
@@ -350,7 +292,7 @@ class AphasiaTreatmentPredictor:
350
  return predictor
351
 
352
 
353
- def train_predictor_from_latents(latents, outcomes, demographics=None, prediction_type="regression", cv=5, **kwargs):
354
  """
355
  Train a treatment outcome predictor from VAE latent representations
356
 
@@ -358,17 +300,16 @@ def train_predictor_from_latents(latents, outcomes, demographics=None, predictio
358
  latents (np.ndarray): Latent representations from VAE
359
  outcomes (np.ndarray): Treatment outcome values
360
  demographics (dict or pd.DataFrame, optional): Demographic information to include as features
361
- prediction_type (str): "classification" or "regression"
362
  cv (int): Number of folds for cross-validation
363
  **kwargs: Additional parameters for the AphasiaTreatmentPredictor
364
 
365
  Returns:
366
  dict: Training results and trained model
367
  """
368
- logger.info(f"Training {prediction_type} model for treatment prediction")
369
 
370
  # Create predictor
371
- predictor = AphasiaTreatmentPredictor(prediction_type=prediction_type, **kwargs)
372
 
373
  # Run cross-validation
374
  cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
 
1
  import numpy as np
2
+ from sklearn.ensemble import RandomForestRegressor
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
+ from sklearn.metrics import mean_squared_error, r2_score
6
  import matplotlib.pyplot as plt
7
  import os
8
  import joblib
 
12
  logger = logging.getLogger(__name__)
13
 
14
  class AphasiaTreatmentPredictor:
15
+ def __init__(self, n_estimators=100, max_depth=None, random_state=42):
16
  """
17
+ Initialize the Treatment Predictor with Random Forest Regressor
18
 
19
  Args:
 
20
  n_estimators (int): Number of trees in the forest
21
  max_depth (int): Maximum depth of trees (None for unlimited)
22
  random_state (int): Random seed for reproducibility
23
  """
24
+ self.prediction_type = "regression"
25
  self.n_estimators = n_estimators
26
  self.max_depth = max_depth
27
  self.random_state = random_state
28
  self.feature_importance = None
29
  self.feature_names = None
30
 
31
+ self.model = RandomForestRegressor(
32
+ n_estimators=n_estimators,
33
+ max_depth=max_depth,
34
+ random_state=random_state
35
+ )
 
 
 
 
 
 
 
36
 
37
  def prepare_features(self, latents, demographics):
38
  """
 
107
  predictions = self.model.predict(X)
108
 
109
  # Get prediction intervals using tree variance
110
+ tree_predictions = np.array([tree.predict(X)
111
+ for tree in self.model.estimators_])
112
+ prediction_std = np.std(tree_predictions, axis=0)
 
 
 
 
 
 
113
 
114
  return predictions, prediction_std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
117
  """
 
143
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
144
 
145
  # Clone the model for this fold
146
+ fold_model = RandomForestRegressor(
147
+ n_estimators=self.n_estimators,
148
+ max_depth=self.max_depth,
149
+ random_state=self.random_state
150
+ )
 
 
 
 
 
 
 
151
 
152
  # Train the model
153
  fold_model.fit(X_train, y_train)
 
159
  predictions[test_idx] = pred
160
 
161
  # Calculate metrics
162
+ rmse = np.sqrt(mean_squared_error(y_test, pred))
163
+ r2 = r2_score(y_test, pred)
164
+ metrics = {
165
+ "r2": r2,
166
+ "rmse": rmse,
167
+ "mse": rmse**2
168
+ }
169
+
170
+ # Get prediction intervals using tree variance
171
+ tree_predictions = np.array([tree.predict(X_test)
172
+ for tree in fold_model.estimators_])
173
+ pred_std = np.std(tree_predictions, axis=0)
174
+ prediction_stds[test_idx] = pred_std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  fold_metrics.append(metrics)
177
  logger.info(f"Fold {fold+1} metrics: {metrics}")
 
278
 
279
  # Create new instance
280
  predictor = cls(
 
281
  n_estimators=data['n_estimators'],
282
  max_depth=data['max_depth'],
283
  random_state=data['random_state']
 
292
  return predictor
293
 
294
 
295
+ def train_predictor_from_latents(latents, outcomes, demographics=None, cv=5, **kwargs):
296
  """
297
  Train a treatment outcome predictor from VAE latent representations
298
 
 
300
  latents (np.ndarray): Latent representations from VAE
301
  outcomes (np.ndarray): Treatment outcome values
302
  demographics (dict or pd.DataFrame, optional): Demographic information to include as features
 
303
  cv (int): Number of folds for cross-validation
304
  **kwargs: Additional parameters for the AphasiaTreatmentPredictor
305
 
306
  Returns:
307
  dict: Training results and trained model
308
  """
309
+ logger.info(f"Training regression model for treatment prediction")
310
 
311
  # Create predictor
312
+ predictor = AphasiaTreatmentPredictor(**kwargs)
313
 
314
  # Run cross-validation
315
  cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
src/.DS_Store CHANGED
Binary files a/src/.DS_Store and b/src/.DS_Store differ
 
vae_model.py CHANGED
@@ -26,17 +26,33 @@ class VAE(nn.Module):
26
  self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
27
 
28
  def enc(self, x):
29
- x = self.bn1(F.relu(self.enc1(x)))
30
- z = self.enc2(x)
 
 
 
 
 
 
 
31
  return z
32
 
33
  def gen(self, n):
34
  return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
35
 
36
  def dec(self, z, demo):
37
- z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
38
- x = self.bn2(F.relu(self.dec1(z)))
39
- x = self.dec2(x)
 
 
 
 
 
 
 
 
 
40
  return x
41
 
42
  class DemoVAE(BaseEstimator):
@@ -106,16 +122,77 @@ class DemoVAE(BaseEstimator):
106
  return train_losses, val_losses
107
 
108
  def transform(self, x, demo, demo_types):
109
- if isinstance(x, int):
110
- z = self.vae.gen(x)
111
- else:
112
- z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
113
- demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
114
- y = self.vae.dec(z, demo_t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return to_numpy(y)
116
 
117
  def get_latents(self, x):
118
- z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  return to_numpy(z)
120
 
121
  def save(self, path):
 
26
  self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
27
 
28
  def enc(self, x):
29
+ # First layer with activation
30
+ h = self.enc1(x)
31
+ h = F.relu(h)
32
+
33
+ # Apply batch norm - handle training vs eval mode automatically
34
+ h = self.bn1(h)
35
+
36
+ # Output layer
37
+ z = self.enc2(h)
38
  return z
39
 
40
  def gen(self, n):
41
  return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
42
 
43
  def dec(self, z, demo):
44
+ # Concatenate latent code with demographic data
45
+ z_combined = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
46
+
47
+ # First decoder layer with activation
48
+ h = self.dec1(z_combined)
49
+ h = F.relu(h)
50
+
51
+ # Apply batch norm - handle training vs eval mode automatically
52
+ h = self.bn2(h)
53
+
54
+ # Output layer
55
+ x = self.dec2(h)
56
  return x
57
 
58
  class DemoVAE(BaseEstimator):
 
122
  return train_losses, val_losses
123
 
124
  def transform(self, x, demo, demo_types):
125
+ # Set model to evaluation mode to handle batch norm with batch size of 1
126
+ self.vae.eval()
127
+
128
+ # Use torch.no_grad to disable gradient calculation during inference
129
+ with torch.no_grad():
130
+ if isinstance(x, int):
131
+ z = self.vae.gen(x)
132
+ else:
133
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
134
+
135
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
136
+
137
+ # Handle batch size of 1 for batch normalization
138
+ if z.size(0) == 1:
139
+ # If batch size is 1, we need to be careful with batch norm
140
+ # Clone and repeat the input to create a fake batch if needed
141
+ if hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2'):
142
+ try:
143
+ # Try normal decoding first
144
+ y = self.vae.dec(z, demo_t)
145
+ except Exception as e:
146
+ # If it fails, use a workaround for batch norm
147
+ print(f"Using batch norm workaround for inference: {e}")
148
+ # Create a batch by repeating the input
149
+ z_batch = z.repeat(2, 1)
150
+ demo_t_batch = demo_t.repeat(2, 1)
151
+ # Get the output and use only the first element
152
+ y_batch = self.vae.dec(z_batch, demo_t_batch)
153
+ y = y_batch[0:1]
154
+ else:
155
+ # No batch norm, proceed normally
156
+ y = self.vae.dec(z, demo_t)
157
+ else:
158
+ # Normal batch size, proceed as usual
159
+ y = self.vae.dec(z, demo_t)
160
+
161
  return to_numpy(y)
162
 
163
  def get_latents(self, x):
164
+ # Set model to evaluation mode
165
+ self.vae.eval()
166
+
167
+ # Use torch.no_grad for inference
168
+ with torch.no_grad():
169
+ try:
170
+ # Convert to torch tensor and move to CUDA if needed
171
+ x_tensor = to_cuda(to_torch(x), self.vae.use_cuda)
172
+
173
+ # Get latent representation
174
+ z = self.vae.enc(x_tensor)
175
+ except Exception as e:
176
+ print(f"Error in encoder: {e}")
177
+ # Try workaround for batch norm if needed
178
+ if x.shape[0] == 1 and (hasattr(self.vae, 'bn1') or hasattr(self.vae, 'bn2')):
179
+ print("Using batch normalization workaround for single sample")
180
+ # Repeat the input to create a batch of size 2
181
+ if len(x.shape) == 2:
182
+ x_batch = np.repeat(x, 2, axis=0)
183
+ else:
184
+ x_batch = np.array([x[0], x[0]])
185
+
186
+ # Process the batch
187
+ x_tensor = to_cuda(to_torch(x_batch), self.vae.use_cuda)
188
+ z_batch = self.vae.enc(x_tensor)
189
+
190
+ # Extract just the first sample's latent representation
191
+ z = z_batch[0:1]
192
+ else:
193
+ # Re-raise if we can't handle it
194
+ raise
195
+
196
  return to_numpy(z)
197
 
198
  def save(self, path):