| | |
| | import time, logging |
| | from contextlib import nullcontext |
| | from typing import Any, Dict, AsyncIterable, Tuple |
| |
|
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| | from backends_base import ChatBackend, ImagesBackend |
| | from config import settings |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | try: |
| | import spaces |
| | from spaces.zero import client as zero_client |
| | except ImportError: |
| | spaces, zero_client = None, None |
| |
|
| | |
| | MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" |
| | logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...") |
| |
|
| | tokenizer, load_error = None, None |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | MODEL_ID, |
| | trust_remote_code=True, |
| | use_fast=False, |
| | ) |
| | except Exception as e: |
| | load_error = f"Failed to load tokenizer: {e}" |
| | logger.exception(load_error) |
| |
|
| |
|
| | |
| | def _pick_cpu_dtype() -> torch.dtype: |
| | if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"): |
| | try: |
| | if torch.cpu.is_bf16_supported(): |
| | logger.info("CPU BF16 supported, will attempt torch.bfloat16") |
| | return torch.bfloat16 |
| | except Exception: |
| | pass |
| | logger.info("Falling back to torch.float32 on CPU") |
| | return torch.float32 |
| |
|
| |
|
| | |
| | _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {} |
| |
|
| |
|
| | def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]: |
| | key = (device, dtype) |
| | if key in _MODEL_CACHE: |
| | return _MODEL_CACHE[key], dtype |
| |
|
| | cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) |
| | if hasattr(cfg, "quantization_config"): |
| | logger.warning("Removing quantization_config from model config") |
| | delattr(cfg, "quantization_config") |
| |
|
| | eff_dtype = dtype |
| | try: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | config=cfg, |
| | torch_dtype=dtype, |
| | trust_remote_code=True, |
| | device_map="auto" if device != "cpu" else {"": "cpu"}, |
| | low_cpu_mem_usage=False, |
| | ) |
| | except Exception as e: |
| | if device == "cpu" and dtype == torch.bfloat16: |
| | logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.") |
| | eff_dtype = torch.float32 |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | config=cfg, |
| | torch_dtype=eff_dtype, |
| | trust_remote_code=True, |
| | device_map={"": "cpu"}, |
| | low_cpu_mem_usage=False, |
| | ) |
| | else: |
| | raise |
| |
|
| | if device == "cpu": |
| | model = model.to(device=device, dtype=eff_dtype) |
| | else: |
| | model = model.to(device=device) |
| |
|
| | model.eval() |
| | _MODEL_CACHE[(device, eff_dtype)] = model |
| | return model, eff_dtype |
| |
|
| |
|
| | |
| | class HFChatBackend(ChatBackend): |
| | async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: |
| | if load_error: |
| | raise RuntimeError(load_error) |
| |
|
| | messages = request.get("messages", []) |
| | temperature = float(request.get("temperature", settings.LlmTemp or 0.7)) |
| | max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512)) |
| |
|
| | rid = f"chatcmpl-hf-{int(time.time())}" |
| | now = int(time.time()) |
| |
|
| | x_ip_token = request.get("x_ip_token") |
| | if x_ip_token and zero_client: |
| | zero_client.HEADERS["X-IP-Token"] = x_ip_token |
| | logger.debug("Injected X-IP-Token into ZeroGPU headers") |
| |
|
| | if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): |
| | try: |
| | prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | logger.debug("Applied chat template for prompt") |
| | except Exception as e: |
| | logger.warning(f"Failed to apply chat template: {e}, using fallback") |
| | prompt = messages[-1]["content"] if messages else "(empty)" |
| | else: |
| | prompt = messages[-1]["content"] if messages else "(empty)" |
| |
|
| | def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str: |
| | model, eff_dtype = _get_model(device, req_dtype) |
| |
|
| | inputs = tokenizer(prompt, return_tensors="pt") |
| | inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} |
| |
|
| | with torch.inference_mode(): |
| | if device != "cpu": |
| | autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype) |
| | else: |
| | if eff_dtype == torch.bfloat16: |
| | autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) |
| | else: |
| | autocast_ctx = nullcontext() |
| |
|
| | with autocast_ctx: |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | do_sample=True, |
| | use_cache=True, |
| | ) |
| |
|
| | |
| | input_len = inputs["input_ids"].shape[-1] |
| | generated_ids = outputs[0][input_len:] |
| | text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| | return text |
| |
|
| | if spaces: |
| | @spaces.GPU(duration=120) |
| | def run_once(prompt: str) -> str: |
| | if torch.cuda.is_available(): |
| | return _run_once(prompt, device="cuda", req_dtype=torch.float16) |
| | return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype()) |
| |
|
| | text = run_once(prompt) |
| | else: |
| | text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype()) |
| |
|
| | yield { |
| | "id": rid, |
| | "object": "chat.completion.chunk", |
| | "created": now, |
| | "model": MODEL_ID, |
| | "choices": [ |
| | {"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"} |
| | ], |
| | } |
| |
|
| |
|
| | |
| | class StubImagesBackend(ImagesBackend): |
| | async def generate_b64(self, request: Dict[str, Any]) -> str: |
| | logger.warning("Image generation not supported in HF backend.") |
| | return ( |
| | "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" |
| | ) |
| |
|