""" Professional WebRTC handler for real-time video streaming and movement analysis Using FastRTC (the current WebRTC standard, replaces deprecated gradio-webrtc) Based on: https://fastrtc.org and https://www.gradio.app/guides/object-detection-from-webcam-with-webrtc """ import cv2 import numpy as np from typing import Optional, Dict, Any, Tuple from collections import deque import time import logging import os from .pose_estimation import get_pose_estimator from .notation_engine import MovementAnalyzer from .visualizer import PoseVisualizer logger = logging.getLogger(__name__) # Official Gradio WebRTC approach (compatible with NumPy 1.x) try: from gradio_webrtc import WebRTC HAS_WEBRTC_COMPONENT = True except ImportError: HAS_WEBRTC_COMPONENT = False class RealtimeMovementAnalyzer: """Real-time movement analyzer for WebRTC streams following Gradio 5 best practices""" # Gradio component compatibility events = {} def __init__(self, model: str = "mediapipe-lite", buffer_size: int = 30): """ Initialize real-time movement analyzer. Args: model: Pose estimation model optimized for real-time processing buffer_size: Number of frames to buffer for analysis """ self.model = model self.pose_estimator = get_pose_estimator(model) self.movement_analyzer = MovementAnalyzer(fps=30.0) self.visualizer = PoseVisualizer( trail_length=10, show_skeleton=True, show_trails=True, show_direction_arrows=True, show_metrics=True ) # Real-time buffers self.pose_buffer = deque(maxlen=buffer_size) self.metrics_buffer = deque(maxlen=buffer_size) # Performance tracking self.frame_count = 0 self.last_fps_update = time.time() self.current_fps = 0.0 # Current metrics for display self.current_metrics = { "direction": "stationary", "intensity": "low", "fluidity": 0.0, "expansion": 0.5, "fps": 0.0 } def process_frame(self, image: np.ndarray, conf_threshold: float = 0.5) -> np.ndarray: """ Process a single frame from WebRTC stream for real-time movement analysis. Args: image: Input frame from webcam as numpy array (RGB format from WebRTC) conf_threshold: Confidence threshold for pose detection Returns: Processed frame with pose overlay and movement metrics """ if image is None: return None # Convert RGB to BGR for OpenCV processing frame_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Update frame count and FPS self.frame_count += 1 current_time = time.time() if current_time - self.last_fps_update >= 1.0: self.current_fps = self.frame_count / (current_time - self.last_fps_update) self.frame_count = 0 self.last_fps_update = current_time self.current_metrics["fps"] = self.current_fps # Pose detection pose_results = self.pose_estimator.detect(frame_bgr) # Store pose data self.pose_buffer.append(pose_results) # Calculate movement metrics if we have enough frames if len(self.pose_buffer) >= 2: recent_poses = list(self.pose_buffer)[-10:] # Last 10 frames for analysis try: # Analyze movement from recent poses movement_metrics = self.movement_analyzer.analyze_movement(recent_poses) if movement_metrics: latest_metrics = movement_metrics[-1] self.current_metrics.update({ "direction": latest_metrics.direction.value if latest_metrics.direction else "stationary", "intensity": latest_metrics.intensity.value if latest_metrics.intensity else "low", "fluidity": latest_metrics.fluidity if latest_metrics.fluidity is not None else 0.0, "expansion": latest_metrics.expansion if latest_metrics.expansion is not None else 0.5 }) self.metrics_buffer.append(self.current_metrics.copy()) except Exception as e: logger.warning(f"Movement analysis error: {e}") # Apply visualization overlays output_frame = self._apply_visualization(frame_bgr, pose_results, self.current_metrics) # Convert back to RGB for WebRTC output output_rgb = cv2.cvtColor(output_frame, cv2.COLOR_BGR2RGB) return output_rgb def _apply_visualization(self, frame: np.ndarray, pose_results: list, metrics: dict) -> np.ndarray: """Apply pose and movement visualization overlays""" output_frame = frame.copy() # Draw pose skeleton if detected if pose_results: for pose_result in pose_results: # Draw skeleton if hasattr(self.visualizer, 'draw_skeleton'): output_frame = self.visualizer.draw_skeleton(output_frame, pose_result.keypoints) # Draw keypoints for keypoint in pose_result.keypoints: if keypoint.confidence > 0.5: x = int(keypoint.x * frame.shape[1]) y = int(keypoint.y * frame.shape[0]) cv2.circle(output_frame, (x, y), 5, (0, 255, 0), -1) # Draw real-time metrics overlay self._draw_metrics_overlay(output_frame, metrics) return output_frame def _draw_metrics_overlay(self, frame: np.ndarray, metrics: dict): """Draw real-time metrics overlay following professional UI standards""" h, w = frame.shape[:2] # Semi-transparent background overlay = frame.copy() cv2.rectangle(overlay, (10, 10), (320, 160), (0, 0, 0), -1) cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame) # Header cv2.putText(frame, "Real-time Movement Analysis", (20, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) # Metrics y_offset = 60 spacing = 22 cv2.putText(frame, f"Direction: {metrics['direction']}", (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) y_offset += spacing cv2.putText(frame, f"Intensity: {metrics['intensity']}", (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) y_offset += spacing cv2.putText(frame, f"Fluidity: {metrics['fluidity']:.2f}", (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) y_offset += spacing cv2.putText(frame, f"FPS: {metrics['fps']:.1f}", (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1) def get_current_metrics(self) -> dict: """Get current movement metrics for external display""" return self.current_metrics.copy() def get_rtc_configuration(): """ Get RTC configuration for WebRTC. Uses Twilio TURN servers if credentials are available, otherwise uses default. """ # For local development, no TURN servers needed # For cloud deployment, set TWILIO_ACCOUNT_SID and TWILIO_AUTH_TOKEN twilio_account_sid = os.getenv("TWILIO_ACCOUNT_SID") twilio_auth_token = os.getenv("TWILIO_AUTH_TOKEN") if twilio_account_sid and twilio_auth_token: # Use Twilio TURN servers for cloud deployment return { "iceServers": [ {"urls": ["stun:global.stun.twilio.com:3478"]}, { "urls": ["turn:global.turn.twilio.com:3478?transport=udp"], "username": twilio_account_sid, "credential": twilio_auth_token, }, { "urls": ["turn:global.turn.twilio.com:3478?transport=tcp"], "username": twilio_account_sid, "credential": twilio_auth_token, }, ] } else: # Default configuration for local development return { "iceServers": [ {"urls": ["stun:stun.l.google.com:19302"]} ] } # Global analyzer instance for demo _analyzer = None def get_analyzer(model: str = "mediapipe-lite") -> RealtimeMovementAnalyzer: """Get or create analyzer instance""" global _analyzer if _analyzer is None or _analyzer.model != model: _analyzer = RealtimeMovementAnalyzer(model) return _analyzer def webrtc_detection(image: np.ndarray, model: str, conf_threshold: float = 0.5) -> np.ndarray: """ Main detection function for WebRTC streaming. Compatible with Gradio 5 WebRTC streaming API. Args: image: Input frame from webcam (RGB format) model: Pose estimation model name conf_threshold: Confidence threshold for pose detection Returns: Processed frame with pose overlay and metrics """ analyzer = get_analyzer(model) return analyzer.process_frame(image, conf_threshold) def get_webrtc_interface(): """ Create streaming interface using built-in Gradio components. Avoids NumPy 2.x dependency conflicts with FastRTC. Returns: Tuple of (streaming_config, rtc_configuration) """ rtc_config = get_rtc_configuration() # Use built-in Gradio streaming capabilities streaming_config = { "sources": ["webcam"], "streaming": True, "mirror_webcam": False } return streaming_config, rtc_config # Compatibility exports with Gradio component attributes class WebRTCMovementAnalyzer(RealtimeMovementAnalyzer): """Real-time movement analyzer for WebRTC streams following Gradio 5 best practices""" events = {} # Gradio component compatibility class WebRTCGradioInterface: """Create streaming interface using built-in Gradio components. Avoids NumPy 2.x dependency conflicts with FastRTC.""" events = {} # Gradio component compatibility @staticmethod def get_config(): return get_webrtc_interface()