SreekarB commited on
Commit
e88139d
·
verified ·
1 Parent(s): 763369a

Upload 12 files

Browse files
app.py CHANGED
@@ -2395,46 +2395,54 @@ def create_interface():
2395
  vae.load(vae_path)
2396
  app_state['vae'] = vae
2397
 
2398
- # We also need latent representations for RF training
2399
- # Use synthetic data if no real data is available
2400
- from data_preprocessing import generate_synthetic_fc_matrices
2401
- synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2402
- logger.info("Generating latent representations from synthetic data...")
2403
 
2404
- latents = vae.encode(synthetic_fc, synthetic_demo)
2405
- app_state['latents'] = latents
2406
- app_state['demographics'] = synthetic_demo
2407
  app_state['vae_trained'] = True
2408
- logger.info("Loaded VAE model and generated synthetic latents")
2409
- else:
2410
- # Train a simple VAE with synthetic data
2411
- from vae_model import DemoVAE
2412
- from data_preprocessing import generate_synthetic_fc_matrices
2413
-
2414
- logger.info("VAE model not found. Training a simple model with synthetic data...")
2415
-
2416
- # Generate synthetic data
2417
- synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2418
-
2419
- # Train a simple VAE
2420
- vae = DemoVAE(latent_dim=10)
2421
- vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
2422
 
2423
- # Get latent representations
2424
- latents = vae.encode(synthetic_fc, synthetic_demo)
2425
-
2426
- # Save in app_state
2427
- app_state['vae'] = vae
2428
- app_state['latents'] = latents
2429
- app_state['demographics'] = synthetic_demo
2430
- app_state['vae_trained'] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2431
 
2432
- # Save the model for future use
2433
- if not os.path.exists('models'):
2434
- os.makedirs('models')
2435
- vae.save('models/vae_model.pt')
2436
 
2437
- logger.info("Trained and saved a simple VAE model with synthetic data")
 
 
 
 
 
 
 
2438
  except Exception as e:
2439
  error_fig = plt.figure(figsize=(10, 6))
2440
  message = f"Error: Unable to load or train VAE model: {str(e)}"
@@ -2635,55 +2643,15 @@ def create_interface():
2635
  app_state['rf_trained'] = True
2636
  rf_loaded = True
2637
 
2638
- # If we couldn't load both models, train quick synthetic models
2639
  if not (vae_loaded and rf_loaded):
2640
- logger.info("Training synthetic models for demo purposes...")
2641
-
2642
- # Generate synthetic data
2643
- from data_preprocessing import generate_synthetic_fc_matrices
2644
- synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2645
-
2646
- # Train VAE if needed
2647
- if not vae_loaded:
2648
- vae = DemoVAE(latent_dim=10)
2649
- vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
2650
- app_state['vae'] = vae
2651
- app_state['vae_trained'] = True
2652
-
2653
- # Save for future use
2654
- if not os.path.exists('models'):
2655
- os.makedirs('models')
2656
- vae.save('models/vae_model.pt')
2657
- else:
2658
- vae = app_state['vae']
2659
-
2660
- # Get latent representations for RF training
2661
- latents = vae.encode(synthetic_fc, synthetic_demo)
2662
 
2663
- # Train RF if needed
2664
- if not rf_loaded:
2665
- from main import RandomForestPredictor
2666
-
2667
- # Create synthetic outcome data
2668
- import numpy as np
2669
- outcomes = np.random.normal(50, 10, size=len(synthetic_demo))
2670
-
2671
- # Train the RF model
2672
- predictor = RandomForestPredictor()
2673
- predictor.train(latents, outcomes)
2674
-
2675
- app_state['predictor'] = predictor
2676
- app_state['rf_trained'] = True
2677
-
2678
- # Save for future use
2679
- if not os.path.exists('models'):
2680
- os.makedirs('models')
2681
- torch.save({
2682
- 'predictor_state': predictor.model,
2683
- 'feature_importance': predictor.feature_importance
2684
- }, 'models/predictor_model.pt')
2685
-
2686
- logger.info("Successfully trained synthetic models for demo")
2687
  except Exception as e:
2688
  error_message = f"Error: Unable to load or train required models: {str(e)}"
2689
  error_fig = plt.figure(figsize=(10, 6))
 
2395
  vae.load(vae_path)
2396
  app_state['vae'] = vae
2397
 
2398
+ # Only use real data for training and visualization
2399
+ logger.info("Using loaded VAE model with real data only...")
 
 
 
2400
 
2401
+ # Set flag to indicate VAE model is loaded, but not using synthetic data
 
 
2402
  app_state['vae_trained'] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2403
 
2404
+ # Try to load previously saved latents if they exist
2405
+ if os.path.exists('results/latents.npy'):
2406
+ try:
2407
+ logger.info("Loading saved latent representations...")
2408
+ latents = np.load('results/latents.npy')
2409
+ app_state['latents'] = latents
2410
+ logger.info(f"Loaded {len(latents)} real latent vectors")
2411
+
2412
+ # Try to load real demographics if available
2413
+ if os.path.exists('temp_demographics.csv'):
2414
+ logger.info("Loading demographics from temp_demographics.csv")
2415
+ demo_df = pd.read_csv('temp_demographics.csv')
2416
+ app_state['demographics'] = {
2417
+ 'age_at_stroke': demo_df['age'].values,
2418
+ 'sex': demo_df['sex'].values,
2419
+ 'months_post_stroke': demo_df['months_post_stroke'].values,
2420
+ 'wab_score': demo_df['wab_score'].values
2421
+ }
2422
+ else:
2423
+ logger.warning("No real demographic data found")
2424
+ except Exception as e:
2425
+ logger.error(f"Error loading real latents: {e}")
2426
+ logger.warning("Will not use synthetic data")
2427
+ else:
2428
+ logger.warning("No real latent representations found")
2429
+ logger.warning("Will not use synthetic data")
2430
+ else:
2431
+ # Don't train with synthetic data in strict real data mode
2432
+ logger.info("VAE model not found and using strict real data mode.")
2433
+ logger.warning("Cannot train VAE model without real data")
2434
 
2435
+ # Set state to indicate VAE is not trained
2436
+ app_state['vae_trained'] = False
 
 
2437
 
2438
+ # Show message about requiring real data
2439
+ status_msg = "No VAE model available. Please train with real data first."
2440
+ return {
2441
+ tab_rf: gr.update(visible=False),
2442
+ tab_vae: gr.update(visible=True),
2443
+ status: status_msg,
2444
+ vae_status: "Model not trained. Upload real data and train with it."
2445
+ }
2446
  except Exception as e:
2447
  error_fig = plt.figure(figsize=(10, 6))
2448
  message = f"Error: Unable to load or train VAE model: {str(e)}"
 
2643
  app_state['rf_trained'] = True
2644
  rf_loaded = True
2645
 
2646
+ # If we couldn't load both models in strict real data mode
2647
  if not (vae_loaded and rf_loaded):
2648
+ logger.info("Strict real data mode: Not using synthetic data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2649
 
2650
+ # Show a message to the user
2651
+ return {
2652
+ status: "Cannot use synthetic data in strict real data mode. Please train with real data first.",
2653
+ rf_status: "Not trained. Upload real data and train the VAE model first."
2654
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2655
  except Exception as e:
2656
  error_message = f"Error: Unable to load or train required models: {str(e)}"
2657
  error_fig = plt.figure(figsize=(10, 6))
cache/.DS_Store ADDED
Binary file (6.15 kB). View file
 
cache/atlas/power_2011_coords.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f0e80988258ff3da5522d409679d81d633a3bd1c39c4a23494926101eb852eb
3
+ size 6464
config.py CHANGED
@@ -30,7 +30,8 @@ PREDICTION_CONFIG = {
30
  'default_outcome': 'wab_aq',
31
  'save_path': 'results/treatment_predictor.joblib',
32
  'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
33
- 'use_synthetic_nifti': False, # Set to False to use only real NIfTI data
34
- 'use_synthetic_fc': False, # Set to False to use only real FC matrices
35
- 'strict_real_data': True # Set to True to strictly use real data only
 
36
  }
 
30
  'default_outcome': 'wab_aq',
31
  'save_path': 'results/treatment_predictor.joblib',
32
  'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
33
+ 'use_synthetic_nifti': False, # Set to False to use only real NIfTI data
34
+ 'use_synthetic_fc': False, # Set to False to use only real FC matrices
35
+ 'strict_real_data': True, # Set to True to strictly use real data only
36
+ 'no_mock_data': True # Set to True to prevent using any mock or synthetic data
37
  }
main.py CHANGED
@@ -204,38 +204,46 @@ def run_analysis(data_dir="data",
204
  print("Creating learning curve visualization...")
205
 
206
  # Check if losses are stored in the VAE object first (most reliable source)
207
- if hasattr(vae, 'train_losses') and hasattr(vae, 'val_losses'):
208
- if len(vae.train_losses) > 0 and len(vae.val_losses) > 0:
209
- print(f"Using learning curves from VAE object: {len(vae.train_losses)} train, {len(vae.val_losses)} validation points")
210
- learning_fig = plot_learning_curves(vae.train_losses, vae.val_losses)
211
- else:
212
- # Fall back to the losses passed directly
213
- if train_losses and val_losses:
214
- print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
215
- learning_fig = plot_learning_curves(train_losses, val_losses)
216
- else:
217
- # Create a placeholder
218
- print("No training history available for learning curves")
219
- learning_fig = plt.figure(figsize=(10, 6))
220
- plt.text(0.5, 0.5, "Learning curve data unavailable",
221
- ha='center', va='center', transform=plt.gca().transAxes,
222
- fontsize=14, color='darkred')
223
- plt.axis('off')
224
- plt.tight_layout()
225
  else:
226
- # Fall back to the losses passed directly
227
- if train_losses and val_losses:
228
- print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
229
- learning_fig = plot_learning_curves(train_losses, val_losses)
230
- else:
231
- # Create a placeholder
232
- print("No training history available for learning curves")
233
- learning_fig = plt.figure(figsize=(10, 6))
234
- plt.text(0.5, 0.5, "Learning curve data unavailable",
235
- ha='center', va='center', transform=plt.gca().transAxes,
236
- fontsize=14, color='darkred')
237
- plt.axis('off')
238
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  except Exception as e:
240
  import traceback
241
  print(f"Error creating learning curve plot: {e}")
@@ -249,16 +257,41 @@ def run_analysis(data_dir="data",
249
  plt.axis('off')
250
  plt.tight_layout()
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # Initialize results dictionary
253
  results = {
254
  'vae': vae,
255
  'latents': latents,
256
  'demographics': demographics,
257
- 'figures': {
258
- 'vae': fc_fig,
259
- 'fc_analysis': fc_fig,
260
- 'learning_curves': learning_fig
261
- }
262
  }
263
 
264
  # Add reconstructed and generated FC if available
 
204
  print("Creating learning curve visualization...")
205
 
206
  # Check if losses are stored in the VAE object first (most reliable source)
207
+ train_data = []
208
+ val_data = []
209
+
210
+ # Only use real data from VAE object or training results
211
+ if hasattr(vae, 'train_losses') and len(getattr(vae, 'train_losses', [])) > 0:
212
+ train_data = vae.train_losses
213
+ print(f"Found {len(train_data)} real training loss points in VAE object")
214
+ elif train_losses and len(train_losses) > 0:
215
+ train_data = train_losses
216
+ print(f"Using {len(train_data)} real training loss points from fit return value")
 
 
 
 
 
 
 
 
217
  else:
218
+ # Instead of synthetic data, provide empty list and warning
219
+ print("WARNING: No real training loss data found")
220
+ train_data = []
221
+
222
+ # Do the same for validation data
223
+ if hasattr(vae, 'val_losses') and len(getattr(vae, 'val_losses', [])) > 0:
224
+ val_data = vae.val_losses
225
+ print(f"Found {len(val_data)} real validation loss points in VAE object")
226
+ elif val_losses and len(val_losses) > 0:
227
+ val_data = val_losses
228
+ print(f"Using {len(val_data)} real validation loss points from fit return value")
229
+ else:
230
+ # Instead of synthetic data, provide empty list and warning
231
+ print("WARNING: No real validation loss data found")
232
+ val_data = []
233
+
234
+ # If we get here, we have some training data (real or synthetic)
235
+ # Store the data in the VAE object for future use
236
+ if not hasattr(vae, 'train_losses') or len(getattr(vae, 'train_losses', [])) == 0:
237
+ print("Storing training loss data in VAE object")
238
+ vae.train_losses = train_data
239
+
240
+ if not hasattr(vae, 'val_losses') or len(getattr(vae, 'val_losses', [])) == 0:
241
+ print("Storing validation loss data in VAE object")
242
+ vae.val_losses = val_data
243
+
244
+ # Now create the visualization using the data we collected
245
+ print(f"Creating learning curve with {len(train_data)} training and {len(val_data)} validation points")
246
+ learning_fig = plot_learning_curves(train_data, val_data)
247
  except Exception as e:
248
  import traceback
249
  print(f"Error creating learning curve plot: {e}")
 
257
  plt.axis('off')
258
  plt.tight_layout()
259
 
260
+ # Check if we should use strict real data mode
261
+ use_strict_real_data = PREDICTION_CONFIG.get('strict_real_data', False)
262
+ no_mock_data = PREDICTION_CONFIG.get('no_mock_data', False)
263
+
264
+ if use_strict_real_data or no_mock_data:
265
+ print("Using strict real data mode - only including real data in results")
266
+ # Only include figures if they contain real data
267
+ figures = {}
268
+ if hasattr(vae, 'train_losses') and len(vae.train_losses) > 0:
269
+ figures['learning_curves'] = learning_fig
270
+ print("Including real learning curves")
271
+ else:
272
+ print("WARNING: No real learning curve data available")
273
+
274
+ # Only include FC analysis if it's based on real data
275
+ if len(np.array(X).shape) > 0 and len(X) > 0:
276
+ figures['vae'] = fc_fig
277
+ figures['fc_analysis'] = fc_fig
278
+ print("Including real FC analysis")
279
+ else:
280
+ print("WARNING: No real FC data available")
281
+ else:
282
+ # Include all figures, even if based on synthetic data
283
+ figures = {
284
+ 'vae': fc_fig,
285
+ 'fc_analysis': fc_fig,
286
+ 'learning_curves': learning_fig
287
+ }
288
+
289
  # Initialize results dictionary
290
  results = {
291
  'vae': vae,
292
  'latents': latents,
293
  'demographics': demographics,
294
+ 'figures': figures
 
 
 
 
295
  }
296
 
297
  # Add reconstructed and generated FC if available
test_learning.png ADDED
test_loading.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Set Huggingface cache directory to avoid permission issues
3
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
4
+ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
5
+ os.makedirs('models', exist_ok=True)
6
+
7
+ import numpy as np
8
+ import torch
9
+ from vae_model import DemoVAE
10
+ import matplotlib.pyplot as plt
11
+ from visualization import plot_learning_curves
12
+
13
+ print("Creating synthetic test data...")
14
+ # Create small synthetic dataset with only 5 samples
15
+ input_dim = 100
16
+ n_samples = 5
17
+ X = np.random.randn(n_samples, input_dim)
18
+ demo_data = [
19
+ np.random.normal(60, 10, n_samples), # age
20
+ np.random.choice(['M', 'F'], n_samples), # sex
21
+ np.random.normal(24, 12, n_samples), # months post stroke
22
+ np.random.normal(50, 15, n_samples) # WAB score
23
+ ]
24
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
25
+
26
+ print("Testing DemoVAE initialization...")
27
+ # Initialize with nepochs=3 for fast testing
28
+ vae = DemoVAE(latent_dim=16, nepochs=3, bsize=5)
29
+
30
+ print("Testing DemoVAE fit method...")
31
+ # Fit model
32
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
33
+
34
+ print(f"Train losses shape: {len(train_losses)}")
35
+ print(f"Val losses shape: {len(val_losses)}")
36
+
37
+ print("Testing get_latents method...")
38
+ # Test get_latents
39
+ latents = vae.get_latents(X)
40
+ print(f"Latents shape: {latents.shape}")
41
+
42
+ print("Testing encode method...")
43
+ # Test encode
44
+ latents2 = vae.encode(X)
45
+ print(f"Latents from encode shape: {latents2.shape}")
46
+
47
+ print("Testing model save...")
48
+ # Save model
49
+ vae.save('models/test_vae.pt')
50
+
51
+ print("Testing model load...")
52
+ # Load model
53
+ vae2 = DemoVAE()
54
+ vae2.load('models/test_vae.pt')
55
+
56
+ print("Testing learning curve plotting...")
57
+ # Test learning curve plotting
58
+ fig = plot_learning_curves(vae2.train_losses, vae2.val_losses)
59
+ plt.savefig('test_learning.png')
60
+ print("Learning curve saved to test_learning.png")
61
+
62
+ print("All tests passed!")
vae_model.py CHANGED
@@ -234,6 +234,10 @@ class DemoVAE(BaseEstimator):
234
  print(f"Returning fallback output with shape: {fallback.shape}")
235
  return fallback
236
 
 
 
 
 
237
  def get_latents(self, x):
238
  # Set model to evaluation mode
239
  self.vae.eval()
 
234
  print(f"Returning fallback output with shape: {fallback.shape}")
235
  return fallback
236
 
237
+ def encode(self, x):
238
+ """Alias for get_latents method - to provide compatibility with some interfaces"""
239
+ return self.get_latents(x)
240
+
241
  def get_latents(self, x):
242
  # Set model to evaluation mode
243
  self.vae.eval()
visualization.py CHANGED
@@ -397,14 +397,22 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
397
  def plot_learning_curves(train_losses, val_losses):
398
  """Plot VAE learning curves with enhanced visualization"""
399
  try:
400
- # Handle empty or None inputs
401
- if not train_losses or train_losses is None:
402
- print("WARNING: No training loss data provided")
403
- train_losses = [0.0]
 
 
 
 
 
 
 
404
 
405
- if not val_losses or val_losses is None:
406
- print("WARNING: No validation loss data provided")
407
- val_losses = [0.0]
 
408
 
409
  # Convert to numpy arrays for safe handling
410
  train_np = np.array(train_losses)
 
397
  def plot_learning_curves(train_losses, val_losses):
398
  """Plot VAE learning curves with enhanced visualization"""
399
  try:
400
+ # Handle empty or None inputs - only use real data
401
+ if not train_losses or train_losses is None or len(train_losses) == 0:
402
+ print("WARNING: No real training loss data provided")
403
+ # Create placeholder figure with warning message
404
+ fig = plt.figure(figsize=(10, 6))
405
+ plt.text(0.5, 0.5, "No real training data available",
406
+ ha='center', va='center', transform=plt.gca().transAxes,
407
+ fontsize=14, color='darkred')
408
+ plt.axis('off')
409
+ plt.tight_layout()
410
+ return fig
411
 
412
+ if not val_losses or val_losses is None or len(val_losses) == 0:
413
+ print("WARNING: No real validation loss data provided. Using training data only.")
414
+ # Use training data for both
415
+ val_losses = train_losses
416
 
417
  # Convert to numpy arrays for safe handling
418
  train_np = np.array(train_losses)