Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- vae_model.py +27 -14
vae_model.py
CHANGED
|
@@ -318,22 +318,35 @@ class DemoVAE(BaseEstimator):
|
|
| 318 |
|
| 319 |
def load(self, path):
|
| 320 |
try:
|
| 321 |
-
#
|
|
|
|
| 322 |
try:
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
print("Successfully loaded checkpoint with weights_only=True")
|
| 326 |
-
except (TypeError, ValueError):
|
| 327 |
-
# If weights_only parameter is not supported (older PyTorch) or fails
|
| 328 |
-
print("Falling back to load with default parameters")
|
| 329 |
-
# Add necessary global variables to the safe list if using PyTorch 2.6+
|
| 330 |
-
if hasattr(torch.serialization, 'add_safe_globals'):
|
| 331 |
import numpy as np
|
| 332 |
-
# Add numpy
|
| 333 |
-
torch.serialization
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
# Initialize from checkpoint
|
| 339 |
self.set_params(**checkpoint['params'])
|
|
|
|
| 318 |
|
| 319 |
def load(self, path):
|
| 320 |
try:
|
| 321 |
+
# Try different loading methods based on PyTorch version
|
| 322 |
+
print(f"Attempting to load model from {path}")
|
| 323 |
try:
|
| 324 |
+
# For PyTorch 2.6+, explicitly set weights_only=False for backward compatibility
|
| 325 |
+
if hasattr(torch, '__version__') and torch.__version__.startswith('2.6'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
import numpy as np
|
| 327 |
+
# Add all necessary numpy types to safe globals list
|
| 328 |
+
if hasattr(torch.serialization, 'add_safe_globals'):
|
| 329 |
+
torch.serialization.add_safe_globals([
|
| 330 |
+
'numpy._core.multiarray.scalar',
|
| 331 |
+
'numpy.core.multiarray.scalar',
|
| 332 |
+
'numpy.ndarray',
|
| 333 |
+
'numpy._globals._NoValue'
|
| 334 |
+
])
|
| 335 |
+
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar']):
|
| 336 |
+
checkpoint = torch.load(path, weights_only=False)
|
| 337 |
+
else:
|
| 338 |
+
# For older PyTorch versions
|
| 339 |
+
checkpoint = torch.load(path)
|
| 340 |
+
except Exception as e:
|
| 341 |
+
print(f"Primary loading method failed: {str(e)}")
|
| 342 |
+
# Last resort - try with context manager if available
|
| 343 |
+
if hasattr(torch.serialization, 'safe_globals'):
|
| 344 |
+
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.core.multiarray.scalar']):
|
| 345 |
+
checkpoint = torch.load(path, weights_only=False)
|
| 346 |
+
else:
|
| 347 |
+
# Fall back to default with no safety
|
| 348 |
+
checkpoint = torch.load(path)
|
| 349 |
+
print("Successfully loaded checkpoint")
|
| 350 |
|
| 351 |
# Initialize from checkpoint
|
| 352 |
self.set_params(**checkpoint['params'])
|