thisnick's picture
Upload full model folder with custom handler
383b031 verified
raw
history blame
2.57 kB
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import base64
import io
class EndpointHandler():
def __init__(self, model_path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = LlavaForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto" if torch.cuda.is_available() else None
)
self.model.eval()
def __call__(self, data):
inputs = data.get("inputs", {})
prompt = inputs.get("prompt", "Generate a caption for this image.")
images_b64 = inputs.get("images")
# Handle both single image and list of images
if isinstance(images_b64, str):
images_b64 = [images_b64]
if not images_b64:
return {"error": "No images provided in the payload."}
try:
images = [
Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
for img_b64 in images_b64
]
except Exception as e:
return {"error": f"Failed to decode image: {str(e)}"}
# Build the conversation template for captioning
conversation = [
{"role": "system", "content": "You are a helpful image captioner."},
{"role": "user", "content": prompt}
]
convo_string = self.processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
if not isinstance(convo_string, str):
return {"error": "Failed to create conversation string."}
# Prepare the inputs for the model - process all images at once
model_inputs = self.processor(
text=[convo_string],
images=images,
return_tensors="pt"
)
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
if "pixel_values" in model_inputs:
model_inputs["pixel_values"] = model_inputs["pixel_values"].to(torch.bfloat16)
# Generate caption tokens for all images at once
generate_ids = self.model.generate(
**model_inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.6,
top_p=0.9
)
# Trim off the prompt tokens and decode all captions
generate_ids = generate_ids[:, model_inputs["input_ids"].shape[1]:]
captions = [
self.processor.tokenizer.decode(
ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
).strip()
for ids in generate_ids
]
return {"captions": captions if len(captions) > 1 else captions[0]}