# handler.py import os import io import base64 from typing import Any, Dict, List, Union import torch from PIL import Image from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, AutoModelForVision2Seq, AutoModelForCausalLM, BitsAndBytesConfig, ) # --- Env defaults --- os.environ.setdefault("HF_TRUST_REMOTE_CODE", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") USE_4BIT = os.environ.get("USE_4BIT", "1") not in {"0", "false", "False"} os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1") # mrope 相关键(部分分支可能访问 kwargs 里的这些键,而不是从 generation_config 取) MROPE_KEYS = ("mrope_section", "mrope_section_output", "mrope_theta", "mrope_t") def _b64_to_pil(b64str: str) -> Image.Image: return Image.open(io.BytesIO(base64.b64decode(b64str))).convert("RGB") def _normalize_images(img_field: Union[str, List[str], None]) -> List[Image.Image]: if img_field is None: return [] if isinstance(img_field, str): return [_b64_to_pil(img_field)] if isinstance(img_field, list): return [_b64_to_pil(s) for s in img_field] raise ValueError("image_b64 must be a base64 string or a list of base64 strings.") def _to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: out = {} for k, v in batch.items(): if hasattr(v, "to"): out[k] = v.to(device) else: out[k] = v return out class EndpointHandler: """ 支持两种输入: 1) Inference API 形态: {"inputs": {...}, "parameters": {...}} 2) 扁平形态: {"prompt": "...", "image_b64": "...|[...]", "max_new_tokens": ..., "temperature": ..., "top_p": ..., "force_text": ...} 返回:{"text": "..."} """ def __init__(self, path: str = "/repository"): self.model_id = path self.cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True) # 判定是否多模态 vl_model_types = { "qwen2_5_vl", "qwen2_vl", "mllama", "fuyu", "phi4multimodal", "git", "gotocr2", "qwen2_5_vl_moe" } self.is_vl = getattr(self.cfg, "model_type", "").lower() in vl_model_types # 处理器 / 分词器 self.processor = None self.tokenizer = None if self.is_vl: self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True) self.tokenizer = getattr(self.processor, "tokenizer", None) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, trust_remote_code=True, use_fast=True ) else: self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, trust_remote_code=True, use_fast=True ) # 精度 / 量化 dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 quant_cfg = None if USE_4BIT: try: quant_cfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) except Exception: quant_cfg = None self.max_memory = {0: "38GiB", 1: "38GiB", "cpu": "120GiB"} ModelClass = AutoModelForVision2Seq if self.is_vl else AutoModelForCausalLM self.model = ModelClass.from_pretrained( self.model_id, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True, device_map="auto", max_memory=self.max_memory, quantization_config=quant_cfg, ).eval() torch.backends.cuda.matmul.allow_tf32 = True self.default_gen_kwargs = dict( max_new_tokens=128, do_sample=False, temperature=0.0, top_p=1.0, ) has_mrope = any(hasattr(self.cfg, k) for k in MROPE_KEYS) print( f"[handler] transformers={self._tf_version_safe()}, " f"model_type={getattr(self.cfg, 'model_type', 'unknown')}, " f"has_mrope_in_cfg={has_mrope}" ) # 先在 generation_config 上打底 self._ensure_mrope_on(self.model.generation_config) # ---------- helpers ---------- def _tf_version_safe(self) -> str: try: import transformers as _tf # noqa return getattr(_tf, "__version__", "unknown") except Exception: return "unknown" def _ensure_mrope_on(self, gen_cfg: Any) -> None: if gen_cfg is None: return for k in MROPE_KEYS: if not hasattr(gen_cfg, k): try: setattr(gen_cfg, k, None) except Exception: pass # 尽力把内部 dict 也补齐(不同实现 to_dict 用法不同,这里 best-effort) for attr_name in ("__dict__", "_internal_dict"): try: d = getattr(gen_cfg, attr_name, None) if isinstance(d, dict) and k not in d: d[k] = None except Exception: pass def _inject_mrope_into_kwargs(self, gen_kwargs: Dict[str, Any]) -> None: """ 关键修复:把 mrope_* 直接注入到即将传入 generate 的 kwargs 中, 防止底层代码对 kwargs['mrope_section'] 等键进行索引时报 KeyError。 """ for k in MROPE_KEYS: if k not in gen_kwargs: # 优先取 generation_config 上的值,否则 None v = getattr(self.model.generation_config, k, None) gen_kwargs[k] = v def _build_and_merge_payload(self, data: Dict[str, Any]) -> Dict[str, Any]: if isinstance(data, dict) and "inputs" in data and not isinstance(data["inputs"], str): payload = data.get("inputs") or {} params = data.get("parameters") or {} merged = dict(payload) for k, v in params.items(): merged.setdefault(k, v) return merged return data def _decode_outputs(self, outputs: Any) -> str: if hasattr(self.processor, "batch_decode"): try: return self.processor.batch_decode(outputs, skip_special_tokens=True)[0] except Exception: pass if self.tokenizer is not None: try: return self.tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception: pass try: return str(outputs[0].tolist()) except Exception: return str(outputs) # ---------- inference ---------- @torch.inference_mode() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: data = self._build_and_merge_payload(data) prompt = str(data.get("prompt", "")).strip() max_new_tokens = int(data.get("max_new_tokens", self.default_gen_kwargs["max_new_tokens"])) temperature = float(data.get("temperature", self.default_gen_kwargs["temperature"])) top_p = float(data.get("top_p", self.default_gen_kwargs["top_p"])) force_text = bool(data.get("force_text", False)) gen_kwargs: Dict[str, Any] = dict( max_new_tokens=max_new_tokens, do_sample=(temperature > 0.0), temperature=temperature, top_p=top_p, ) # 生成前:双保险 self._ensure_mrope_on(self.model.generation_config) self._inject_mrope_into_kwargs(gen_kwargs) # 文本路径 def run_text_path(text: str) -> str: tok = self.tokenizer if tok is None and getattr(self, "processor", None) is not None: tok = getattr(self.processor, "tokenizer", None) if tok is None: tok = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True, use_fast=True) self.tokenizer = tok txt = text if text else "Hello" inputs = tok(txt, return_tensors="pt") if "input_ids" not in inputs or inputs["input_ids"].numel() == 0: inputs = tok("Hello", return_tensors="pt") inputs = _to_device(inputs, self.model.device) # 再次注入,避免外层修改丢失 local_gen = dict(gen_kwargs) self._inject_mrope_into_kwargs(local_gen) try: out = self.model.generate(**inputs, **local_gen) except KeyError as e: if "mrope" in str(e).lower(): print("[handler] caught mrope in text path, rebuilding & retry") self._ensure_mrope_on(self.model.generation_config) self._inject_mrope_into_kwargs(local_gen) out = self.model.generate(**inputs, **local_gen) else: raise return self.tokenizer.decode(out[0], skip_special_tokens=True) # 非 VL 或强制文本 if (not self.is_vl) or force_text: return {"text": run_text_path(prompt)} # VL 路径 images = _normalize_images(data.get("image_b64")) if hasattr(self.processor, "apply_chat_template"): content = [{"type": "text", "text": prompt or "Describe the image."}] for _ in images: content.append({"type": "image"}) msgs = [{"role": "user", "content": content}] prompt_text = self.processor.apply_chat_template( msgs, tokenize=False, add_generation_prompt=True ) else: prompt_text = prompt or "Describe the image." proc_inputs = self.processor( text=prompt_text, images=images if images else None, return_tensors="pt", ) if (("input_ids" in proc_inputs and hasattr(proc_inputs["input_ids"], "numel") and proc_inputs["input_ids"].numel() == 0) and not images): return {"text": run_text_path(prompt)} proc_inputs = _to_device(proc_inputs, self.model.device) # 每次 generate 前都注入一次 local_gen = dict(gen_kwargs) self._inject_mrope_into_kwargs(local_gen) try: outputs = self.model.generate(**proc_inputs, **local_gen) except KeyError as e: if "mrope" in str(e).lower(): print("[handler] caught mrope in VL path, rebuilding & retry") self._ensure_mrope_on(self.model.generation_config) self._inject_mrope_into_kwargs(local_gen) outputs = self.model.generate(**proc_inputs, **local_gen) else: raise text = self._decode_outputs(outputs) return {"text": text}