from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): """Initialize model and tokenizer.""" self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) self.model.eval() self.device = next(self.model.parameters()).device print(f"✅ Model loaded on {self.device}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """Handle inference request.""" inputs = data.get("inputs", data.get("input", "")) params = data.get("parameters", {}) # Tokenize encoded = self.tokenizer( inputs, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( **encoded, max_new_tokens=params.get("max_new_tokens", 256), temperature=params.get("temperature", 0.7), top_p=params.get("top_p", 0.9), do_sample=params.get("do_sample", True), repetition_penalty=params.get("repetition_penalty", 1.1), pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode (remove input tokens) generated = outputs[0][encoded["input_ids"].shape[1]:] text = self.tokenizer.decode(generated, skip_special_tokens=True) return {"generated_text": text}