SreekarB commited on
Commit
37a1b01
·
verified ·
1 Parent(s): 90774b7

Upload 13 files

Browse files
Files changed (3) hide show
  1. app.py +50 -4
  2. main.py +8 -1
  3. rcf_prediction.py +143 -11
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  from sklearn.metrics import mean_squared_error, r2_score
6
  import json
7
  import pickle
 
8
 
9
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
10
  """
@@ -89,6 +90,7 @@ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
89
  demographics = results.get('demographics')
90
  reconstructed_fc = results.get('reconstructed_fc')
91
  generated_fc = results.get('generated_fc')
 
92
 
93
  # Calculate accuracy metrics
94
  accuracy_metrics = {}
@@ -108,6 +110,37 @@ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
108
  if latents is not None and demographics is not None:
109
  latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
110
  print(f"Saved latents to {latents_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # Prepare status message with accuracy metrics
113
  if accuracy_metrics:
@@ -118,8 +151,20 @@ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
118
  f"• RMSE: {avg['RMSE']:.6f}\n"
119
  f"• R²: {avg['R²']:.6f}\n"
120
  f"• Correlation: {avg['Correlation']:.6f}\n"
121
- f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n"
122
- f"Latent representations saved to results/latents_dim{latent_dim}.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
123
  else:
124
  status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
125
  else:
@@ -216,8 +261,9 @@ def create_interface():
216
  1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
217
  2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
218
  3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
219
- 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
220
- 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
 
221
 
222
  Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
223
  """)
 
5
  from sklearn.metrics import mean_squared_error, r2_score
6
  import json
7
  import pickle
8
+ from rcf_prediction import train_predictor_from_latents
9
 
10
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
11
  """
 
90
  demographics = results.get('demographics')
91
  reconstructed_fc = results.get('reconstructed_fc')
92
  generated_fc = results.get('generated_fc')
93
+ outcome_measures = results.get('outcome_measures', None)
94
 
95
  # Calculate accuracy metrics
96
  accuracy_metrics = {}
 
110
  if latents is not None and demographics is not None:
111
  latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
112
  print(f"Saved latents to {latents_path}")
113
+
114
+ # Train a predictor model if we have outcome measures
115
+ predictor_results = None
116
+ if outcome_measures is not None and 'wab_aq' in outcome_measures:
117
+ try:
118
+ print("Training WAB-AQ prediction model from latent representations...")
119
+ wab_scores = np.array(outcome_measures['wab_aq'])
120
+ # Filter out any NaN values
121
+ valid_indices = ~np.isnan(wab_scores)
122
+ if np.sum(valid_indices) > 5: # Only train with sufficient data
123
+ filtered_latents = latents[valid_indices]
124
+ filtered_wab = wab_scores[valid_indices]
125
+
126
+ # Extract demographic features for the model
127
+ filtered_demographics = {}
128
+ for key, values in demographics.items():
129
+ if isinstance(values, (list, np.ndarray)) and len(values) >= len(valid_indices):
130
+ filtered_demographics[key] = np.array(values)[valid_indices]
131
+
132
+ # Train the prediction model with cross-validation
133
+ predictor_results = train_predictor_from_latents(
134
+ filtered_latents,
135
+ filtered_wab,
136
+ filtered_demographics,
137
+ cv=5, # 5-fold cross-validation
138
+ n_estimators=100, # Number of trees in Random Forest
139
+ prediction_type="regression"
140
+ )
141
+ print("WAB-AQ prediction model training complete!")
142
+ except Exception as e:
143
+ print(f"Error training prediction model: {str(e)}")
144
 
145
  # Prepare status message with accuracy metrics
146
  if accuracy_metrics:
 
151
  f"• RMSE: {avg['RMSE']:.6f}\n"
152
  f"• R²: {avg['R²']:.6f}\n"
153
  f"• Correlation: {avg['Correlation']:.6f}\n"
154
+ f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n")
155
+
156
+ # Add prediction model results if available
157
+ if predictor_results is not None:
158
+ cv_results = predictor_results.get('cv_results', {})
159
+ mean_metrics = cv_results.get('mean_metrics', {})
160
+ if mean_metrics and 'r2' in mean_metrics:
161
+ prediction_r2 = mean_metrics.get('r2', 0)
162
+ prediction_rmse = mean_metrics.get('rmse', 0)
163
+ status += (f"WAB-AQ Prediction Model Performance:\n"
164
+ f"• R²: {prediction_r2:.4f}\n"
165
+ f"• RMSE: {prediction_rmse:.4f}\n\n")
166
+
167
+ status += f"Latent representations saved to results/latents_dim{latent_dim}.pkl"
168
  else:
169
  status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
170
  else:
 
261
  1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
262
  2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
263
  3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
264
+ 4. **Predictive Modeling**: The system trains a Random Forest regressor on latent features to predict WAB-AQ scores (aphasia severity)
265
+ 5. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
266
+ 6. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
267
 
268
  Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
269
  """)
main.py CHANGED
@@ -246,6 +246,12 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
246
 
247
  # If requested, return additional data for accuracy calculations
248
  if return_data:
 
 
 
 
 
 
249
  results = {
250
  'vae': vae,
251
  'X': X,
@@ -253,7 +259,8 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
253
  'demographics': demographics,
254
  'reconstructed_fc': reconstructed_fc,
255
  'generated_fc': generated_fc,
256
- 'analysis_results': analysis_results
 
257
  }
258
  return fig, results
259
 
 
246
 
247
  # If requested, return additional data for accuracy calculations
248
  if return_data:
249
+ # Create a structured outcome measures dictionary
250
+ outcome_measures = {
251
+ 'wab_aq': demo_data[3], # WAB-AQ scores
252
+ # Could add other outcome measures here
253
+ }
254
+
255
  results = {
256
  'vae': vae,
257
  'X': X,
 
259
  'demographics': demographics,
260
  'reconstructed_fc': reconstructed_fc,
261
  'generated_fc': generated_fc,
262
+ 'analysis_results': analysis_results,
263
+ 'outcome_measures': outcome_measures
264
  }
265
  return fig, results
266
 
rcf_prediction.py CHANGED
@@ -54,10 +54,54 @@ class AphasiaTreatmentPredictor:
54
  tuple: Combined features array and feature names
55
  """
56
  if isinstance(demographics, dict):
57
- demo_df = pd.DataFrame(demographics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
  demo_df = demographics.copy()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Get categorical columns
62
  cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
63
 
@@ -71,7 +115,16 @@ class AphasiaTreatmentPredictor:
71
  feature_names = latent_names + demo_names
72
 
73
  # Combine latents with demographics
74
- features = np.hstack([latents, demo_df.values])
 
 
 
 
 
 
 
 
 
75
  return features, feature_names
76
 
77
  def fit(self, latents, demographics, treatment_outcomes):
@@ -90,6 +143,11 @@ class AphasiaTreatmentPredictor:
90
  self.feature_names = feature_names
91
 
92
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
 
 
 
 
 
93
  self.model.fit(X, treatment_outcomes)
94
 
95
  # Calculate feature importance
@@ -98,6 +156,7 @@ class AphasiaTreatmentPredictor:
98
  'importance': self.model.feature_importances_
99
  }).sort_values('importance', ascending=False)
100
 
 
101
  return self
102
 
103
  def predict(self, latents, demographics):
@@ -160,31 +219,54 @@ class AphasiaTreatmentPredictor:
160
  X, feature_names = self.prepare_features(latents, demographics)
161
  self.feature_names = feature_names
162
 
163
- logger.info(f"Running {n_splits}-fold cross-validation")
164
-
165
- kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  cv_scores = []
168
  predictions = np.zeros_like(treatment_outcomes)
169
  prediction_stds = np.zeros_like(treatment_outcomes)
170
  fold_metrics = []
171
 
172
- for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
173
  X_train, X_test = X[train_idx], X[test_idx]
174
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
175
 
 
 
176
  # Clone the model for this fold
177
  if self.prediction_type == "classification":
178
  fold_model = RandomForestClassifier(
179
  n_estimators=self.n_estimators,
180
  max_depth=self.max_depth,
181
- random_state=self.random_state
 
182
  )
183
  else:
184
  fold_model = RandomForestRegressor(
185
  n_estimators=self.n_estimators,
186
  max_depth=self.max_depth,
187
- random_state=self.random_state
 
188
  )
189
 
190
  # Train the model
@@ -199,13 +281,34 @@ class AphasiaTreatmentPredictor:
199
  # Calculate metrics
200
  if self.prediction_type == "regression":
201
  rmse = np.sqrt(mean_squared_error(y_test, pred))
202
- r2 = r2_score(y_test, pred)
 
 
 
 
 
 
 
 
 
 
 
203
  metrics = {
204
  "r2": r2,
205
  "rmse": rmse,
206
- "mse": rmse**2
207
  }
208
 
 
 
 
 
 
 
 
 
 
 
209
  # Get prediction intervals using tree variance
210
  tree_predictions = np.array([tree.predict(X_test)
211
  for tree in fold_model.estimators_])
@@ -233,14 +336,43 @@ class AphasiaTreatmentPredictor:
233
  fold_metrics.append(metrics)
234
  logger.info(f"Fold {fold+1} metrics: {metrics}")
235
 
 
 
 
 
 
 
 
 
 
 
 
236
  # Calculate average metrics
237
  avg_metrics = {}
238
  for key in fold_metrics[0].keys():
239
- avg_metrics[key] = np.mean([fold[key] for fold in fold_metrics])
 
 
 
 
 
240
 
241
  logger.info(f"Average CV metrics: {avg_metrics}")
242
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Train final model on all data
 
 
244
  self.model.fit(X, treatment_outcomes)
245
 
246
  # Calculate feature importance
 
54
  tuple: Combined features array and feature names
55
  """
56
  if isinstance(demographics, dict):
57
+ # For dictionary input, ensure all arrays are same length as latents
58
+ n_samples = latents.shape[0]
59
+ aligned_demos = {}
60
+
61
+ for key, values in demographics.items():
62
+ if len(values) != n_samples:
63
+ print(f"WARNING: Demographics '{key}' length ({len(values)}) doesn't match latents ({n_samples})")
64
+ # Truncate or pad to match latent samples
65
+ if len(values) > n_samples:
66
+ aligned_demos[key] = values[:n_samples] # Truncate
67
+ print(f" Truncated '{key}' to {n_samples} samples")
68
+ else:
69
+ # Pad with repeated values or zeros depending on type
70
+ if len(values) > 0:
71
+ # Use mean for numerical, mode for categorical
72
+ if isinstance(values[0], (int, float, np.number)):
73
+ filler = np.mean(values)
74
+ else:
75
+ # Use most common value
76
+ from collections import Counter
77
+ filler = Counter(values).most_common(1)[0][0]
78
+
79
+ padding = [filler] * (n_samples - len(values))
80
+ aligned_demos[key] = list(values) + padding
81
+ print(f" Padded '{key}' with {filler} to {n_samples} samples")
82
+ else:
83
+ # Empty array, fill with zeros
84
+ aligned_demos[key] = [0] * n_samples
85
+ print(f" Filled empty '{key}' with zeros to {n_samples} samples")
86
+ else:
87
+ aligned_demos[key] = values
88
+
89
+ demo_df = pd.DataFrame(aligned_demos)
90
  else:
91
  demo_df = demographics.copy()
92
 
93
+ # Ensure DataFrame has same number of rows as latents
94
+ if len(demo_df) != latents.shape[0]:
95
+ print(f"WARNING: Demographics DataFrame size ({len(demo_df)}) doesn't match latents ({latents.shape[0]})")
96
+ if len(demo_df) > latents.shape[0]:
97
+ demo_df = demo_df.iloc[:latents.shape[0]] # Truncate
98
+ print(f" Truncated demographics to {latents.shape[0]} samples")
99
+ else:
100
+ # Cannot easily pad DataFrame, use last row or means
101
+ print(f" ERROR: Cannot pad demographics DataFrame - using latents only")
102
+ # Create a DataFrame with the same columns but zeros
103
+ demo_df = pd.DataFrame(0, index=range(latents.shape[0]), columns=demo_df.columns)
104
+
105
  # Get categorical columns
106
  cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
107
 
 
115
  feature_names = latent_names + demo_names
116
 
117
  # Combine latents with demographics
118
+ try:
119
+ features = np.hstack([latents, demo_df.values])
120
+ except ValueError as e:
121
+ print(f"ERROR combining features: {e}")
122
+ print(f"Latents shape: {latents.shape}, Demographics shape: {demo_df.values.shape}")
123
+ # Fall back to using just latents
124
+ print("Falling back to using only latent features")
125
+ features = latents
126
+ feature_names = latent_names
127
+
128
  return features, feature_names
129
 
130
  def fit(self, latents, demographics, treatment_outcomes):
 
143
  self.feature_names = feature_names
144
 
145
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
146
+ print(f"Random Forest: Building {self.n_estimators} trees...")
147
+
148
+ # Track progress during fit with verbose
149
+ # Set verbose to 2 for detailed per-tree progress
150
+ self.model.verbose = 1
151
  self.model.fit(X, treatment_outcomes)
152
 
153
  # Calculate feature importance
 
156
  'importance': self.model.feature_importances_
157
  }).sort_values('importance', ascending=False)
158
 
159
+ print(f"Random Forest: Training complete. Top features: {', '.join(self.feature_importance['feature'].head(3).tolist())}")
160
  return self
161
 
162
  def predict(self, latents, demographics):
 
219
  X, feature_names = self.prepare_features(latents, demographics)
220
  self.feature_names = feature_names
221
 
222
+ # Adjust n_splits if we have too few samples
223
+ sample_count = len(treatment_outcomes)
224
+ if sample_count < n_splits * 2: # Need at least 2 samples per fold
225
+ adjusted_n_splits = max(2, sample_count // 2) # At least 2 folds, each with multiple samples
226
+ logger.warning(f"Too few samples ({sample_count}) for {n_splits} folds. Adjusting to {adjusted_n_splits} folds.")
227
+ print(f"Random Forest: Starting {adjusted_n_splits}-fold cross-validation with {sample_count} samples")
228
+ n_splits = adjusted_n_splits
229
+ else:
230
+ logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples")
231
+ print(f"Random Forest: Starting {n_splits}-fold cross-validation with {sample_count} samples")
232
+
233
+ # Use stratified KFold for regression to ensure balanced folds
234
+ # or LeaveOneOut for very small datasets
235
+ if sample_count <= 5:
236
+ from sklearn.model_selection import LeaveOneOut
237
+ logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples")
238
+ print(f"Random Forest: Using Leave-One-Out cross-validation due to small sample size ({sample_count})")
239
+ kf = LeaveOneOut()
240
+ cv_iterator = kf.split(X)
241
+ else:
242
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
243
+ cv_iterator = kf.split(X)
244
 
245
  cv_scores = []
246
  predictions = np.zeros_like(treatment_outcomes)
247
  prediction_stds = np.zeros_like(treatment_outcomes)
248
  fold_metrics = []
249
 
250
+ for fold, (train_idx, test_idx) in enumerate(cv_iterator):
251
  X_train, X_test = X[train_idx], X[test_idx]
252
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
253
 
254
+ print(f"Random Forest: Training fold {fold+1}/{n_splits} - {len(X_train)} training samples, {len(X_test)} test samples")
255
+
256
  # Clone the model for this fold
257
  if self.prediction_type == "classification":
258
  fold_model = RandomForestClassifier(
259
  n_estimators=self.n_estimators,
260
  max_depth=self.max_depth,
261
+ random_state=self.random_state,
262
+ verbose=1 # Add verbosity
263
  )
264
  else:
265
  fold_model = RandomForestRegressor(
266
  n_estimators=self.n_estimators,
267
  max_depth=self.max_depth,
268
+ random_state=self.random_state,
269
+ verbose=1 # Add verbosity
270
  )
271
 
272
  # Train the model
 
281
  # Calculate metrics
282
  if self.prediction_type == "regression":
283
  rmse = np.sqrt(mean_squared_error(y_test, pred))
284
+
285
+ # R-squared requires at least 2 samples and some variance in the target
286
+ if len(y_test) >= 2 and np.var(y_test) > 1e-10:
287
+ r2 = r2_score(y_test, pred)
288
+ else:
289
+ r2 = np.nan
290
+ logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)")
291
+ print(f"Random Forest: Fold {fold+1} - R² not calculated (insufficient samples or variance)")
292
+
293
+ # MSE can always be calculated
294
+ mse = rmse**2
295
+
296
  metrics = {
297
  "r2": r2,
298
  "rmse": rmse,
299
+ "mse": mse
300
  }
301
 
302
+ # Add other useful metrics if there are enough samples
303
+ if len(y_test) >= 2 and np.var(y_test) > 1e-10:
304
+ from sklearn.metrics import explained_variance_score
305
+ try:
306
+ ev = explained_variance_score(y_test, pred)
307
+ metrics["explained_variance"] = ev
308
+ except:
309
+ # Skip if it can't be calculated
310
+ pass
311
+
312
  # Get prediction intervals using tree variance
313
  tree_predictions = np.array([tree.predict(X_test)
314
  for tree in fold_model.estimators_])
 
336
  fold_metrics.append(metrics)
337
  logger.info(f"Fold {fold+1} metrics: {metrics}")
338
 
339
+ # Print a more user-friendly version of the fold results
340
+ if self.prediction_type == "regression":
341
+ r2_val = metrics.get('r2', np.nan)
342
+ rmse_val = metrics.get('rmse', np.nan)
343
+ r2_text = f"R² = {r2_val:.4f}" if not np.isnan(r2_val) else "R² = N/A"
344
+ print(f"Random Forest: Fold {fold+1} results - {r2_text}, RMSE = {rmse_val:.4f}")
345
+ else:
346
+ acc_val = metrics.get('accuracy', 0)
347
+ f1_val = metrics.get('f1', 0)
348
+ print(f"Random Forest: Fold {fold+1} results - Accuracy = {acc_val:.4f}, F1 = {f1_val:.4f}")
349
+
350
  # Calculate average metrics
351
  avg_metrics = {}
352
  for key in fold_metrics[0].keys():
353
+ # Filter out nan values when calculating means
354
+ values = [fold[key] for fold in fold_metrics if key in fold and not (isinstance(fold[key], float) and np.isnan(fold[key]))]
355
+ if values: # Only calculate mean if we have valid values
356
+ avg_metrics[key] = np.mean(values)
357
+ else:
358
+ avg_metrics[key] = np.nan
359
 
360
  logger.info(f"Average CV metrics: {avg_metrics}")
361
 
362
+ # Print a summary of cross-validation performance
363
+ if self.prediction_type == "regression":
364
+ r2_avg = avg_metrics.get('r2', np.nan)
365
+ rmse_avg = avg_metrics.get('rmse', np.nan)
366
+ r2_text = f"R² = {r2_avg:.4f}" if not np.isnan(r2_avg) else "R² = N/A"
367
+ print(f"Random Forest: Cross-validation complete - Average {r2_text}, RMSE = {rmse_avg:.4f}")
368
+ else:
369
+ acc_avg = avg_metrics.get('accuracy', 0)
370
+ f1_avg = avg_metrics.get('f1', 0)
371
+ print(f"Random Forest: Cross-validation complete - Average Accuracy = {acc_avg:.4f}, F1 = {f1_avg:.4f}")
372
+
373
  # Train final model on all data
374
+ print(f"Random Forest: Training final model on all {len(X)} samples...")
375
+ self.model.verbose = 1
376
  self.model.fit(X, treatment_outcomes)
377
 
378
  # Calculate feature importance