| |
| """ |
| SAM2 Loader with Hugging Face Hub integration |
| Provides SAM2Predictor class with memory management and optimization features |
| Updated to use Hugging Face Hub models instead of direct downloads |
| (Enhanced logging and exception safety) |
| """ |
|
|
| import os |
| import gc |
| import torch |
| import logging |
| import numpy as np |
| from pathlib import Path |
| from typing import Optional, Any, Dict, List, Tuple |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class SAM2Predictor: |
| """ |
| T4-optimized SAM2 video predictor wrapper with memory management |
| """ |
| |
| def __init__(self, device: torch.device, model_size: str = "small"): |
| logger.info(f"[SAM2Predictor.__init__] device={device}, model_size={model_size}") |
| self.device = device |
| self.model_size = model_size |
| self.predictor = None |
| self.model = None |
| self._load_predictor() |
| |
| def _load_predictor(self): |
| """Load SAM2 predictor with Hugging Face Hub integration""" |
| try: |
| logger.info("[SAM2Predictor._load_predictor] Loading SAM2 predictor...") |
| from sam2.build_sam import build_sam2_video_predictor |
| |
| checkpoint_path = self._get_hf_checkpoint() |
| if not checkpoint_path: |
| logger.error(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
| raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
| |
| model_cfg = self._get_model_config() |
| logger.info(f"[SAM2Predictor._load_predictor] Using model_cfg: {model_cfg}") |
| |
| self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device) |
| self._optimize_for_t4() |
| logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub") |
| except ImportError as e: |
| logger.error(f"SAM2 import failed: {e}") |
| raise RuntimeError("SAM2 not available - check sam2 installation") |
| except Exception as e: |
| logger.error(f"SAM2 loading failed: {e}", exc_info=True) |
| raise |
| |
| def _get_hf_checkpoint(self) -> Optional[str]: |
| """Download checkpoint from Hugging Face Hub""" |
| try: |
| logger.info(f"[SAM2Predictor._get_hf_checkpoint] Downloading checkpoint...") |
| from huggingface_hub import hf_hub_download |
| |
| repo_mapping = { |
| "small": "facebook/sam2-hiera-small", |
| "base": "facebook/sam2-hiera-base-plus", |
| "large": "facebook/sam2-hiera-large" |
| } |
| filename_mapping = { |
| "small": "sam2_hiera_small.pt", |
| "base": "sam2_hiera_base_plus.pt", |
| "large": "sam2_hiera_large.pt" |
| } |
| if self.model_size not in repo_mapping: |
| logger.error(f"Unknown model size: {self.model_size}") |
| return None |
| repo_id = repo_mapping[self.model_size] |
| filename = filename_mapping[self.model_size] |
| logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}") |
| checkpoint_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| cache_dir=None, |
| force_download=False, |
| token=None |
| ) |
| logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}") |
| return checkpoint_path |
| except Exception as e: |
| logger.error(f"HF Hub download failed: {e}") |
| return self._fallback_local_checkpoint() |
| |
| def _fallback_local_checkpoint(self) -> Optional[str]: |
| """Fallback to local checkpoint files""" |
| try: |
| checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt" |
| if Path(checkpoint_path).exists(): |
| logger.info(f"Using local checkpoint: {checkpoint_path}") |
| return checkpoint_path |
| else: |
| logger.error(f"Local checkpoint not found: {checkpoint_path}") |
| return None |
| except Exception as e: |
| logger.error(f"Local checkpoint fallback failed: {e}") |
| return None |
| |
| def _get_model_config(self) -> str: |
| """Get the appropriate model config file""" |
| config_mapping = { |
| "small": "sam2_hiera_s.yaml", |
| "base": "sam2_hiera_b+.yaml", |
| "large": "sam2_hiera_l.yaml" |
| } |
| cfg = config_mapping.get(self.model_size, "sam2_hiera_s.yaml") |
| logger.info(f"[SAM2Predictor._get_model_config] Returning config: {cfg}") |
| return cfg |
| |
| def _optimize_for_t4(self): |
| """Apply T4-specific optimizations""" |
| try: |
| logger.info("[SAM2Predictor._optimize_for_t4] Optimizing for T4...") |
| if hasattr(self.predictor, "model") and self.predictor.model is not None: |
| self.model = self.predictor.model |
| self.model = self.model.half().to(self.device) |
| self.model = self.model.to(memory_format=torch.channels_last) |
| logger.info("SAM2: fp16 + channels_last applied for T4 optimization") |
| except Exception as e: |
| logger.warning(f"SAM2 T4 optimization warning: {e}", exc_info=True) |
| |
| def init_state(self, video_path: str): |
| logger.info(f"[SAM2Predictor.init_state] Initializing video state for: {video_path}") |
| if self.predictor is None: |
| logger.error("Predictor not loaded in init_state") |
| raise RuntimeError("Predictor not loaded") |
| try: |
| state = self.predictor.init_state(video_path=video_path) |
| logger.info("[SAM2Predictor.init_state] Video state initialized OK") |
| return state |
| except Exception as e: |
| logger.error(f"Failed to initialize video state: {e}", exc_info=True) |
| raise |
| |
| def add_new_points(self, inference_state, frame_idx: int, obj_id: int, |
| points: np.ndarray, labels: np.ndarray): |
| logger.info(f"[SAM2Predictor.add_new_points] Adding points for frame {frame_idx}, obj {obj_id}") |
| if self.predictor is None: |
| logger.error("Predictor not loaded in add_new_points") |
| raise RuntimeError("Predictor not loaded") |
| try: |
| out = self.predictor.add_new_points( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| obj_id=obj_id, |
| points=points, |
| labels=labels |
| ) |
| logger.info(f"[SAM2Predictor.add_new_points] Points added OK") |
| return out |
| except Exception as e: |
| logger.error(f"Failed to add new points: {e}", exc_info=True) |
| raise |
| |
| def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int, |
| points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True): |
| logger.info(f"[SAM2Predictor.add_new_points_or_box] Adding points/box for frame {frame_idx}, obj {obj_id}") |
| if self.predictor is None: |
| logger.error("Predictor not loaded in add_new_points_or_box") |
| raise RuntimeError("Predictor not loaded") |
| try: |
| if hasattr(self.predictor, 'add_new_points_or_box'): |
| out = self.predictor.add_new_points_or_box( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| obj_id=obj_id, |
| points=points, |
| labels=labels, |
| clear_old_points=clear_old_points |
| ) |
| logger.info(f"[SAM2Predictor.add_new_points_or_box] Used new API, points/box added OK") |
| return out |
| else: |
| out = self.predictor.add_new_points( |
| inference_state=inference_state, |
| frame_idx=frame_idx, |
| obj_id=obj_id, |
| points=points, |
| labels=labels |
| ) |
| logger.info(f"[SAM2Predictor.add_new_points_or_box] Used fallback, points added OK") |
| return out |
| except Exception as e: |
| logger.error(f"Failed to add new points or box: {e}", exc_info=True) |
| raise |
| |
| def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs): |
| logger.info(f"[SAM2Predictor.propagate_in_video] Propagating in video...") |
| if self.predictor is None: |
| logger.error("Predictor not loaded in propagate_in_video") |
| raise RuntimeError("Predictor not loaded") |
| try: |
| out = self.predictor.propagate_in_video(inference_state, **kwargs) |
| logger.info(f"[SAM2Predictor.propagate_in_video] Propagation OK") |
| return out |
| except Exception as e: |
| logger.error(f"Failed to propagate in video: {e}", exc_info=True) |
| raise |
| |
| def prune_state(self, inference_state, keep: int): |
| logger.info(f"[SAM2Predictor.prune_state] Pruning state to keep {keep} frames...") |
| try: |
| if hasattr(inference_state, 'cached_features'): |
| cached_keys = list(inference_state.cached_features.keys()) |
| if len(cached_keys) > keep: |
| keys_to_remove = cached_keys[:-keep] |
| for key in keys_to_remove: |
| if key in inference_state.cached_features: |
| del inference_state.cached_features[key] |
| logger.debug(f"Pruned {len(keys_to_remove)} old cached features") |
| if hasattr(inference_state, 'point_inputs_per_obj'): |
| for obj_id in list(inference_state.point_inputs_per_obj.keys()): |
| obj_inputs = inference_state.point_inputs_per_obj[obj_id] |
| if len(obj_inputs) > keep: |
| recent_keys = sorted(obj_inputs.keys())[-keep:] |
| new_inputs = {k: obj_inputs[k] for k in recent_keys} |
| inference_state.point_inputs_per_obj[obj_id] = new_inputs |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| except Exception as e: |
| logger.debug(f"State pruning warning: {e}", exc_info=True) |
| |
| def clear_memory(self): |
| logger.info("[SAM2Predictor.clear_memory] Clearing GPU memory") |
| try: |
| if self.device.type == 'cuda': |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| torch.cuda.ipc_collect() |
| gc.collect() |
| except Exception as e: |
| logger.warning(f"Memory clearing warning: {e}", exc_info=True) |
| |
| def get_memory_usage(self) -> Dict[str, float]: |
| logger.info("[SAM2Predictor.get_memory_usage] Checking memory usage") |
| if self.device.type != 'cuda': |
| return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
| try: |
| allocated = torch.cuda.memory_allocated(self.device) / (1024**3) |
| reserved = torch.cuda.memory_reserved(self.device) / (1024**3) |
| free, total = torch.cuda.mem_get_info(self.device) |
| free_gb = free / (1024**3) |
| return { |
| "allocated_gb": allocated, |
| "reserved_gb": reserved, |
| "free_gb": free_gb, |
| "total_gb": total / (1024**3) |
| } |
| except Exception as e: |
| logger.warning(f"Error checking memory usage: {e}", exc_info=True) |
| return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
| |
| def __del__(self): |
| logger.info("[SAM2Predictor.__del__] Cleaning up...") |
| try: |
| if hasattr(self, 'predictor') and self.predictor is not None: |
| del self.predictor |
| if hasattr(self, 'model') and self.model is not None: |
| del self.model |
| self.clear_memory() |
| except Exception as e: |
| logger.warning(f"Error in __del__: {e}", exc_info=True) |
|
|