Spaces:
Sleeping
Sleeping
Upload 12 files
Browse files- app.py +50 -82
- cache/.DS_Store +0 -0
- cache/atlas/power_2011_coords.npy +3 -0
- config.py +4 -3
- main.py +69 -36
- test_learning.png +0 -0
- test_loading.py +62 -0
- vae_model.py +4 -0
- visualization.py +15 -7
app.py
CHANGED
|
@@ -2395,46 +2395,54 @@ def create_interface():
|
|
| 2395 |
vae.load(vae_path)
|
| 2396 |
app_state['vae'] = vae
|
| 2397 |
|
| 2398 |
-
#
|
| 2399 |
-
|
| 2400 |
-
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2401 |
-
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2402 |
-
logger.info("Generating latent representations from synthetic data...")
|
| 2403 |
|
| 2404 |
-
|
| 2405 |
-
app_state['latents'] = latents
|
| 2406 |
-
app_state['demographics'] = synthetic_demo
|
| 2407 |
app_state['vae_trained'] = True
|
| 2408 |
-
logger.info("Loaded VAE model and generated synthetic latents")
|
| 2409 |
-
else:
|
| 2410 |
-
# Train a simple VAE with synthetic data
|
| 2411 |
-
from vae_model import DemoVAE
|
| 2412 |
-
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2413 |
-
|
| 2414 |
-
logger.info("VAE model not found. Training a simple model with synthetic data...")
|
| 2415 |
-
|
| 2416 |
-
# Generate synthetic data
|
| 2417 |
-
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2418 |
-
|
| 2419 |
-
# Train a simple VAE
|
| 2420 |
-
vae = DemoVAE(latent_dim=10)
|
| 2421 |
-
vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
|
| 2422 |
|
| 2423 |
-
#
|
| 2424 |
-
|
| 2425 |
-
|
| 2426 |
-
|
| 2427 |
-
|
| 2428 |
-
|
| 2429 |
-
|
| 2430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2431 |
|
| 2432 |
-
#
|
| 2433 |
-
|
| 2434 |
-
os.makedirs('models')
|
| 2435 |
-
vae.save('models/vae_model.pt')
|
| 2436 |
|
| 2437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2438 |
except Exception as e:
|
| 2439 |
error_fig = plt.figure(figsize=(10, 6))
|
| 2440 |
message = f"Error: Unable to load or train VAE model: {str(e)}"
|
|
@@ -2635,55 +2643,15 @@ def create_interface():
|
|
| 2635 |
app_state['rf_trained'] = True
|
| 2636 |
rf_loaded = True
|
| 2637 |
|
| 2638 |
-
# If we couldn't load both models
|
| 2639 |
if not (vae_loaded and rf_loaded):
|
| 2640 |
-
logger.info("
|
| 2641 |
-
|
| 2642 |
-
# Generate synthetic data
|
| 2643 |
-
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2644 |
-
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2645 |
-
|
| 2646 |
-
# Train VAE if needed
|
| 2647 |
-
if not vae_loaded:
|
| 2648 |
-
vae = DemoVAE(latent_dim=10)
|
| 2649 |
-
vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
|
| 2650 |
-
app_state['vae'] = vae
|
| 2651 |
-
app_state['vae_trained'] = True
|
| 2652 |
-
|
| 2653 |
-
# Save for future use
|
| 2654 |
-
if not os.path.exists('models'):
|
| 2655 |
-
os.makedirs('models')
|
| 2656 |
-
vae.save('models/vae_model.pt')
|
| 2657 |
-
else:
|
| 2658 |
-
vae = app_state['vae']
|
| 2659 |
-
|
| 2660 |
-
# Get latent representations for RF training
|
| 2661 |
-
latents = vae.encode(synthetic_fc, synthetic_demo)
|
| 2662 |
|
| 2663 |
-
#
|
| 2664 |
-
|
| 2665 |
-
|
| 2666 |
-
|
| 2667 |
-
|
| 2668 |
-
import numpy as np
|
| 2669 |
-
outcomes = np.random.normal(50, 10, size=len(synthetic_demo))
|
| 2670 |
-
|
| 2671 |
-
# Train the RF model
|
| 2672 |
-
predictor = RandomForestPredictor()
|
| 2673 |
-
predictor.train(latents, outcomes)
|
| 2674 |
-
|
| 2675 |
-
app_state['predictor'] = predictor
|
| 2676 |
-
app_state['rf_trained'] = True
|
| 2677 |
-
|
| 2678 |
-
# Save for future use
|
| 2679 |
-
if not os.path.exists('models'):
|
| 2680 |
-
os.makedirs('models')
|
| 2681 |
-
torch.save({
|
| 2682 |
-
'predictor_state': predictor.model,
|
| 2683 |
-
'feature_importance': predictor.feature_importance
|
| 2684 |
-
}, 'models/predictor_model.pt')
|
| 2685 |
-
|
| 2686 |
-
logger.info("Successfully trained synthetic models for demo")
|
| 2687 |
except Exception as e:
|
| 2688 |
error_message = f"Error: Unable to load or train required models: {str(e)}"
|
| 2689 |
error_fig = plt.figure(figsize=(10, 6))
|
|
|
|
| 2395 |
vae.load(vae_path)
|
| 2396 |
app_state['vae'] = vae
|
| 2397 |
|
| 2398 |
+
# Only use real data for training and visualization
|
| 2399 |
+
logger.info("Using loaded VAE model with real data only...")
|
|
|
|
|
|
|
|
|
|
| 2400 |
|
| 2401 |
+
# Set flag to indicate VAE model is loaded, but not using synthetic data
|
|
|
|
|
|
|
| 2402 |
app_state['vae_trained'] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2403 |
|
| 2404 |
+
# Try to load previously saved latents if they exist
|
| 2405 |
+
if os.path.exists('results/latents.npy'):
|
| 2406 |
+
try:
|
| 2407 |
+
logger.info("Loading saved latent representations...")
|
| 2408 |
+
latents = np.load('results/latents.npy')
|
| 2409 |
+
app_state['latents'] = latents
|
| 2410 |
+
logger.info(f"Loaded {len(latents)} real latent vectors")
|
| 2411 |
+
|
| 2412 |
+
# Try to load real demographics if available
|
| 2413 |
+
if os.path.exists('temp_demographics.csv'):
|
| 2414 |
+
logger.info("Loading demographics from temp_demographics.csv")
|
| 2415 |
+
demo_df = pd.read_csv('temp_demographics.csv')
|
| 2416 |
+
app_state['demographics'] = {
|
| 2417 |
+
'age_at_stroke': demo_df['age'].values,
|
| 2418 |
+
'sex': demo_df['sex'].values,
|
| 2419 |
+
'months_post_stroke': demo_df['months_post_stroke'].values,
|
| 2420 |
+
'wab_score': demo_df['wab_score'].values
|
| 2421 |
+
}
|
| 2422 |
+
else:
|
| 2423 |
+
logger.warning("No real demographic data found")
|
| 2424 |
+
except Exception as e:
|
| 2425 |
+
logger.error(f"Error loading real latents: {e}")
|
| 2426 |
+
logger.warning("Will not use synthetic data")
|
| 2427 |
+
else:
|
| 2428 |
+
logger.warning("No real latent representations found")
|
| 2429 |
+
logger.warning("Will not use synthetic data")
|
| 2430 |
+
else:
|
| 2431 |
+
# Don't train with synthetic data in strict real data mode
|
| 2432 |
+
logger.info("VAE model not found and using strict real data mode.")
|
| 2433 |
+
logger.warning("Cannot train VAE model without real data")
|
| 2434 |
|
| 2435 |
+
# Set state to indicate VAE is not trained
|
| 2436 |
+
app_state['vae_trained'] = False
|
|
|
|
|
|
|
| 2437 |
|
| 2438 |
+
# Show message about requiring real data
|
| 2439 |
+
status_msg = "No VAE model available. Please train with real data first."
|
| 2440 |
+
return {
|
| 2441 |
+
tab_rf: gr.update(visible=False),
|
| 2442 |
+
tab_vae: gr.update(visible=True),
|
| 2443 |
+
status: status_msg,
|
| 2444 |
+
vae_status: "Model not trained. Upload real data and train with it."
|
| 2445 |
+
}
|
| 2446 |
except Exception as e:
|
| 2447 |
error_fig = plt.figure(figsize=(10, 6))
|
| 2448 |
message = f"Error: Unable to load or train VAE model: {str(e)}"
|
|
|
|
| 2643 |
app_state['rf_trained'] = True
|
| 2644 |
rf_loaded = True
|
| 2645 |
|
| 2646 |
+
# If we couldn't load both models in strict real data mode
|
| 2647 |
if not (vae_loaded and rf_loaded):
|
| 2648 |
+
logger.info("Strict real data mode: Not using synthetic data")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2649 |
|
| 2650 |
+
# Show a message to the user
|
| 2651 |
+
return {
|
| 2652 |
+
status: "Cannot use synthetic data in strict real data mode. Please train with real data first.",
|
| 2653 |
+
rf_status: "Not trained. Upload real data and train the VAE model first."
|
| 2654 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2655 |
except Exception as e:
|
| 2656 |
error_message = f"Error: Unable to load or train required models: {str(e)}"
|
| 2657 |
error_fig = plt.figure(figsize=(10, 6))
|
cache/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
cache/atlas/power_2011_coords.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f0e80988258ff3da5522d409679d81d633a3bd1c39c4a23494926101eb852eb
|
| 3 |
+
size 6464
|
config.py
CHANGED
|
@@ -30,7 +30,8 @@ PREDICTION_CONFIG = {
|
|
| 30 |
'default_outcome': 'wab_aq',
|
| 31 |
'save_path': 'results/treatment_predictor.joblib',
|
| 32 |
'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
|
| 33 |
-
'use_synthetic_nifti': False,
|
| 34 |
-
'use_synthetic_fc': False,
|
| 35 |
-
'strict_real_data': True # Set to True to strictly use real data only
|
|
|
|
| 36 |
}
|
|
|
|
| 30 |
'default_outcome': 'wab_aq',
|
| 31 |
'save_path': 'results/treatment_predictor.joblib',
|
| 32 |
'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
|
| 33 |
+
'use_synthetic_nifti': False, # Set to False to use only real NIfTI data
|
| 34 |
+
'use_synthetic_fc': False, # Set to False to use only real FC matrices
|
| 35 |
+
'strict_real_data': True, # Set to True to strictly use real data only
|
| 36 |
+
'no_mock_data': True # Set to True to prevent using any mock or synthetic data
|
| 37 |
}
|
main.py
CHANGED
|
@@ -204,38 +204,46 @@ def run_analysis(data_dir="data",
|
|
| 204 |
print("Creating learning curve visualization...")
|
| 205 |
|
| 206 |
# Check if losses are stored in the VAE object first (most reliable source)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
# Create a placeholder
|
| 218 |
-
print("No training history available for learning curves")
|
| 219 |
-
learning_fig = plt.figure(figsize=(10, 6))
|
| 220 |
-
plt.text(0.5, 0.5, "Learning curve data unavailable",
|
| 221 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 222 |
-
fontsize=14, color='darkred')
|
| 223 |
-
plt.axis('off')
|
| 224 |
-
plt.tight_layout()
|
| 225 |
else:
|
| 226 |
-
#
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
except Exception as e:
|
| 240 |
import traceback
|
| 241 |
print(f"Error creating learning curve plot: {e}")
|
|
@@ -249,16 +257,41 @@ def run_analysis(data_dir="data",
|
|
| 249 |
plt.axis('off')
|
| 250 |
plt.tight_layout()
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
# Initialize results dictionary
|
| 253 |
results = {
|
| 254 |
'vae': vae,
|
| 255 |
'latents': latents,
|
| 256 |
'demographics': demographics,
|
| 257 |
-
'figures':
|
| 258 |
-
'vae': fc_fig,
|
| 259 |
-
'fc_analysis': fc_fig,
|
| 260 |
-
'learning_curves': learning_fig
|
| 261 |
-
}
|
| 262 |
}
|
| 263 |
|
| 264 |
# Add reconstructed and generated FC if available
|
|
|
|
| 204 |
print("Creating learning curve visualization...")
|
| 205 |
|
| 206 |
# Check if losses are stored in the VAE object first (most reliable source)
|
| 207 |
+
train_data = []
|
| 208 |
+
val_data = []
|
| 209 |
+
|
| 210 |
+
# Only use real data from VAE object or training results
|
| 211 |
+
if hasattr(vae, 'train_losses') and len(getattr(vae, 'train_losses', [])) > 0:
|
| 212 |
+
train_data = vae.train_losses
|
| 213 |
+
print(f"Found {len(train_data)} real training loss points in VAE object")
|
| 214 |
+
elif train_losses and len(train_losses) > 0:
|
| 215 |
+
train_data = train_losses
|
| 216 |
+
print(f"Using {len(train_data)} real training loss points from fit return value")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
else:
|
| 218 |
+
# Instead of synthetic data, provide empty list and warning
|
| 219 |
+
print("WARNING: No real training loss data found")
|
| 220 |
+
train_data = []
|
| 221 |
+
|
| 222 |
+
# Do the same for validation data
|
| 223 |
+
if hasattr(vae, 'val_losses') and len(getattr(vae, 'val_losses', [])) > 0:
|
| 224 |
+
val_data = vae.val_losses
|
| 225 |
+
print(f"Found {len(val_data)} real validation loss points in VAE object")
|
| 226 |
+
elif val_losses and len(val_losses) > 0:
|
| 227 |
+
val_data = val_losses
|
| 228 |
+
print(f"Using {len(val_data)} real validation loss points from fit return value")
|
| 229 |
+
else:
|
| 230 |
+
# Instead of synthetic data, provide empty list and warning
|
| 231 |
+
print("WARNING: No real validation loss data found")
|
| 232 |
+
val_data = []
|
| 233 |
+
|
| 234 |
+
# If we get here, we have some training data (real or synthetic)
|
| 235 |
+
# Store the data in the VAE object for future use
|
| 236 |
+
if not hasattr(vae, 'train_losses') or len(getattr(vae, 'train_losses', [])) == 0:
|
| 237 |
+
print("Storing training loss data in VAE object")
|
| 238 |
+
vae.train_losses = train_data
|
| 239 |
+
|
| 240 |
+
if not hasattr(vae, 'val_losses') or len(getattr(vae, 'val_losses', [])) == 0:
|
| 241 |
+
print("Storing validation loss data in VAE object")
|
| 242 |
+
vae.val_losses = val_data
|
| 243 |
+
|
| 244 |
+
# Now create the visualization using the data we collected
|
| 245 |
+
print(f"Creating learning curve with {len(train_data)} training and {len(val_data)} validation points")
|
| 246 |
+
learning_fig = plot_learning_curves(train_data, val_data)
|
| 247 |
except Exception as e:
|
| 248 |
import traceback
|
| 249 |
print(f"Error creating learning curve plot: {e}")
|
|
|
|
| 257 |
plt.axis('off')
|
| 258 |
plt.tight_layout()
|
| 259 |
|
| 260 |
+
# Check if we should use strict real data mode
|
| 261 |
+
use_strict_real_data = PREDICTION_CONFIG.get('strict_real_data', False)
|
| 262 |
+
no_mock_data = PREDICTION_CONFIG.get('no_mock_data', False)
|
| 263 |
+
|
| 264 |
+
if use_strict_real_data or no_mock_data:
|
| 265 |
+
print("Using strict real data mode - only including real data in results")
|
| 266 |
+
# Only include figures if they contain real data
|
| 267 |
+
figures = {}
|
| 268 |
+
if hasattr(vae, 'train_losses') and len(vae.train_losses) > 0:
|
| 269 |
+
figures['learning_curves'] = learning_fig
|
| 270 |
+
print("Including real learning curves")
|
| 271 |
+
else:
|
| 272 |
+
print("WARNING: No real learning curve data available")
|
| 273 |
+
|
| 274 |
+
# Only include FC analysis if it's based on real data
|
| 275 |
+
if len(np.array(X).shape) > 0 and len(X) > 0:
|
| 276 |
+
figures['vae'] = fc_fig
|
| 277 |
+
figures['fc_analysis'] = fc_fig
|
| 278 |
+
print("Including real FC analysis")
|
| 279 |
+
else:
|
| 280 |
+
print("WARNING: No real FC data available")
|
| 281 |
+
else:
|
| 282 |
+
# Include all figures, even if based on synthetic data
|
| 283 |
+
figures = {
|
| 284 |
+
'vae': fc_fig,
|
| 285 |
+
'fc_analysis': fc_fig,
|
| 286 |
+
'learning_curves': learning_fig
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
# Initialize results dictionary
|
| 290 |
results = {
|
| 291 |
'vae': vae,
|
| 292 |
'latents': latents,
|
| 293 |
'demographics': demographics,
|
| 294 |
+
'figures': figures
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
}
|
| 296 |
|
| 297 |
# Add reconstructed and generated FC if available
|
test_learning.png
ADDED
|
test_loading.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Set Huggingface cache directory to avoid permission issues
|
| 3 |
+
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
|
| 4 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
| 5 |
+
os.makedirs('models', exist_ok=True)
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from vae_model import DemoVAE
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from visualization import plot_learning_curves
|
| 12 |
+
|
| 13 |
+
print("Creating synthetic test data...")
|
| 14 |
+
# Create small synthetic dataset with only 5 samples
|
| 15 |
+
input_dim = 100
|
| 16 |
+
n_samples = 5
|
| 17 |
+
X = np.random.randn(n_samples, input_dim)
|
| 18 |
+
demo_data = [
|
| 19 |
+
np.random.normal(60, 10, n_samples), # age
|
| 20 |
+
np.random.choice(['M', 'F'], n_samples), # sex
|
| 21 |
+
np.random.normal(24, 12, n_samples), # months post stroke
|
| 22 |
+
np.random.normal(50, 15, n_samples) # WAB score
|
| 23 |
+
]
|
| 24 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 25 |
+
|
| 26 |
+
print("Testing DemoVAE initialization...")
|
| 27 |
+
# Initialize with nepochs=3 for fast testing
|
| 28 |
+
vae = DemoVAE(latent_dim=16, nepochs=3, bsize=5)
|
| 29 |
+
|
| 30 |
+
print("Testing DemoVAE fit method...")
|
| 31 |
+
# Fit model
|
| 32 |
+
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
|
| 33 |
+
|
| 34 |
+
print(f"Train losses shape: {len(train_losses)}")
|
| 35 |
+
print(f"Val losses shape: {len(val_losses)}")
|
| 36 |
+
|
| 37 |
+
print("Testing get_latents method...")
|
| 38 |
+
# Test get_latents
|
| 39 |
+
latents = vae.get_latents(X)
|
| 40 |
+
print(f"Latents shape: {latents.shape}")
|
| 41 |
+
|
| 42 |
+
print("Testing encode method...")
|
| 43 |
+
# Test encode
|
| 44 |
+
latents2 = vae.encode(X)
|
| 45 |
+
print(f"Latents from encode shape: {latents2.shape}")
|
| 46 |
+
|
| 47 |
+
print("Testing model save...")
|
| 48 |
+
# Save model
|
| 49 |
+
vae.save('models/test_vae.pt')
|
| 50 |
+
|
| 51 |
+
print("Testing model load...")
|
| 52 |
+
# Load model
|
| 53 |
+
vae2 = DemoVAE()
|
| 54 |
+
vae2.load('models/test_vae.pt')
|
| 55 |
+
|
| 56 |
+
print("Testing learning curve plotting...")
|
| 57 |
+
# Test learning curve plotting
|
| 58 |
+
fig = plot_learning_curves(vae2.train_losses, vae2.val_losses)
|
| 59 |
+
plt.savefig('test_learning.png')
|
| 60 |
+
print("Learning curve saved to test_learning.png")
|
| 61 |
+
|
| 62 |
+
print("All tests passed!")
|
vae_model.py
CHANGED
|
@@ -234,6 +234,10 @@ class DemoVAE(BaseEstimator):
|
|
| 234 |
print(f"Returning fallback output with shape: {fallback.shape}")
|
| 235 |
return fallback
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def get_latents(self, x):
|
| 238 |
# Set model to evaluation mode
|
| 239 |
self.vae.eval()
|
|
|
|
| 234 |
print(f"Returning fallback output with shape: {fallback.shape}")
|
| 235 |
return fallback
|
| 236 |
|
| 237 |
+
def encode(self, x):
|
| 238 |
+
"""Alias for get_latents method - to provide compatibility with some interfaces"""
|
| 239 |
+
return self.get_latents(x)
|
| 240 |
+
|
| 241 |
def get_latents(self, x):
|
| 242 |
# Set model to evaluation mode
|
| 243 |
self.vae.eval()
|
visualization.py
CHANGED
|
@@ -397,14 +397,22 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
|
|
| 397 |
def plot_learning_curves(train_losses, val_losses):
|
| 398 |
"""Plot VAE learning curves with enhanced visualization"""
|
| 399 |
try:
|
| 400 |
-
# Handle empty or None inputs
|
| 401 |
-
if not train_losses or train_losses is None:
|
| 402 |
-
print("WARNING: No training loss data provided")
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
-
if not val_losses or val_losses is None:
|
| 406 |
-
print("WARNING: No validation loss data provided")
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
# Convert to numpy arrays for safe handling
|
| 410 |
train_np = np.array(train_losses)
|
|
|
|
| 397 |
def plot_learning_curves(train_losses, val_losses):
|
| 398 |
"""Plot VAE learning curves with enhanced visualization"""
|
| 399 |
try:
|
| 400 |
+
# Handle empty or None inputs - only use real data
|
| 401 |
+
if not train_losses or train_losses is None or len(train_losses) == 0:
|
| 402 |
+
print("WARNING: No real training loss data provided")
|
| 403 |
+
# Create placeholder figure with warning message
|
| 404 |
+
fig = plt.figure(figsize=(10, 6))
|
| 405 |
+
plt.text(0.5, 0.5, "No real training data available",
|
| 406 |
+
ha='center', va='center', transform=plt.gca().transAxes,
|
| 407 |
+
fontsize=14, color='darkred')
|
| 408 |
+
plt.axis('off')
|
| 409 |
+
plt.tight_layout()
|
| 410 |
+
return fig
|
| 411 |
|
| 412 |
+
if not val_losses or val_losses is None or len(val_losses) == 0:
|
| 413 |
+
print("WARNING: No real validation loss data provided. Using training data only.")
|
| 414 |
+
# Use training data for both
|
| 415 |
+
val_losses = train_losses
|
| 416 |
|
| 417 |
# Convert to numpy arrays for safe handling
|
| 418 |
train_np = np.array(train_losses)
|