SreekarB commited on
Commit
79a3849
·
verified ·
1 Parent(s): 06059fe

Upload 3 files

Browse files
Files changed (1) hide show
  1. vae_model.py +27 -14
vae_model.py CHANGED
@@ -318,22 +318,35 @@ class DemoVAE(BaseEstimator):
318
 
319
  def load(self, path):
320
  try:
321
- # First try loading with weights_only=True (PyTorch 2.6+ default)
 
322
  try:
323
- print(f"Attempting to load model with weights_only=True from {path}")
324
- checkpoint = torch.load(path, weights_only=True)
325
- print("Successfully loaded checkpoint with weights_only=True")
326
- except (TypeError, ValueError):
327
- # If weights_only parameter is not supported (older PyTorch) or fails
328
- print("Falling back to load with default parameters")
329
- # Add necessary global variables to the safe list if using PyTorch 2.6+
330
- if hasattr(torch.serialization, 'add_safe_globals'):
331
  import numpy as np
332
- # Add numpy scalar to safe globals
333
- torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar',
334
- 'numpy.core.multiarray.scalar'])
335
- checkpoint = torch.load(path)
336
- print("Successfully loaded checkpoint with default parameters")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  # Initialize from checkpoint
339
  self.set_params(**checkpoint['params'])
 
318
 
319
  def load(self, path):
320
  try:
321
+ # Try different loading methods based on PyTorch version
322
+ print(f"Attempting to load model from {path}")
323
  try:
324
+ # For PyTorch 2.6+, explicitly set weights_only=False for backward compatibility
325
+ if hasattr(torch, '__version__') and torch.__version__.startswith('2.6'):
 
 
 
 
 
 
326
  import numpy as np
327
+ # Add all necessary numpy types to safe globals list
328
+ if hasattr(torch.serialization, 'add_safe_globals'):
329
+ torch.serialization.add_safe_globals([
330
+ 'numpy._core.multiarray.scalar',
331
+ 'numpy.core.multiarray.scalar',
332
+ 'numpy.ndarray',
333
+ 'numpy._globals._NoValue'
334
+ ])
335
+ with torch.serialization.safe_globals(['numpy._core.multiarray.scalar']):
336
+ checkpoint = torch.load(path, weights_only=False)
337
+ else:
338
+ # For older PyTorch versions
339
+ checkpoint = torch.load(path)
340
+ except Exception as e:
341
+ print(f"Primary loading method failed: {str(e)}")
342
+ # Last resort - try with context manager if available
343
+ if hasattr(torch.serialization, 'safe_globals'):
344
+ with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.core.multiarray.scalar']):
345
+ checkpoint = torch.load(path, weights_only=False)
346
+ else:
347
+ # Fall back to default with no safety
348
+ checkpoint = torch.load(path)
349
+ print("Successfully loaded checkpoint")
350
 
351
  # Initialize from checkpoint
352
  self.set_params(**checkpoint['params'])