| | |
| | import os |
| | import io |
| | import base64 |
| | from typing import Any, Dict, List, Union |
| |
|
| | import torch |
| | from PIL import Image |
| | from transformers import ( |
| | AutoConfig, |
| | AutoProcessor, |
| | AutoTokenizer, |
| | AutoModelForVision2Seq, |
| | AutoModelForCausalLM, |
| | BitsAndBytesConfig, |
| | ) |
| |
|
| | |
| | os.environ.setdefault("HF_TRUST_REMOTE_CODE", "1") |
| | os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| | os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| |
|
| | USE_4BIT = os.environ.get("USE_4BIT", "1") not in {"0", "false", "False"} |
| | os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1") |
| |
|
| | |
| | MROPE_KEYS = ("mrope_section", "mrope_section_output", "mrope_theta", "mrope_t") |
| |
|
| |
|
| | def _b64_to_pil(b64str: str) -> Image.Image: |
| | return Image.open(io.BytesIO(base64.b64decode(b64str))).convert("RGB") |
| |
|
| |
|
| | def _normalize_images(img_field: Union[str, List[str], None]) -> List[Image.Image]: |
| | if img_field is None: |
| | return [] |
| | if isinstance(img_field, str): |
| | return [_b64_to_pil(img_field)] |
| | if isinstance(img_field, list): |
| | return [_b64_to_pil(s) for s in img_field] |
| | raise ValueError("image_b64 must be a base64 string or a list of base64 strings.") |
| |
|
| |
|
| | def _to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: |
| | out = {} |
| | for k, v in batch.items(): |
| | if hasattr(v, "to"): |
| | out[k] = v.to(device) |
| | else: |
| | out[k] = v |
| | return out |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | 支持两种输入: |
| | 1) Inference API 形态: |
| | {"inputs": {...}, "parameters": {...}} |
| | 2) 扁平形态: |
| | {"prompt": "...", "image_b64": "...|[...]", "max_new_tokens": ..., "temperature": ..., "top_p": ..., "force_text": ...} |
| | |
| | 返回:{"text": "..."} |
| | """ |
| |
|
| | def __init__(self, path: str = "/repository"): |
| | self.model_id = path |
| | self.cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True) |
| |
|
| | |
| | vl_model_types = { |
| | "qwen2_5_vl", "qwen2_vl", "mllama", "fuyu", |
| | "phi4multimodal", "git", "gotocr2", "qwen2_5_vl_moe" |
| | } |
| | self.is_vl = getattr(self.cfg, "model_type", "").lower() in vl_model_types |
| |
|
| | |
| | self.processor = None |
| | self.tokenizer = None |
| | if self.is_vl: |
| | self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True) |
| | self.tokenizer = getattr(self.processor, "tokenizer", None) |
| | if self.tokenizer is None: |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | self.model_id, trust_remote_code=True, use_fast=True |
| | ) |
| | else: |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | self.model_id, trust_remote_code=True, use_fast=True |
| | ) |
| |
|
| | |
| | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| | quant_cfg = None |
| | if USE_4BIT: |
| | try: |
| | quant_cfg = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=dtype, |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| | except Exception: |
| | quant_cfg = None |
| |
|
| | self.max_memory = {0: "38GiB", 1: "38GiB", "cpu": "120GiB"} |
| | ModelClass = AutoModelForVision2Seq if self.is_vl else AutoModelForCausalLM |
| |
|
| | self.model = ModelClass.from_pretrained( |
| | self.model_id, |
| | trust_remote_code=True, |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | device_map="auto", |
| | max_memory=self.max_memory, |
| | quantization_config=quant_cfg, |
| | ).eval() |
| |
|
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | self.default_gen_kwargs = dict( |
| | max_new_tokens=128, |
| | do_sample=False, |
| | temperature=0.0, |
| | top_p=1.0, |
| | ) |
| |
|
| | has_mrope = any(hasattr(self.cfg, k) for k in MROPE_KEYS) |
| | print( |
| | f"[handler] transformers={self._tf_version_safe()}, " |
| | f"model_type={getattr(self.cfg, 'model_type', 'unknown')}, " |
| | f"has_mrope_in_cfg={has_mrope}" |
| | ) |
| |
|
| | |
| | self._ensure_mrope_on(self.model.generation_config) |
| |
|
| | |
| |
|
| | def _tf_version_safe(self) -> str: |
| | try: |
| | import transformers as _tf |
| | return getattr(_tf, "__version__", "unknown") |
| | except Exception: |
| | return "unknown" |
| |
|
| | def _ensure_mrope_on(self, gen_cfg: Any) -> None: |
| | if gen_cfg is None: |
| | return |
| | for k in MROPE_KEYS: |
| | if not hasattr(gen_cfg, k): |
| | try: |
| | setattr(gen_cfg, k, None) |
| | except Exception: |
| | pass |
| | |
| | for attr_name in ("__dict__", "_internal_dict"): |
| | try: |
| | d = getattr(gen_cfg, attr_name, None) |
| | if isinstance(d, dict) and k not in d: |
| | d[k] = None |
| | except Exception: |
| | pass |
| |
|
| | def _inject_mrope_into_kwargs(self, gen_kwargs: Dict[str, Any]) -> None: |
| | """ |
| | 关键修复:把 mrope_* 直接注入到即将传入 generate 的 kwargs 中, |
| | 防止底层代码对 kwargs['mrope_section'] 等键进行索引时报 KeyError。 |
| | """ |
| | for k in MROPE_KEYS: |
| | if k not in gen_kwargs: |
| | |
| | v = getattr(self.model.generation_config, k, None) |
| | gen_kwargs[k] = v |
| |
|
| | def _build_and_merge_payload(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | if isinstance(data, dict) and "inputs" in data and not isinstance(data["inputs"], str): |
| | payload = data.get("inputs") or {} |
| | params = data.get("parameters") or {} |
| | merged = dict(payload) |
| | for k, v in params.items(): |
| | merged.setdefault(k, v) |
| | return merged |
| | return data |
| |
|
| | def _decode_outputs(self, outputs: Any) -> str: |
| | if hasattr(self.processor, "batch_decode"): |
| | try: |
| | return self.processor.batch_decode(outputs, skip_special_tokens=True)[0] |
| | except Exception: |
| | pass |
| | if self.tokenizer is not None: |
| | try: |
| | return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | except Exception: |
| | pass |
| | try: |
| | return str(outputs[0].tolist()) |
| | except Exception: |
| | return str(outputs) |
| |
|
| | |
| |
|
| | @torch.inference_mode() |
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | data = self._build_and_merge_payload(data) |
| |
|
| | prompt = str(data.get("prompt", "")).strip() |
| | max_new_tokens = int(data.get("max_new_tokens", self.default_gen_kwargs["max_new_tokens"])) |
| | temperature = float(data.get("temperature", self.default_gen_kwargs["temperature"])) |
| | top_p = float(data.get("top_p", self.default_gen_kwargs["top_p"])) |
| | force_text = bool(data.get("force_text", False)) |
| |
|
| | gen_kwargs: Dict[str, Any] = dict( |
| | max_new_tokens=max_new_tokens, |
| | do_sample=(temperature > 0.0), |
| | temperature=temperature, |
| | top_p=top_p, |
| | ) |
| |
|
| | |
| | self._ensure_mrope_on(self.model.generation_config) |
| | self._inject_mrope_into_kwargs(gen_kwargs) |
| |
|
| | |
| | def run_text_path(text: str) -> str: |
| | tok = self.tokenizer |
| | if tok is None and getattr(self, "processor", None) is not None: |
| | tok = getattr(self.processor, "tokenizer", None) |
| | if tok is None: |
| | tok = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True, use_fast=True) |
| | self.tokenizer = tok |
| |
|
| | txt = text if text else "Hello" |
| | inputs = tok(txt, return_tensors="pt") |
| | if "input_ids" not in inputs or inputs["input_ids"].numel() == 0: |
| | inputs = tok("Hello", return_tensors="pt") |
| | inputs = _to_device(inputs, self.model.device) |
| |
|
| | |
| | local_gen = dict(gen_kwargs) |
| | self._inject_mrope_into_kwargs(local_gen) |
| |
|
| | try: |
| | out = self.model.generate(**inputs, **local_gen) |
| | except KeyError as e: |
| | if "mrope" in str(e).lower(): |
| | print("[handler] caught mrope in text path, rebuilding & retry") |
| | self._ensure_mrope_on(self.model.generation_config) |
| | self._inject_mrope_into_kwargs(local_gen) |
| | out = self.model.generate(**inputs, **local_gen) |
| | else: |
| | raise |
| | return self.tokenizer.decode(out[0], skip_special_tokens=True) |
| |
|
| | |
| | if (not self.is_vl) or force_text: |
| | return {"text": run_text_path(prompt)} |
| |
|
| | |
| | images = _normalize_images(data.get("image_b64")) |
| | if hasattr(self.processor, "apply_chat_template"): |
| | content = [{"type": "text", "text": prompt or "Describe the image."}] |
| | for _ in images: |
| | content.append({"type": "image"}) |
| | msgs = [{"role": "user", "content": content}] |
| | prompt_text = self.processor.apply_chat_template( |
| | msgs, tokenize=False, add_generation_prompt=True |
| | ) |
| | else: |
| | prompt_text = prompt or "Describe the image." |
| |
|
| | proc_inputs = self.processor( |
| | text=prompt_text, |
| | images=images if images else None, |
| | return_tensors="pt", |
| | ) |
| |
|
| | if (("input_ids" in proc_inputs and hasattr(proc_inputs["input_ids"], "numel") |
| | and proc_inputs["input_ids"].numel() == 0) and not images): |
| | return {"text": run_text_path(prompt)} |
| |
|
| | proc_inputs = _to_device(proc_inputs, self.model.device) |
| |
|
| | |
| | local_gen = dict(gen_kwargs) |
| | self._inject_mrope_into_kwargs(local_gen) |
| |
|
| | try: |
| | outputs = self.model.generate(**proc_inputs, **local_gen) |
| | except KeyError as e: |
| | if "mrope" in str(e).lower(): |
| | print("[handler] caught mrope in VL path, rebuilding & retry") |
| | self._ensure_mrope_on(self.model.generation_config) |
| | self._inject_mrope_into_kwargs(local_gen) |
| | outputs = self.model.generate(**proc_inputs, **local_gen) |
| | else: |
| | raise |
| |
|
| | text = self._decode_outputs(outputs) |
| | return {"text": text} |
| |
|