Lorenzob's picture
Removed sentence-transformers and fixed SyntaxError in handler.py
204d192 verified
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
import io
import base64
import json
class InferenceHandler:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Correct path for loading model and processor within the HF Inference Endpoint container
self.model = SamModel.from_pretrained(".").to(self.device)
self.processor = SamProcessor.from_pretrained(".")
def preprocess(self, request_body):
# Expect request_body to be a JSON string with 'image' (base64) and 'boxes' (list of list of floats)
data = json.loads(request_body)
# Decode image from base64
image_bytes = base64.b64decode(data['image'])
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Get bounding boxes
input_boxes = data.get('boxes', [])
# Ensure boxes are in the expected format (list of list of 4 floats)
input_boxes = [[float(coord) for coord in box] for box in input_boxes]
# Prepare inputs for the model
inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt", do_rescale=False, do_normalize=False).to(self.device)
return inputs, image.size
def inference(self, inputs):
with torch.no_grad():
outputs = self.model(**inputs, multimask_output=False)
return outputs
def postprocess(self, outputs, original_size):
# Post-process masks to original image size
masks = self.processor.post_process_masks(
outputs.pred_masks.cpu(),
torch.tensor([original_size]), # (W, H) -> (H, W)
outputs.reshaped_input_sizes.cpu()
)
# Convert masks to binary numpy arrays and then to base64 for JSON response
results = []
for mask_dict in masks:
mask_np = mask_dict['segmentation'].squeeze().numpy().astype(np.uint8) * 255 # Convert to 0/255
buffered = io.BytesIO()
Image.fromarray(mask_np).save(buffered, format="PNG")
encoded_mask = base64.b64encode(buffered.getvalue()).decode('utf-8')
results.append({"mask": encoded_mask, "score": mask_dict.get('score', 0.0)})
return json.dumps(results)
# Example of how to use the handler locally (for testing)
if __name__ == '__main__':
handler = InferenceHandler()
# Create a dummy image
dummy_image_size = (256, 256)
dummy_image_np = np.random.randint(0, 256, dummy_image_size, dtype=np.uint8)
image = Image.fromarray(dummy_image_np)
# Encode dummy image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
# Example bounding box
example_boxes = [[50, 50, 200, 200]]
# Create a dummy request body
dummy_request_body = json.dumps({"image": encoded_image, "boxes": example_boxes})
print('
--- Testing InferenceHandler locally ---')
inputs, original_size = handler.preprocess(dummy_request_body)
outputs = handler.inference(inputs)
processed_response = handler.postprocess(outputs, original_size)
print('Local test successful. Response structure (truncated):', processed_response[:200], '...')