ogulcanaydogan commited on
Commit
ba0c288
·
1 Parent(s): 8acbc10

feat: upgrade to advanced counting system with RT-DETR and proper line crossing

Browse files

- Add RT-DETR model support for dense/crowded scenes
- Implement proper geometric line crossing detection
- Add multi-class detection modes (people, vehicles, animals, sheep)
- Add configurable track buffer and activation threshold
- Increase GPU duration to 180s for longer videos
- Add unique tracks and max simultaneous count metrics

Files changed (1) hide show
  1. app.py +319 -56
app.py CHANGED
@@ -1,3 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import cv2
@@ -5,126 +18,376 @@ import numpy as np
5
  import tempfile
6
  import os
7
  from collections import defaultdict
 
8
 
9
  import supervision as sv
10
- from ultralytics import YOLO
11
 
12
- COCO_CLASSES = {
13
- 0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  }
15
 
16
- MODEL_CACHE = {}
17
 
18
  def get_model(model_name: str):
 
19
  if model_name not in MODEL_CACHE:
20
  model_map = {
21
- "YOLOv8n (Fast)": "yolov8n.pt",
22
- "YOLOv8s (Balanced)": "yolov8s.pt",
 
 
 
23
  }
24
- MODEL_CACHE[model_name] = YOLO(model_map.get(model_name, "yolov8n.pt"))
 
 
 
 
25
  return MODEL_CACHE[model_name]
26
 
27
 
28
- @spaces.GPU(duration=120)
29
- def process_video(video_path, detection_model, confidence, line_position):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if video_path is None:
31
- return None, "Please upload a video."
32
 
 
33
  model = get_model(detection_model)
 
 
 
 
 
34
  cap = cv2.VideoCapture(video_path)
35
  if not cap.isOpened():
36
- return None, "Failed to open video."
37
 
38
  fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
39
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
40
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
41
 
 
42
  output_path = tempfile.mktemp(suffix=".mp4")
43
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
44
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
45
 
46
- tracker = sv.ByteTrack(track_activation_threshold=0.25, lost_track_buffer=30, minimum_matching_threshold=0.8, frame_rate=fps)
 
 
 
 
 
 
 
 
47
  line_y = int(height * line_position)
48
- line_zone = sv.LineZone(start=sv.Point(0, line_y), end=sv.Point(width, line_y))
 
 
 
49
 
 
50
  box_annotator = sv.BoxAnnotator(thickness=2)
51
- label_annotator = sv.LabelAnnotator(text_scale=0.5, text_thickness=1)
52
- trace_annotator = sv.TraceAnnotator(thickness=2, trace_length=30)
53
- line_annotator = sv.LineZoneAnnotator(thickness=2, text_scale=0.5)
54
 
55
- total_in, total_out, frame_idx = 0, 0, 0
56
- class_counts = defaultdict(lambda: {"in": 0, "out": 0})
 
 
 
 
 
 
 
 
 
57
 
58
  while True:
59
  ret, frame = cap.read()
60
  if not ret:
61
  break
 
 
62
  results = model.predict(frame, conf=confidence, verbose=False)[0]
63
- detections = sv.Detections.from_ultralytics(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  detections = tracker.update_with_detections(detections)
65
- crossed_in, crossed_out = line_zone.trigger(detections)
66
 
67
- if crossed_in.any():
68
- for idx in np.where(crossed_in)[0]:
69
- class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
70
- class_name = COCO_CLASSES.get(class_id, f"class_{class_id}")
71
- class_counts[class_name]["in"] += 1
72
- total_in += 1
73
 
74
- if crossed_out.any():
75
- for idx in np.where(crossed_out)[0]:
 
 
 
76
  class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
77
- class_name = COCO_CLASSES.get(class_id, f"class_{class_id}")
78
- class_counts[class_name]["out"] += 1
79
- total_out += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
81
  annotated = frame.copy()
 
 
 
 
 
 
 
82
  annotated = trace_annotator.annotate(annotated, detections)
83
  annotated = box_annotator.annotate(annotated, detections)
 
84
  labels = []
85
- for idx in range(len(detections)):
86
- class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
87
- class_name = COCO_CLASSES.get(class_id, f"class_{class_id}")
88
- track_id = detections.tracker_id[idx] if detections.tracker_id is not None else 0
89
- labels.append(f"{class_name} #{track_id}")
 
90
  annotated = label_annotator.annotate(annotated, detections, labels)
91
- annotated = line_annotator.annotate(annotated, line_zone)
92
- cv2.putText(annotated, f"IN: {total_in} | OUT: {total_out}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
 
 
 
 
 
 
 
 
 
93
  out.write(annotated)
94
  frame_idx += 1
95
 
96
  cap.release()
97
  out.release()
 
 
98
  final_path = tempfile.mktemp(suffix=".mp4")
99
- os.system(f"ffmpeg -y -i {output_path} -c:v libx264 -preset fast -crf 23 {final_path} -loglevel quiet")
100
  if os.path.exists(final_path) and os.path.getsize(final_path) > 0:
101
  os.remove(output_path)
102
  output_path = final_path
103
 
104
- stats = "## Results\n\n"
105
- stats += f"**Entered:** {total_in}\n"
106
- stats += f"**Exited:** {total_out}\n"
107
- stats += f"**Net:** {total_in - total_out}\n\n"
108
- for cls, counts in sorted(class_counts.items()):
109
- stats += f"- {cls}: IN={counts['in']}, OUT={counts['out']}\n"
110
- stats += f"\n**Frames:** {frame_idx}"
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return output_path, stats
112
 
113
 
114
- with gr.Blocks(analytics_enabled=False) as demo:
115
- gr.Markdown("# CCTV Customer Analytics")
116
- gr.Markdown("Upload a video to detect, track, and count objects crossing a line.")
 
 
 
 
 
 
 
 
 
 
 
 
117
  with gr.Row():
118
- with gr.Column():
119
  video_input = gr.Video(label="Upload Video")
120
- model_dropdown = gr.Dropdown(choices=["YOLOv8n (Fast)", "YOLOv8s (Balanced)"], value="YOLOv8s (Balanced)", label="Model")
121
- confidence_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.05, label="Confidence")
122
- line_slider = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Line Position")
123
- submit_btn = gr.Button("Process Video", variant="primary")
124
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  video_output = gr.Video(label="Processed Video")
126
  stats_output = gr.Markdown(label="Statistics")
127
- submit_btn.click(fn=process_video, inputs=[video_input, model_dropdown, confidence_slider, line_slider], outputs=[video_output, stats_output], api_name=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  if __name__ == "__main__":
130
  demo.launch()
 
1
+ """CCTV Customer Analytics - Advanced Object Counting System
2
+
3
+ This Space provides accurate object detection, tracking, and counting
4
+ across a user-defined line. Optimized for counting large numbers of
5
+ animals (sheep, cows) and vehicles in crowded scenes.
6
+
7
+ Key Features:
8
+ - RT-DETR and YOLOv8 model support
9
+ - Optimized ByteTrack for dense scenes
10
+ - Proper geometric line crossing detection
11
+ - Multi-class object support
12
+ """
13
+
14
  import gradio as gr
15
  import spaces
16
  import cv2
 
18
  import tempfile
19
  import os
20
  from collections import defaultdict
21
+ from typing import Dict, List, Tuple, Optional
22
 
23
  import supervision as sv
24
+ from ultralytics import YOLO, RTDETR
25
 
26
+ # Detection modes with COCO class IDs
27
+ DETECTION_MODES = {
28
+ "All Objects (Street)": {
29
+ "class_ids": [0, 1, 2, 3, 5, 7, 17, 18, 19],
30
+ "labels": {0: "person", 1: "bicycle", 2: "car", 3: "motorcycle",
31
+ 5: "bus", 7: "truck", 17: "horse", 18: "sheep", 19: "cow"},
32
+ },
33
+ "People Only": {
34
+ "class_ids": [0],
35
+ "labels": {0: "person"},
36
+ },
37
+ "Vehicles Only": {
38
+ "class_ids": [1, 2, 3, 5, 7],
39
+ "labels": {1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"},
40
+ },
41
+ "Animals (Sheep/Cow/Horse)": {
42
+ "class_ids": [17, 18, 19],
43
+ "labels": {17: "horse", 18: "sheep", 19: "cow"},
44
+ },
45
+ "Sheep Only": {
46
+ "class_ids": [18],
47
+ "labels": {18: "sheep"},
48
+ },
49
  }
50
 
51
+ MODEL_CACHE: Dict[str, object] = {}
52
 
53
  def get_model(model_name: str):
54
+ """Load and cache detection model."""
55
  if model_name not in MODEL_CACHE:
56
  model_map = {
57
+ "YOLOv8n (Fast)": ("yolov8n.pt", "yolo"),
58
+ "YOLOv8s (Balanced)": ("yolov8s.pt", "yolo"),
59
+ "YOLOv8m (Accurate)": ("yolov8m.pt", "yolo"),
60
+ "YOLOv8x (Best YOLO)": ("yolov8x.pt", "yolo"),
61
+ "RT-DETR-L (Dense Scenes)": ("rtdetr-l.pt", "rtdetr"),
62
  }
63
+ model_file, model_type = model_map.get(model_name, ("yolov8s.pt", "yolo"))
64
+ if model_type == "rtdetr":
65
+ MODEL_CACHE[model_name] = RTDETR(model_file)
66
+ else:
67
+ MODEL_CACHE[model_name] = YOLO(model_file)
68
  return MODEL_CACHE[model_name]
69
 
70
 
71
+ def point_side(point: Tuple[float, float], line: Tuple[Tuple[float, float], Tuple[float, float]]) -> float:
72
+ """Return the sign of a point relative to a line using cross product."""
73
+ (x1, y1), (x2, y2) = line
74
+ x, y = point
75
+ return (x - x1) * (y2 - y1) - (y - y1) * (x2 - x1)
76
+
77
+
78
+ def crossed_line(prev_point: Tuple[float, float], curr_point: Tuple[float, float],
79
+ line: Tuple[Tuple[float, float], Tuple[float, float]]) -> bool:
80
+ """Check if movement from prev_point to curr_point crosses the line."""
81
+ prev_side = point_side(prev_point, line)
82
+ curr_side = point_side(curr_point, line)
83
+ return prev_side * curr_side < 0
84
+
85
+
86
+ def bbox_center(bbox: Tuple[int, int, int, int]) -> Tuple[float, float]:
87
+ """Get center point of bounding box."""
88
+ x1, y1, x2, y2 = bbox
89
+ return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
90
+
91
+
92
+ def determine_outside_side(line: Tuple[Tuple[float, float], Tuple[float, float]],
93
+ frame_height: int) -> float:
94
+ """Determine which side of the line is 'outside' based on line position."""
95
+ (x1, y1), (x2, y2) = line
96
+ mid_y = (y1 + y2) / 2.0
97
+ mid_x = (x1 + x2) / 2.0
98
+ # If line is in upper half, outside is above (y=0)
99
+ # If line is in lower half, outside is below (y=height)
100
+ if mid_y < frame_height / 2.0:
101
+ reference_point = (mid_x, 0.0)
102
+ else:
103
+ reference_point = (mid_x, float(frame_height))
104
+ return point_side(reference_point, line)
105
+
106
+
107
+ @spaces.GPU(duration=180)
108
+ def process_video(
109
+ video_path: str,
110
+ detection_model: str,
111
+ detection_mode: str,
112
+ confidence: float,
113
+ line_position: float,
114
+ track_buffer: int,
115
+ activation_threshold: float,
116
+ ):
117
+ """Process video with advanced tracking and counting."""
118
  if video_path is None:
119
+ return None, "Please upload a video file."
120
 
121
+ # Get model and detection config
122
  model = get_model(detection_model)
123
+ mode_config = DETECTION_MODES.get(detection_mode, DETECTION_MODES["All Objects (Street)"])
124
+ target_class_ids = set(mode_config["class_ids"])
125
+ class_labels = mode_config["labels"]
126
+
127
+ # Open video
128
  cap = cv2.VideoCapture(video_path)
129
  if not cap.isOpened():
130
+ return None, "Failed to open video file."
131
 
132
  fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
133
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
134
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
135
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
136
 
137
+ # Setup output video
138
  output_path = tempfile.mktemp(suffix=".mp4")
139
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
140
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
141
 
142
+ # Initialize tracker with optimized parameters for dense scenes
143
+ tracker = sv.ByteTrack(
144
+ track_activation_threshold=activation_threshold,
145
+ lost_track_buffer=track_buffer,
146
+ minimum_matching_threshold=0.7,
147
+ frame_rate=fps,
148
+ )
149
+
150
+ # Setup counting line (absolute coordinates)
151
  line_y = int(height * line_position)
152
+ line_start = (0, line_y)
153
+ line_end = (width, line_y)
154
+ abs_line = ((0.0, float(line_y)), (float(width), float(line_y)))
155
+ outside_side = determine_outside_side(abs_line, height)
156
 
157
+ # Annotators
158
  box_annotator = sv.BoxAnnotator(thickness=2)
159
+ label_annotator = sv.LabelAnnotator(text_scale=0.4, text_thickness=1)
160
+ trace_annotator = sv.TraceAnnotator(thickness=1, trace_length=50)
 
161
 
162
+ # Tracking state
163
+ track_last_center: Dict[int, Tuple[float, float]] = {}
164
+ track_class: Dict[int, str] = {}
165
+ counted_tracks: set = set()
166
+
167
+ # Counters
168
+ total_in, total_out = 0, 0
169
+ class_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: {"in": 0, "out": 0})
170
+
171
+ frame_idx = 0
172
+ max_simultaneous = 0
173
 
174
  while True:
175
  ret, frame = cap.read()
176
  if not ret:
177
  break
178
+
179
+ # Run detection
180
  results = model.predict(frame, conf=confidence, verbose=False)[0]
181
+
182
+ # Filter detections by target classes
183
+ boxes = results.boxes
184
+ if boxes is not None and len(boxes) > 0:
185
+ mask = np.array([int(cls) in target_class_ids for cls in boxes.cls])
186
+ if mask.any():
187
+ filtered_boxes = boxes[mask]
188
+ detections = sv.Detections(
189
+ xyxy=filtered_boxes.xyxy.cpu().numpy(),
190
+ confidence=filtered_boxes.conf.cpu().numpy(),
191
+ class_id=filtered_boxes.cls.cpu().numpy().astype(int),
192
+ )
193
+ else:
194
+ detections = sv.Detections.empty()
195
+ else:
196
+ detections = sv.Detections.empty()
197
+
198
+ # Track objects
199
  detections = tracker.update_with_detections(detections)
 
200
 
201
+ # Update max simultaneous count
202
+ if len(detections) > max_simultaneous:
203
+ max_simultaneous = len(detections)
 
 
 
204
 
205
+ # Check line crossings with proper geometry
206
+ if detections.tracker_id is not None:
207
+ for idx in range(len(detections)):
208
+ track_id = int(detections.tracker_id[idx])
209
+ x1, y1, x2, y2 = detections.xyxy[idx]
210
  class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
211
+ class_name = class_labels.get(class_id, f"class_{class_id}")
212
+
213
+ current_center = bbox_center((int(x1), int(y1), int(x2), int(y2)))
214
+ track_class[track_id] = class_name
215
+
216
+ if track_id in track_last_center and track_id not in counted_tracks:
217
+ prev_center = track_last_center[track_id]
218
+
219
+ if crossed_line(prev_center, current_center, abs_line):
220
+ prev_side = point_side(prev_center, abs_line)
221
+ curr_side = point_side(current_center, abs_line)
222
+
223
+ # Determine direction based on which side is "outside"
224
+ if prev_side * outside_side >= 0 and curr_side * outside_side < 0:
225
+ total_in += 1
226
+ class_counts[class_name]["in"] += 1
227
+ elif prev_side * outside_side < 0 and curr_side * outside_side >= 0:
228
+ total_out += 1
229
+ class_counts[class_name]["out"] += 1
230
+
231
+ counted_tracks.add(track_id)
232
+
233
+ track_last_center[track_id] = current_center
234
 
235
+ # Annotate frame
236
  annotated = frame.copy()
237
+
238
+ # Draw counting line
239
+ cv2.line(annotated, line_start, line_end, (0, 0, 255), 3)
240
+ cv2.putText(annotated, "COUNTING LINE", (10, line_y - 10),
241
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
242
+
243
+ # Draw traces, boxes, and labels
244
  annotated = trace_annotator.annotate(annotated, detections)
245
  annotated = box_annotator.annotate(annotated, detections)
246
+
247
  labels = []
248
+ if detections.tracker_id is not None:
249
+ for idx in range(len(detections)):
250
+ class_id = int(detections.class_id[idx]) if detections.class_id is not None else 0
251
+ class_name = class_labels.get(class_id, f"class_{class_id}")
252
+ track_id = int(detections.tracker_id[idx])
253
+ labels.append(f"{class_name} #{track_id}")
254
  annotated = label_annotator.annotate(annotated, detections, labels)
255
+
256
+ # Draw stats overlay
257
+ overlay_h = 80
258
+ cv2.rectangle(annotated, (5, 5), (300, overlay_h), (0, 0, 0), -1)
259
+ cv2.putText(annotated, f"IN: {total_in} | OUT: {total_out}", (15, 30),
260
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
261
+ cv2.putText(annotated, f"Net: {total_in - total_out} | Now: {len(detections)}", (15, 55),
262
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
263
+ cv2.putText(annotated, f"Frame: {frame_idx}/{total_frames}", (15, 75),
264
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
265
+
266
  out.write(annotated)
267
  frame_idx += 1
268
 
269
  cap.release()
270
  out.release()
271
+
272
+ # Convert to H.264 for browser compatibility
273
  final_path = tempfile.mktemp(suffix=".mp4")
274
+ os.system(f'ffmpeg -y -i {output_path} -c:v libx264 -preset fast -crf 23 {final_path} -loglevel quiet')
275
  if os.path.exists(final_path) and os.path.getsize(final_path) > 0:
276
  os.remove(output_path)
277
  output_path = final_path
278
 
279
+ # Generate statistics report
280
+ unique_tracks = len(track_last_center)
281
+ stats = "## Counting Results\n\n"
282
+ stats += f"**Total Entered:** {total_in}\n"
283
+ stats += f"**Total Exited:** {total_out}\n"
284
+ stats += f"**Net Count:** {total_in - total_out}\n"
285
+ stats += f"**Unique Tracks:** {unique_tracks}\n"
286
+ stats += f"**Max Simultaneous:** {max_simultaneous}\n\n"
287
+
288
+ if class_counts:
289
+ stats += "### By Class\n"
290
+ for cls, counts in sorted(class_counts.items()):
291
+ net = counts['in'] - counts['out']
292
+ stats += f"- **{cls}**: IN={counts['in']}, OUT={counts['out']}, Net={net}\n"
293
+
294
+ stats += f"\n### Video Info\n"
295
+ stats += f"- Frames: {frame_idx}\n"
296
+ stats += f"- Resolution: {width}x{height}\n"
297
+ stats += f"- FPS: {fps}\n"
298
+
299
  return output_path, stats
300
 
301
 
302
+ # Build Gradio interface
303
+ with gr.Blocks(analytics_enabled=False, title="CCTV Customer Analytics") as demo:
304
+ gr.Markdown("""
305
+ # CCTV Customer Analytics
306
+
307
+ Advanced object detection, tracking, and counting system.
308
+ Optimized for counting large numbers of animals and vehicles in crowded scenes.
309
+
310
+ **Tips for best results:**
311
+ - Use **RT-DETR** model for dense/crowded scenes (sheep flocks, traffic)
312
+ - Lower **confidence** (0.15-0.25) to detect more objects
313
+ - Increase **track buffer** (60-90) for objects that temporarily disappear
314
+ - Adjust **line position** to where objects cross most clearly
315
+ """)
316
+
317
  with gr.Row():
318
+ with gr.Column(scale=1):
319
  video_input = gr.Video(label="Upload Video")
320
+
321
+ model_dropdown = gr.Dropdown(
322
+ choices=[
323
+ "YOLOv8n (Fast)",
324
+ "YOLOv8s (Balanced)",
325
+ "YOLOv8m (Accurate)",
326
+ "YOLOv8x (Best YOLO)",
327
+ "RT-DETR-L (Dense Scenes)",
328
+ ],
329
+ value="YOLOv8s (Balanced)",
330
+ label="Detection Model",
331
+ )
332
+
333
+ mode_dropdown = gr.Dropdown(
334
+ choices=list(DETECTION_MODES.keys()),
335
+ value="All Objects (Street)",
336
+ label="Detection Mode",
337
+ )
338
+
339
+ confidence_slider = gr.Slider(
340
+ 0.05, 0.9, value=0.25, step=0.05,
341
+ label="Confidence Threshold",
342
+ info="Lower = more detections, higher = fewer false positives"
343
+ )
344
+
345
+ line_slider = gr.Slider(
346
+ 0.1, 0.9, value=0.5, step=0.05,
347
+ label="Line Position",
348
+ info="Vertical position of counting line (0=top, 1=bottom)"
349
+ )
350
+
351
+ with gr.Accordion("Advanced Tracking Settings", open=False):
352
+ track_buffer = gr.Slider(
353
+ 10, 120, value=45, step=5,
354
+ label="Track Buffer",
355
+ info="Frames to keep lost tracks (higher for crowded scenes)"
356
+ )
357
+
358
+ activation_threshold = gr.Slider(
359
+ 0.1, 0.5, value=0.2, step=0.05,
360
+ label="Track Activation Threshold",
361
+ info="Lower = easier to start new tracks"
362
+ )
363
+
364
+ submit_btn = gr.Button("Process Video", variant="primary", size="lg")
365
+
366
+ with gr.Column(scale=1):
367
  video_output = gr.Video(label="Processed Video")
368
  stats_output = gr.Markdown(label="Statistics")
369
+
370
+ submit_btn.click(
371
+ fn=process_video,
372
+ inputs=[
373
+ video_input, model_dropdown, mode_dropdown,
374
+ confidence_slider, line_slider, track_buffer, activation_threshold
375
+ ],
376
+ outputs=[video_output, stats_output],
377
+ api_name=False,
378
+ )
379
+
380
+ gr.Markdown("""
381
+ ---
382
+ **Models:**
383
+ - **YOLOv8n/s/m/x**: General purpose, good for most scenarios
384
+ - **RT-DETR-L**: Transformer-based, better for dense/crowded scenes (recommended for sheep counting)
385
+
386
+ **Detection Modes:**
387
+ - **All Objects**: People + vehicles + animals
388
+ - **Animals**: Sheep, cows, horses
389
+ - **Sheep Only**: Optimized for sheep counting
390
+ """)
391
 
392
  if __name__ == "__main__":
393
  demo.launch()