aditya-g07 commited on
Commit
2d98925
Β·
1 Parent(s): 6feecf0

Fix Gradio JSON schema error: Simplify interface and use stable version

Browse files

- Downgrade to Gradio 4.36.0 (stable version without JSON schema issues)
- Completely rewrite app.py with simplified interface structure
- Remove complex API functions that were causing schema parsing errors
- Use straightforward input/output types that Gradio can handle properly
- Maintain core face detection functionality while fixing runtime errors

Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +189 -367
  3. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ”
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  import time
10
  from PIL import Image, ImageDraw
11
  import json
 
12
 
13
  # Import RetinaFace model components
14
  from models.retinaface import RetinaFace
@@ -26,6 +27,47 @@ def load_models():
26
  global mobilenet_model, resnet_model
27
 
28
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Load MobileNet model
30
  mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
31
  mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device))
@@ -38,422 +80,202 @@ def load_models():
38
  resnet_model.eval()
39
  resnet_model = resnet_model.to(device)
40
 
41
- print("Models loaded successfully!")
42
- return "βœ… Models loaded successfully!"
43
 
44
  except Exception as e:
45
- error_msg = f"❌ Error loading models: {e}"
46
- print(error_msg)
47
- return error_msg
48
 
49
- # Model configurations
50
- mobilenet_cfg = {
51
- 'name': 'mobilenet0.25',
52
- 'min_sizes': [[16, 32], [64, 128], [256, 512]],
53
- 'steps': [8, 16, 32],
54
- 'variance': [0.1, 0.2],
55
- 'clip': False,
56
- 'loc_weight': 2.0,
57
- 'gpu_train': True,
58
- 'batch_size': 32,
59
- 'ngpu': 1,
60
- 'epoch': 250,
61
- 'decay1': 190,
62
- 'decay2': 220,
63
- 'image_size': 640,
64
- 'pretrain': False, # Don't load pretrained weights
65
- 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
66
- 'in_channel': 32,
67
- 'out_channel': 64
68
- }
69
-
70
- resnet_cfg = {
71
- 'name': 'Resnet50',
72
- 'min_sizes': [[16, 32], [64, 128], [256, 512]],
73
- 'steps': [8, 16, 32],
74
- 'variance': [0.1, 0.2],
75
- 'clip': False,
76
- 'loc_weight': 2.0,
77
- 'gpu_train': True,
78
- 'batch_size': 24,
79
- 'ngpu': 4,
80
- 'epoch': 100,
81
- 'decay1': 70,
82
- 'decay2': 90,
83
- 'image_size': 840,
84
- 'pretrain': False, # Don't load pretrained weights
85
- 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
86
- 'in_channel': 256,
87
- 'out_channel': 256
88
- }
89
-
90
- def detect_faces_core(image, model, cfg, confidence_threshold=0.02, nms_threshold=0.4):
91
  """Core face detection function"""
92
- start_time = time.time()
93
-
94
- # Preprocessing
95
- img = np.float32(image)
96
- im_height, im_width, _ = img.shape
97
- scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
98
- img -= (104, 117, 123)
99
- img = img.transpose(2, 0, 1)
100
- img = torch.from_numpy(img).unsqueeze(0)
101
- img = img.to(device)
102
- scale = scale.to(device)
103
-
104
- # Forward pass
105
- with torch.no_grad():
106
- loc, conf, landms = model(img)
107
-
108
- # Post-processing
109
- priorbox = PriorBox(cfg, image_size=(im_height, im_width))
110
- priors = priorbox.forward()
111
- priors = priors.to(device)
112
- prior_data = priors.data
113
- boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
114
- boxes = boxes * scale / 1
115
- boxes = boxes.cpu().numpy()
116
- scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
117
- landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
118
- scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
119
- img.shape[3], img.shape[2], img.shape[3], img.shape[2],
120
- img.shape[3], img.shape[2]])
121
- scale1 = scale1.to(device)
122
- landms = landms * scale1 / 1
123
- landms = landms.cpu().numpy()
124
-
125
- # Ignore low scores
126
- inds = np.where(scores > confidence_threshold)[0]
127
- boxes = boxes[inds]
128
- landms = landms[inds]
129
- scores = scores[inds]
130
-
131
- # Keep top-K before NMS
132
- order = scores.argsort()[::-1][:5000]
133
- boxes = boxes[order]
134
- landms = landms[order]
135
- scores = scores[order]
136
-
137
- # Do NMS
138
- dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
139
- keep = py_cpu_nms(dets, nms_threshold)
140
- dets = dets[keep, :]
141
- landms = landms[keep]
142
-
143
- # Format results
144
- faces = []
145
- for i in range(dets.shape[0]):
146
- if dets[i, 4] < confidence_threshold:
147
- continue
148
-
149
- face = {
150
- "bbox": {
151
- "x1": float(dets[i, 0]),
152
- "y1": float(dets[i, 1]),
153
- "x2": float(dets[i, 2]),
154
- "y2": float(dets[i, 3])
155
- },
156
- "confidence": float(dets[i, 4]),
157
- "landmarks": {
158
- "right_eye": [float(landms[i, 0]), float(landms[i, 1])],
159
- "left_eye": [float(landms[i, 2]), float(landms[i, 3])],
160
- "nose": [float(landms[i, 4]), float(landms[i, 5])],
161
- "right_mouth": [float(landms[i, 6]), float(landms[i, 7])],
162
- "left_mouth": [float(landms[i, 8]), float(landms[i, 9])]
163
- }
164
- }
165
- faces.append(face)
166
-
167
- processing_time = time.time() - start_time
168
- return faces, processing_time
169
-
170
- def draw_faces_on_image(image, faces):
171
- """Draw bounding boxes and landmarks on image"""
172
- if isinstance(image, np.ndarray):
173
- # Convert numpy array to PIL Image
174
- image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
175
-
176
- draw = ImageDraw.Draw(image)
177
-
178
- for face in faces:
179
- bbox = face["bbox"]
180
- confidence = face["confidence"]
181
- landmarks = face["landmarks"]
182
-
183
- # Draw bounding box
184
- draw.rectangle([bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]],
185
- outline="red", width=2)
186
-
187
- # Draw confidence score
188
- draw.text((bbox["x1"], bbox["y1"] - 15),
189
- f'{confidence:.2f}', fill="red")
190
-
191
- # Draw landmarks
192
- for landmark_name, (x, y) in landmarks.items():
193
- draw.ellipse([x-2, y-2, x+2, y+2], fill="blue")
194
-
195
- return image
196
-
197
- def gradio_detect_faces(image, model_type, confidence_threshold, nms_threshold):
198
- """Gradio interface function for face detection"""
199
- if mobilenet_model is None or resnet_model is None:
200
- return None, "❌ Models not loaded. Please wait for models to load.", ""
201
-
202
  try:
203
- # Convert PIL to OpenCV format
204
- if isinstance(image, Image.Image):
205
- image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
206
 
207
- # Select model
208
- if model_type.lower() == "resnet":
209
  model = resnet_model
210
- cfg = resnet_cfg
211
- model_name = "ResNet50"
 
 
 
 
 
212
  else:
213
  model = mobilenet_model
214
- cfg = mobilenet_cfg
215
- model_name = "MobileNet"
216
-
217
- # Detect faces
218
- faces, processing_time = detect_faces_core(
219
- image, model, cfg, confidence_threshold, nms_threshold
220
- )
221
-
222
- # Draw results on image
223
- result_image = draw_faces_on_image(image.copy(), faces)
224
-
225
- # Create results text
226
- results_text = f"🎯 **Detection Results**\n"
227
- results_text += f"πŸ“± Model: {model_name}\n"
228
- results_text += f"⏱️ Processing Time: {processing_time:.3f}s\n"
229
- results_text += f"πŸ‘₯ Faces Detected: {len(faces)}\n\n"
230
-
231
- for i, face in enumerate(faces):
232
- results_text += f"**Face {i+1}:**\n"
233
- results_text += f" Confidence: {face['confidence']:.3f}\n"
234
- bbox = face['bbox']
235
- results_text += f" Location: ({bbox['x1']:.0f}, {bbox['y1']:.0f}) - ({bbox['x2']:.0f}, {bbox['y2']:.0f})\n\n"
236
-
237
- # Create JSON output for API use
238
- json_output = {
239
- "faces": faces,
240
- "processing_time": processing_time,
241
- "model_used": model_name.lower(),
242
- "total_faces": len(faces)
243
- }
244
-
245
- return result_image, results_text, json.dumps(json_output, indent=2)
246
 
247
- except Exception as e:
248
- error_msg = f"❌ Detection failed: {str(e)}"
249
- return None, error_msg, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- def api_detect_live(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
252
- """API function for live detection (Thunkable compatible)"""
253
- try:
254
- # Decode base64 image
255
- image_data = base64.b64decode(image_base64)
256
- nparr = np.frombuffer(image_data, np.uint8)
257
- image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
258
-
259
- if image is None:
260
- return {"error": "Invalid image data"}
261
-
262
- # Select model
263
- if model_type.lower() == "resnet":
264
- model = resnet_model
265
- cfg = resnet_cfg
266
- model_name = "resnet"
267
- else:
268
- model = mobilenet_model
269
- cfg = mobilenet_cfg
270
- model_name = "mobilenet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- if model is None:
273
- return {"error": f"{model_name} model not loaded"}
274
 
275
- # Detect faces
276
- faces, processing_time = detect_faces_core(
277
- image, model, cfg, confidence_threshold, nms_threshold
278
- )
 
 
 
 
279
 
280
- return {
281
- "faces": faces,
282
- "processing_time": processing_time,
283
- "model_used": model_name,
284
- "total_faces": len(faces)
285
- }
286
 
287
  except Exception as e:
288
- return {"error": f"Detection failed: {str(e)}"}
289
 
290
  # Load models on startup
291
  print("Loading RetinaFace models...")
292
- load_status = load_models()
293
 
294
- # Create Gradio interface
295
- with gr.Blocks(title="RetinaFace Face Detection API", theme=gr.themes.Soft()) as demo:
296
- gr.Markdown("""
297
- # πŸ”₯ RetinaFace Face Detection API
298
-
299
- **Real-time face detection using RetinaFace with MobileNet and ResNet backbones**
300
-
301
- - πŸ“± **Mobile App Ready**: Compatible with Thunkable and other mobile frameworks
302
- - ⚑ **Dual Models**: MobileNet (fast) and ResNet (accurate)
303
- - 🎯 **High Accuracy**: Detects faces with bounding boxes and 5-point landmarks
304
- - 🌐 **API Endpoints**: Use `/api/predict` for programmatic access
305
- """)
306
-
307
- with gr.Row():
308
- gr.Markdown(f"**Status**: {load_status}")
309
-
310
- with gr.Tab("πŸ–ΌοΈ Image Detection"):
311
  with gr.Row():
312
  with gr.Column():
313
  input_image = gr.Image(type="pil", label="Upload Image")
314
  model_choice = gr.Dropdown(
315
  choices=["mobilenet", "resnet"],
316
  value="mobilenet",
317
- label="Model Type"
318
  )
319
- confidence_slider = gr.Slider(
320
  minimum=0.1, maximum=1.0, value=0.5, step=0.1,
321
- label="Confidence Threshold"
322
  )
323
- nms_slider = gr.Slider(
324
  minimum=0.1, maximum=1.0, value=0.4, step=0.1,
325
  label="NMS Threshold"
326
  )
327
  detect_btn = gr.Button("πŸ” Detect Faces", variant="primary")
328
 
329
  with gr.Column():
330
- output_image = gr.Image(label="Detection Results")
331
- results_text = gr.Markdown(label="Results")
332
 
333
  detect_btn.click(
334
- fn=gradio_detect_faces,
335
- inputs=[input_image, model_choice, confidence_slider, nms_slider],
336
- outputs=[output_image, results_text]
337
  )
338
-
339
- with gr.Tab("πŸ”— API Documentation"):
340
- gr.Markdown("""
341
- ## API Endpoints for Thunkable Integration
342
 
343
- ### 1. Live Detection Endpoint
344
- ```
345
- POST /api/predict
346
- ```
347
-
348
- **Request Body (JSON):**
349
- ```json
350
- {
351
- "data": [
352
- "base64_encoded_image_string",
353
- "mobilenet",
354
- 0.5,
355
- 0.4
356
- ]
357
- }
358
- ```
359
-
360
- **Response:**
361
- ```json
362
- {
363
- "data": [
364
- {
365
- "faces": [...],
366
- "processing_time": 0.1,
367
- "model_used": "mobilenet",
368
- "total_faces": 2
369
- }
370
- ]
371
- }
372
- ```
373
-
374
- ### 2. Thunkable Integration Example
375
-
376
- **Web API Component Setup:**
377
- - URL: `https://your-space-name.hf.space/api/predict`
378
- - Method: `POST`
379
- - Headers: `Content-Type: application/json`
380
- - Body:
381
  ```json
382
  {
383
- "data": [
384
- "{{base64_image}}",
385
- "mobilenet",
386
- 0.5,
387
- 0.4
388
- ]
389
  }
390
  ```
391
-
392
- ### 3. Model Performance
393
-
394
- | Model | Speed | Accuracy | Best For |
395
- |-------|-------|----------|----------|
396
- | MobileNet | ⚑ Fast | 🎯 Good | Real-time mobile apps |
397
- | ResNet50 | 🐌 Slower | 🎯🎯 High | High-accuracy applications |
398
-
399
- ### 4. Response Format
400
-
401
- Each detected face includes:
402
- - **bbox**: Bounding box coordinates (x1, y1, x2, y2)
403
- - **confidence**: Detection confidence score (0-1)
404
- - **landmarks**: 5-point facial landmarks (eyes, nose, mouth corners)
405
  """)
406
 
407
- with gr.Tab("πŸ“Š API Testing"):
408
- gr.Markdown("### Test the API with base64 encoded images")
409
-
410
- with gr.Row():
411
- with gr.Column():
412
- test_image_b64 = gr.Textbox(
413
- label="Base64 Encoded Image",
414
- placeholder="Paste base64 encoded image here...",
415
- lines=3
416
- )
417
- test_model = gr.Dropdown(
418
- choices=["mobilenet", "resnet"],
419
- value="mobilenet",
420
- label="Model"
421
- )
422
- test_conf = gr.Number(value=0.5, label="Confidence")
423
- test_nms = gr.Number(value=0.4, label="NMS Threshold")
424
- test_btn = gr.Button("πŸ§ͺ Test API", variant="secondary")
425
-
426
- with gr.Column():
427
- api_output = gr.JSON(label="API Response")
428
-
429
- def test_api_function(image_b64, model, conf, nms):
430
- if not image_b64.strip():
431
- return {"error": "Please provide base64 encoded image"}
432
-
433
- # Remove data URL prefix if present
434
- if image_b64.startswith('data:image'):
435
- image_b64 = image_b64.split(',')[1]
436
-
437
- result = api_detect_live(image_b64, model, conf, nms)
438
- return result
439
-
440
- test_btn.click(
441
- fn=test_api_function,
442
- inputs=[test_image_b64, test_model, test_conf, test_nms],
443
- outputs=[api_output]
444
- )
445
 
446
- # Custom API function for external calls
447
- def predict_api(image_base64, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
448
- """API prediction function that matches Gradio's expected format"""
449
- result = api_detect_live(image_base64, model_type, confidence_threshold, nms_threshold)
450
- return result
451
 
452
- # Launch the app
453
  if __name__ == "__main__":
454
  demo.launch(
455
  server_name="0.0.0.0",
456
  server_port=7860,
457
- share=True,
458
- show_error=True
459
  )
 
9
  import time
10
  from PIL import Image, ImageDraw
11
  import json
12
+ import io
13
 
14
  # Import RetinaFace model components
15
  from models.retinaface import RetinaFace
 
27
  global mobilenet_model, resnet_model
28
 
29
  try:
30
+ # Model configurations
31
+ mobilenet_cfg = {
32
+ 'name': 'mobilenet0.25',
33
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
34
+ 'steps': [8, 16, 32],
35
+ 'variance': [0.1, 0.2],
36
+ 'clip': False,
37
+ 'loc_weight': 2.0,
38
+ 'gpu_train': True,
39
+ 'batch_size': 32,
40
+ 'ngpu': 1,
41
+ 'epoch': 250,
42
+ 'decay1': 190,
43
+ 'decay2': 220,
44
+ 'image_size': 640,
45
+ 'pretrain': False,
46
+ 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
47
+ 'in_channel': 32,
48
+ 'out_channel': 64
49
+ }
50
+
51
+ resnet_cfg = {
52
+ 'name': 'Resnet50',
53
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
54
+ 'steps': [8, 16, 32],
55
+ 'variance': [0.1, 0.2],
56
+ 'clip': False,
57
+ 'loc_weight': 2.0,
58
+ 'gpu_train': True,
59
+ 'batch_size': 24,
60
+ 'ngpu': 4,
61
+ 'epoch': 100,
62
+ 'decay1': 70,
63
+ 'decay2': 90,
64
+ 'image_size': 840,
65
+ 'pretrain': False,
66
+ 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
67
+ 'in_channel': 256,
68
+ 'out_channel': 256
69
+ }
70
+
71
  # Load MobileNet model
72
  mobilenet_model = RetinaFace(cfg=mobilenet_cfg, phase='test')
73
  mobilenet_model.load_state_dict(torch.load('mobilenet0.25_Final.pth', map_location=device))
 
80
  resnet_model.eval()
81
  resnet_model = resnet_model.to(device)
82
 
83
+ print("βœ… Models loaded successfully!")
84
+ return True
85
 
86
  except Exception as e:
87
+ print(f"❌ Error loading models: {e}")
88
+ return False
 
89
 
90
+ def detect_faces(image, model_type="mobilenet", confidence_threshold=0.5, nms_threshold=0.4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  """Core face detection function"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ start_time = time.time()
 
 
94
 
95
+ # Choose model
96
+ if model_type == "resnet":
97
  model = resnet_model
98
+ cfg = {
99
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
100
+ 'steps': [8, 16, 32],
101
+ 'variance': [0.1, 0.2],
102
+ 'clip': False,
103
+ 'image_size': 840
104
+ }
105
  else:
106
  model = mobilenet_model
107
+ cfg = {
108
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
109
+ 'steps': [8, 16, 32],
110
+ 'variance': [0.1, 0.2],
111
+ 'clip': False,
112
+ 'image_size': 640
113
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ if model is None:
116
+ return None, "Models not loaded"
117
+
118
+ # Convert PIL to numpy array
119
+ if isinstance(image, Image.Image):
120
+ image = np.array(image)
121
+
122
+ # Preprocessing
123
+ img = np.float32(image)
124
+ im_height, im_width, _ = img.shape
125
+ scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
126
+ img -= (104, 117, 123)
127
+ img = img.transpose(2, 0, 1)
128
+ img = torch.from_numpy(img).unsqueeze(0)
129
+ img = img.to(device)
130
+ scale = scale.to(device)
131
 
132
+ # Forward pass
133
+ with torch.no_grad():
134
+ loc, conf, landms = model(img)
135
+
136
+ # Generate priors
137
+ priorbox = PriorBox(cfg, image_size=(im_height, im_width))
138
+ priors = priorbox.forward()
139
+ priors = priors.to(device)
140
+ prior_data = priors.data
141
+ boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
142
+ boxes = boxes * scale
143
+ boxes = boxes.cpu().numpy()
144
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
145
+ landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
146
+ scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
147
+ img.shape[3], img.shape[2], img.shape[3], img.shape[2],
148
+ img.shape[3], img.shape[2]])
149
+ scale1 = scale1.to(device)
150
+ landms = landms * scale1
151
+ landms = landms.cpu().numpy()
152
+
153
+ # Ignore low scores
154
+ inds = np.where(scores > confidence_threshold)[0]
155
+ boxes = boxes[inds]
156
+ landms = landms[inds]
157
+ scores = scores[inds]
158
+
159
+ # Keep top-K before NMS
160
+ order = scores.argsort()[::-1][:5000]
161
+ boxes = boxes[order]
162
+ landms = landms[order]
163
+ scores = scores[order]
164
+
165
+ # Apply NMS
166
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
167
+ keep = py_cpu_nms(dets, nms_threshold)
168
+ dets = dets[keep, :]
169
+ landms = landms[keep]
170
+
171
+ # Draw results
172
+ result_image = Image.fromarray(image)
173
+ draw = ImageDraw.Draw(result_image)
174
+
175
+ faces = []
176
+ for b, landmarks in zip(dets, landms):
177
+ if b[4] < confidence_threshold:
178
+ continue
179
+
180
+ # Draw bounding box
181
+ draw.rectangle([b[0], b[1], b[2], b[3]], outline="red", width=2)
182
+
183
+ # Draw confidence score
184
+ draw.text((b[0], b[1] - 15), f'{b[4]:.2f}', fill="red")
185
+
186
+ # Draw landmarks
187
+ for i in range(0, 10, 2):
188
+ draw.ellipse([landmarks[i]-2, landmarks[i+1]-2, landmarks[i]+2, landmarks[i+1]+2], fill="blue")
189
+
190
+ faces.append({
191
+ "bbox": {"x1": float(b[0]), "y1": float(b[1]), "x2": float(b[2]), "y2": float(b[3])},
192
+ "confidence": float(b[4]),
193
+ "landmarks": {
194
+ "left_eye": [float(landmarks[0]), float(landmarks[1])],
195
+ "right_eye": [float(landmarks[2]), float(landmarks[3])],
196
+ "nose": [float(landmarks[4]), float(landmarks[5])],
197
+ "left_mouth": [float(landmarks[6]), float(landmarks[7])],
198
+ "right_mouth": [float(landmarks[8]), float(landmarks[9])]
199
+ }
200
+ })
201
 
202
+ processing_time = time.time() - start_time
 
203
 
204
+ result_text = f"""
205
+ **Detection Results:**
206
+ - **Faces Detected:** {len(faces)}
207
+ - **Model Used:** {model_type}
208
+ - **Processing Time:** {processing_time:.3f}s
209
+ - **Confidence Threshold:** {confidence_threshold}
210
+ - **NMS Threshold:** {nms_threshold}
211
+ """
212
 
213
+ return result_image, result_text
 
 
 
 
 
214
 
215
  except Exception as e:
216
+ return None, f"Error: {str(e)}"
217
 
218
  # Load models on startup
219
  print("Loading RetinaFace models...")
220
+ model_loaded = load_models()
221
 
222
+ # Create simple Gradio interface
223
+ def create_interface():
224
+ with gr.Blocks(title="RetinaFace Face Detection") as demo:
225
+ gr.Markdown("# πŸ”₯ RetinaFace Face Detection API")
226
+ gr.Markdown("Real-time face detection using RetinaFace with MobileNet and ResNet backbones")
227
+
228
+ if model_loaded:
229
+ gr.Markdown("βœ… **Status**: Models loaded successfully!")
230
+ else:
231
+ gr.Markdown("❌ **Status**: Error loading models")
232
+
 
 
 
 
 
 
233
  with gr.Row():
234
  with gr.Column():
235
  input_image = gr.Image(type="pil", label="Upload Image")
236
  model_choice = gr.Dropdown(
237
  choices=["mobilenet", "resnet"],
238
  value="mobilenet",
239
+ label="Model"
240
  )
241
+ confidence = gr.Slider(
242
  minimum=0.1, maximum=1.0, value=0.5, step=0.1,
243
+ label="Confidence"
244
  )
245
+ nms = gr.Slider(
246
  minimum=0.1, maximum=1.0, value=0.4, step=0.1,
247
  label="NMS Threshold"
248
  )
249
  detect_btn = gr.Button("πŸ” Detect Faces", variant="primary")
250
 
251
  with gr.Column():
252
+ output_image = gr.Image(label="Results")
253
+ output_text = gr.Markdown()
254
 
255
  detect_btn.click(
256
+ fn=detect_faces,
257
+ inputs=[input_image, model_choice, confidence, nms],
258
+ outputs=[output_image, output_text]
259
  )
 
 
 
 
260
 
261
+ gr.Markdown("""
262
+ ## API Usage
263
+ Use `/api/predict` endpoint with:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  ```json
265
  {
266
+ "data": [image, "mobilenet", 0.5, 0.4]
 
 
 
 
 
267
  }
268
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  """)
270
 
271
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # Create and launch the interface
274
+ demo = create_interface()
 
 
 
275
 
 
276
  if __name__ == "__main__":
277
  demo.launch(
278
  server_name="0.0.0.0",
279
  server_port=7860,
280
+ share=True
 
281
  )
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==4.44.1
2
  torch==2.0.1
3
  torchvision==0.15.2
4
  opencv-python==4.8.1.78
 
1
+ gradio==4.36.0
2
  torch==2.0.1
3
  torchvision==0.15.2
4
  opencv-python==4.8.1.78