|
|
| import torch |
| from transformers import SamModel, SamProcessor |
| from PIL import Image |
| import numpy as np |
| import io |
| import base64 |
| import json |
|
|
| class InferenceHandler: |
| def __init__(self): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| self.model = SamModel.from_pretrained(".").to(self.device) |
| self.processor = SamProcessor.from_pretrained(".") |
|
|
| def preprocess(self, request_body): |
| |
| data = json.loads(request_body) |
|
|
| |
| image_bytes = base64.b64decode(data['image']) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
| |
| input_boxes = data.get('boxes', []) |
| |
| input_boxes = [[float(coord) for coord in box] for box in input_boxes] |
|
|
| |
| inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt", do_rescale=False, do_normalize=False).to(self.device) |
| return inputs, image.size |
|
|
| def inference(self, inputs): |
| with torch.no_grad(): |
| outputs = self.model(**inputs, multimask_output=False) |
| return outputs |
|
|
| def postprocess(self, outputs, original_size): |
| |
| masks = self.processor.post_process_masks( |
| outputs.pred_masks.cpu(), |
| torch.tensor([original_size]), |
| outputs.reshaped_input_sizes.cpu() |
| ) |
|
|
| |
| results = [] |
| for mask_dict in masks: |
| mask_np = mask_dict['segmentation'].squeeze().numpy().astype(np.uint8) * 255 |
| buffered = io.BytesIO() |
| Image.fromarray(mask_np).save(buffered, format="PNG") |
| encoded_mask = base64.b64encode(buffered.getvalue()).decode('utf-8') |
| results.append({"mask": encoded_mask, "score": mask_dict.get('score', 0.0)}) |
| return json.dumps(results) |
|
|
|
|
| |
| if __name__ == '__main__': |
| handler = InferenceHandler() |
|
|
| |
| dummy_image_size = (256, 256) |
| dummy_image_np = np.random.randint(0, 256, dummy_image_size, dtype=np.uint8) |
| image = Image.fromarray(dummy_image_np) |
|
|
| |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
| |
| example_boxes = [[50, 50, 200, 200]] |
|
|
| |
| dummy_request_body = json.dumps({"image": encoded_image, "boxes": example_boxes}) |
|
|
| print(' |
| --- Testing InferenceHandler locally ---') |
| inputs, original_size = handler.preprocess(dummy_request_body) |
| outputs = handler.inference(inputs) |
| processed_response = handler.postprocess(outputs, original_size) |
| print('Local test successful. Response structure (truncated):', processed_response[:200], '...') |
|
|