aditya-g07 commited on
Commit
a0cfc96
Β·
1 Parent(s): e586088

Add model loading test function for better debugging

Browse files
Files changed (1) hide show
  1. app.py +55 -1
app.py CHANGED
@@ -240,9 +240,63 @@ def detect_faces(image, model_type="mobilenet", confidence_threshold=0.5, nms_th
240
  except Exception as e:
241
  return None, f"Error: {str(e)}"
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Load models on startup
244
  print("Loading RetinaFace models...")
245
- model_loaded = load_models()
 
 
 
 
 
 
 
246
 
247
  # Create simple Gradio interface
248
  def create_interface():
 
240
  except Exception as e:
241
  return None, f"Error: {str(e)}"
242
 
243
+ # Simple test function to debug model loading
244
+ def test_model_loading():
245
+ """Test model loading step by step"""
246
+ try:
247
+ print("=== Testing Model Loading ===")
248
+
249
+ # Test basic imports
250
+ print("Testing RetinaFace import...")
251
+ test_cfg = {
252
+ 'name': 'mobilenet0.25',
253
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
254
+ 'steps': [8, 16, 32],
255
+ 'variance': [0.1, 0.2],
256
+ 'clip': False,
257
+ 'pretrain': False,
258
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
259
+ 'in_channel': 32,
260
+ 'out_channel': 64
261
+ }
262
+
263
+ print("Creating RetinaFace instance...")
264
+ model = RetinaFace(cfg=test_cfg, phase='test')
265
+ print(f"βœ… Model created successfully: {type(model)}")
266
+
267
+ print("Checking model file...")
268
+ if os.path.exists('mobilenet0.25_Final.pth'):
269
+ print("βœ… Model file exists")
270
+
271
+ print("Loading state dict...")
272
+ state_dict = torch.load('mobilenet0.25_Final.pth', map_location='cpu')
273
+ print(f"βœ… State dict loaded, keys: {len(state_dict.keys())}")
274
+
275
+ print("Loading state dict into model...")
276
+ model.load_state_dict(state_dict)
277
+ print("βœ… State dict loaded successfully!")
278
+
279
+ return True
280
+ else:
281
+ print("❌ Model file not found")
282
+ return False
283
+
284
+ except Exception as e:
285
+ import traceback
286
+ print(f"❌ Test failed: {e}")
287
+ print(f"❌ Traceback: {traceback.format_exc()}")
288
+ return False
289
+
290
  # Load models on startup
291
  print("Loading RetinaFace models...")
292
+ print("Running model loading test...")
293
+ test_result = test_model_loading()
294
+ if test_result:
295
+ print("Test passed, proceeding with full model loading...")
296
+ model_loaded = load_models()
297
+ else:
298
+ print("Test failed, skipping model loading...")
299
+ model_loaded = False
300
 
301
  # Create simple Gradio interface
302
  def create_interface():