GTA1-32B / handler.py
mengsong16's picture
Update handler.py
208e255 verified
# handler.py
import os
import io
import base64
from typing import Any, Dict, List, Union
import torch
from PIL import Image
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
AutoModelForVision2Seq,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
# --- Env defaults ---
os.environ.setdefault("HF_TRUST_REMOTE_CODE", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
USE_4BIT = os.environ.get("USE_4BIT", "1") not in {"0", "false", "False"}
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0,1")
# mrope 相关键(部分分支可能访问 kwargs 里的这些键,而不是从 generation_config 取)
MROPE_KEYS = ("mrope_section", "mrope_section_output", "mrope_theta", "mrope_t")
def _b64_to_pil(b64str: str) -> Image.Image:
return Image.open(io.BytesIO(base64.b64decode(b64str))).convert("RGB")
def _normalize_images(img_field: Union[str, List[str], None]) -> List[Image.Image]:
if img_field is None:
return []
if isinstance(img_field, str):
return [_b64_to_pil(img_field)]
if isinstance(img_field, list):
return [_b64_to_pil(s) for s in img_field]
raise ValueError("image_b64 must be a base64 string or a list of base64 strings.")
def _to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
out = {}
for k, v in batch.items():
if hasattr(v, "to"):
out[k] = v.to(device)
else:
out[k] = v
return out
class EndpointHandler:
"""
支持两种输入:
1) Inference API 形态:
{"inputs": {...}, "parameters": {...}}
2) 扁平形态:
{"prompt": "...", "image_b64": "...|[...]", "max_new_tokens": ..., "temperature": ..., "top_p": ..., "force_text": ...}
返回:{"text": "..."}
"""
def __init__(self, path: str = "/repository"):
self.model_id = path
self.cfg = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
# 判定是否多模态
vl_model_types = {
"qwen2_5_vl", "qwen2_vl", "mllama", "fuyu",
"phi4multimodal", "git", "gotocr2", "qwen2_5_vl_moe"
}
self.is_vl = getattr(self.cfg, "model_type", "").lower() in vl_model_types
# 处理器 / 分词器
self.processor = None
self.tokenizer = None
if self.is_vl:
self.processor = AutoProcessor.from_pretrained(self.model_id, trust_remote_code=True)
self.tokenizer = getattr(self.processor, "tokenizer", None)
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, trust_remote_code=True, use_fast=True
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, trust_remote_code=True, use_fast=True
)
# 精度 / 量化
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
quant_cfg = None
if USE_4BIT:
try:
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
except Exception:
quant_cfg = None
self.max_memory = {0: "38GiB", 1: "38GiB", "cpu": "120GiB"}
ModelClass = AutoModelForVision2Seq if self.is_vl else AutoModelForCausalLM
self.model = ModelClass.from_pretrained(
self.model_id,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto",
max_memory=self.max_memory,
quantization_config=quant_cfg,
).eval()
torch.backends.cuda.matmul.allow_tf32 = True
self.default_gen_kwargs = dict(
max_new_tokens=128,
do_sample=False,
temperature=0.0,
top_p=1.0,
)
has_mrope = any(hasattr(self.cfg, k) for k in MROPE_KEYS)
print(
f"[handler] transformers={self._tf_version_safe()}, "
f"model_type={getattr(self.cfg, 'model_type', 'unknown')}, "
f"has_mrope_in_cfg={has_mrope}"
)
# 先在 generation_config 上打底
self._ensure_mrope_on(self.model.generation_config)
# ---------- helpers ----------
def _tf_version_safe(self) -> str:
try:
import transformers as _tf # noqa
return getattr(_tf, "__version__", "unknown")
except Exception:
return "unknown"
def _ensure_mrope_on(self, gen_cfg: Any) -> None:
if gen_cfg is None:
return
for k in MROPE_KEYS:
if not hasattr(gen_cfg, k):
try:
setattr(gen_cfg, k, None)
except Exception:
pass
# 尽力把内部 dict 也补齐(不同实现 to_dict 用法不同,这里 best-effort)
for attr_name in ("__dict__", "_internal_dict"):
try:
d = getattr(gen_cfg, attr_name, None)
if isinstance(d, dict) and k not in d:
d[k] = None
except Exception:
pass
def _inject_mrope_into_kwargs(self, gen_kwargs: Dict[str, Any]) -> None:
"""
关键修复:把 mrope_* 直接注入到即将传入 generate 的 kwargs 中,
防止底层代码对 kwargs['mrope_section'] 等键进行索引时报 KeyError。
"""
for k in MROPE_KEYS:
if k not in gen_kwargs:
# 优先取 generation_config 上的值,否则 None
v = getattr(self.model.generation_config, k, None)
gen_kwargs[k] = v
def _build_and_merge_payload(self, data: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(data, dict) and "inputs" in data and not isinstance(data["inputs"], str):
payload = data.get("inputs") or {}
params = data.get("parameters") or {}
merged = dict(payload)
for k, v in params.items():
merged.setdefault(k, v)
return merged
return data
def _decode_outputs(self, outputs: Any) -> str:
if hasattr(self.processor, "batch_decode"):
try:
return self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
except Exception:
pass
if self.tokenizer is not None:
try:
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception:
pass
try:
return str(outputs[0].tolist())
except Exception:
return str(outputs)
# ---------- inference ----------
@torch.inference_mode()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
data = self._build_and_merge_payload(data)
prompt = str(data.get("prompt", "")).strip()
max_new_tokens = int(data.get("max_new_tokens", self.default_gen_kwargs["max_new_tokens"]))
temperature = float(data.get("temperature", self.default_gen_kwargs["temperature"]))
top_p = float(data.get("top_p", self.default_gen_kwargs["top_p"]))
force_text = bool(data.get("force_text", False))
gen_kwargs: Dict[str, Any] = dict(
max_new_tokens=max_new_tokens,
do_sample=(temperature > 0.0),
temperature=temperature,
top_p=top_p,
)
# 生成前:双保险
self._ensure_mrope_on(self.model.generation_config)
self._inject_mrope_into_kwargs(gen_kwargs)
# 文本路径
def run_text_path(text: str) -> str:
tok = self.tokenizer
if tok is None and getattr(self, "processor", None) is not None:
tok = getattr(self.processor, "tokenizer", None)
if tok is None:
tok = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True, use_fast=True)
self.tokenizer = tok
txt = text if text else "Hello"
inputs = tok(txt, return_tensors="pt")
if "input_ids" not in inputs or inputs["input_ids"].numel() == 0:
inputs = tok("Hello", return_tensors="pt")
inputs = _to_device(inputs, self.model.device)
# 再次注入,避免外层修改丢失
local_gen = dict(gen_kwargs)
self._inject_mrope_into_kwargs(local_gen)
try:
out = self.model.generate(**inputs, **local_gen)
except KeyError as e:
if "mrope" in str(e).lower():
print("[handler] caught mrope in text path, rebuilding & retry")
self._ensure_mrope_on(self.model.generation_config)
self._inject_mrope_into_kwargs(local_gen)
out = self.model.generate(**inputs, **local_gen)
else:
raise
return self.tokenizer.decode(out[0], skip_special_tokens=True)
# 非 VL 或强制文本
if (not self.is_vl) or force_text:
return {"text": run_text_path(prompt)}
# VL 路径
images = _normalize_images(data.get("image_b64"))
if hasattr(self.processor, "apply_chat_template"):
content = [{"type": "text", "text": prompt or "Describe the image."}]
for _ in images:
content.append({"type": "image"})
msgs = [{"role": "user", "content": content}]
prompt_text = self.processor.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True
)
else:
prompt_text = prompt or "Describe the image."
proc_inputs = self.processor(
text=prompt_text,
images=images if images else None,
return_tensors="pt",
)
if (("input_ids" in proc_inputs and hasattr(proc_inputs["input_ids"], "numel")
and proc_inputs["input_ids"].numel() == 0) and not images):
return {"text": run_text_path(prompt)}
proc_inputs = _to_device(proc_inputs, self.model.device)
# 每次 generate 前都注入一次
local_gen = dict(gen_kwargs)
self._inject_mrope_into_kwargs(local_gen)
try:
outputs = self.model.generate(**proc_inputs, **local_gen)
except KeyError as e:
if "mrope" in str(e).lower():
print("[handler] caught mrope in VL path, rebuilding & retry")
self._ensure_mrope_on(self.model.generation_config)
self._inject_mrope_into_kwargs(local_gen)
outputs = self.model.generate(**proc_inputs, **local_gen)
else:
raise
text = self._decode_outputs(outputs)
return {"text": text}