Lorenzob commited on
Commit
ac0346c
·
verified ·
1 Parent(s): 6d38dae

Update with README.md, requirements.txt, and inference.py for Inference Endpoint

Browse files
Files changed (4) hide show
  1. README.md +121 -0
  2. inference.py +82 -0
  3. model.safetensors +1 -1
  4. requirements.txt +8 -0
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ pretty_name: SAM Brain Tumor Segmentation
4
+ tagged_ids:
5
+ - sam-brain-tumor-segmentation
6
+ pipeline_tag: image-segmentation
7
+ library_name: transformers
8
+ license: mit
9
+ ---
10
+
11
+ # SAM Brain Tumor Segmentation Model
12
+
13
+ This model is a fine-tuned [Segment Anything Model (SAM)](https://huggingface.co/facebook/sam-vit-base) for brain tumor segmentation from medical imaging data. It was trained using a simulated dataset of 2D slices derived from 3D NIfTI (.nii.gz) images and their corresponding segmentation masks.
14
+
15
+ ## Model Description
16
+
17
+ The original SAM model is a powerful general-purpose image segmentation model. This fine-tuned version specializes in identifying brain tumors, leveraging the prompt-based segmentation capabilities of SAM. The model is prompted with bounding boxes around the tumor regions (derived from ground truth masks during training) to generate precise segmentation masks.
18
+
19
+ ### Training Details
20
+
21
+ - **Base Model**: `facebook/sam-vit-base`
22
+ - **Dataset**: Simulated 2D axial slices from 3D NIfTI images, normalized to 0-1 range.
23
+ - **Image Preprocessing**: Grayscale images were duplicated across 3 channels to match SAM's expected input. Bounding box prompts were generated from ground truth masks.
24
+ - **Loss Functions**: Binary Cross-Entropy (BCE) Loss and Dice Loss.
25
+ - **Optimizer**: AdamW with a learning rate of 1e-5.
26
+ - **Epochs**: 5
27
+ - **Average Dice Score on Validation Set**: 0.9756 (on simulated data)
28
+
29
+ ## Usage
30
+
31
+ To use this model for inference, you can load it with the `transformers` library and provide an image along with a bounding box prompt for the region of interest. The model will then predict a segmentation mask.
32
+
33
+ ```python
34
+ from transformers import SamModel, SamProcessor
35
+ from PIL import Image
36
+ import torch
37
+ import numpy as np
38
+
39
+ # Load the fine-tuned model and processor
40
+ processor = SamProcessor.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
41
+ model = SamModel.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
42
+
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model.to(device)
45
+
46
+ # Example: Create a dummy image (replace with your actual medical image)
47
+ # This should be a 2D grayscale image, then converted to 3 channels.
48
+ # For a real image, load it and ensure it's normalized 0-1 and uint8 or float.
49
+ image_size = 256 # Example size
50
+ dummy_image_data = np.random.rand(image_size, image_size) * 255
51
+ dummy_image = Image.fromarray(dummy_image_data.astype(np.uint8)).convert("RGB")
52
+
53
+ # Example: Define a bounding box for the tumor region (x_min, y_min, x_max, y_max)
54
+ # In a real scenario, this bounding box would be provided by an expert or a detection model.
55
+ input_boxes = [[100, 100, 200, 200]] # Example bounding box coordinates
56
+
57
+ # Preprocess the image and bounding box
58
+ inputs = processor(dummy_image, input_boxes=input_boxes, return_tensors="pt").to(device)
59
+
60
+ # Perform inference
61
+ with torch.no_grad():
62
+ outputs = model(**inputs, multimask_output=False)
63
+
64
+ # Post-process the predicted mask
65
+ masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
66
+
67
+ # The output `masks` is a list of dictionaries. Each dict contains 'segmentation'.
68
+ # For simplicity, let's take the first mask (assuming multimask_output=False)
69
+ predicted_mask = masks[0]['segmentation'].squeeze().numpy() # Shape (H, W)
70
+
71
+ print("Predicted mask shape:", predicted_mask.shape)
72
+ # You can visualize 'predicted_mask' using matplotlib or other image libraries.
73
+ # For example:
74
+ # import matplotlib.pyplot as plt
75
+ # plt.imshow(predicted_mask, cmap='gray')
76
+ # plt.title('Predicted Segmentation Mask')
77
+ # plt.show()
78
+ ```
79
+
80
+ ## Inference Endpoint Configuration (Optional)
81
+
82
+ If you wish to deploy this model as an Inference Endpoint on Hugging Face, here's a sample configuration you might use in your `README.md` (or directly in the UI):
83
+
84
+ ```yaml
85
+ widget:
86
+ - src: "app.py"
87
+ example_title: "Brain Tumor Segmentation Example"
88
+ inputs:
89
+ - filename: "image.png"
90
+ image: https://huggingface.co/datasets/huggingface/sample-images/resolve/main/segmentation_image_input.png
91
+ input_boxes: [[100, 100, 200, 200]]
92
+
93
+ --- # Optional section for specific endpoint settings
94
+
95
+ parameters:
96
+ do_normalize: false # Assuming inputs are already normalized 0-1
97
+ do_rescale: false # Assuming inputs are already scaled correctly
98
+ multimask_output: false # For single best mask output
99
+
100
+ # Example of specific hardware/software config for advanced users
101
+ # inference:
102
+ # accelerator: cuda
103
+ # container: pytorch_latest
104
+ # hardware: gpu_small
105
+ # task: image-segmentation
106
+ ```
107
+
108
+ **Note**: The example image and `input_boxes` in the YAML configuration are placeholders. For a real medical image endpoint, you would provide a relevant example image and a bounding box corresponding to a tumor within that image.
109
+
110
+ ## Limitations
111
+
112
+ - The model was fine-tuned on a simulated dataset. Its performance on real, diverse clinical data may vary and needs further rigorous validation.
113
+ - The model relies on a bounding box prompt. Its accuracy is highly dependent on the quality and precision of the provided bounding box.
114
+ - Currently, the model handles 2D slices. Adaptation for full 3D volume segmentation would require further development.
115
+
116
+ ## Future Work
117
+
118
+ - Evaluate and fine-tune the model on large, real-world medical imaging datasets (e.g., BraTS, TCIA).
119
+ - Explore methods for automatic bounding box generation for tumor regions.
120
+ - Extend the model to handle 3D medical images directly.
121
+ - Implement quantitative metrics (e.g., IoU, Hausdorff Distance) during evaluation with real data.
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import SamModel, SamProcessor
4
+ from PIL import Image
5
+ import numpy as np
6
+ import io
7
+ import base64
8
+ import json
9
+
10
+ class InferenceHandler:
11
+ def __init__(self):
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ self.model = SamModel.from_pretrained("./sam_brain_tumor_model").to(self.device)
14
+ self.processor = SamProcessor.from_pretrained("./sam_brain_tumor_model")
15
+
16
+ def preprocess(self, request_body):
17
+ # Expect request_body to be a JSON string with 'image' (base64) and 'boxes' (list of list of floats)
18
+ data = json.loads(request_body)
19
+
20
+ # Decode image from base64
21
+ image_bytes = base64.b64decode(data['image'])
22
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
+
24
+ # Get bounding boxes
25
+ input_boxes = data.get('boxes', [])
26
+ # Ensure boxes are in the expected format (list of list of 4 floats)
27
+ input_boxes = [[float(coord) for coord in box] for box in input_boxes]
28
+
29
+ # Prepare inputs for the model
30
+ inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt", do_rescale=False, do_normalize=False).to(self.device)
31
+ return inputs, image.size
32
+
33
+ def inference(self, inputs):
34
+ with torch.no_grad():
35
+ outputs = self.model(**inputs, multimask_output=False)
36
+ return outputs
37
+
38
+ def postprocess(self, outputs, original_size):
39
+ # Post-process masks to original image size
40
+ masks = self.processor.post_process_masks(
41
+ outputs.pred_masks.cpu(),
42
+ torch.tensor([original_size]), # (W, H) -> (H, W)
43
+ outputs.reshaped_input_sizes.cpu()
44
+ )
45
+
46
+ # Convert masks to binary numpy arrays and then to base64 for JSON response
47
+ results = []
48
+ for mask_dict in masks:
49
+ mask_np = mask_dict['segmentation'].squeeze().numpy().astype(np.uint8) * 255 # Convert to 0/255
50
+ buffered = io.BytesIO()
51
+ Image.fromarray(mask_np).save(buffered, format="PNG")
52
+ encoded_mask = base64.b64encode(buffered.getvalue()).decode('utf-8')
53
+ results.append({"mask": encoded_mask, "score": mask_dict.get('score', 0.0)})
54
+ return json.dumps(results)
55
+
56
+
57
+ # Example of how to use the handler locally (for testing)
58
+ if __name__ == '__main__':
59
+ handler = InferenceHandler()
60
+
61
+ # Create a dummy image
62
+ dummy_image_size = (256, 256)
63
+ dummy_image_np = np.random.randint(0, 256, dummy_image_size, dtype=np.uint8)
64
+ image = Image.fromarray(dummy_image_np)
65
+
66
+ # Encode dummy image to base64
67
+ buffered = io.BytesIO()
68
+ image.save(buffered, format="PNG")
69
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
70
+
71
+ # Example bounding box
72
+ example_boxes = [[50, 50, 200, 200]]
73
+
74
+ # Create a dummy request body
75
+ dummy_request_body = json.dumps({"image": encoded_image, "boxes": example_boxes})
76
+
77
+ print("
78
+ --- Testing InferenceHandler locally ---")
79
+ inputs, original_size = handler.preprocess(dummy_request_body)
80
+ outputs = handler.inference(inputs)
81
+ processed_response = handler.postprocess(outputs, original_size)
82
+ print("Local test successful. Response structure (truncated):", processed_response[:200], "...")
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8604a4ffe7d3c99df24b1224a6c9593a4f1b82cd8b455eb64b6feb55103f498d
3
  size 374979376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2395d9f09a56238ae54dcb447b7c5230aedf8c7c46fff0644c28666901c6bc11
3
  size 374979376
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ torch==2.8.0+cpu
3
+ transformers==4.57.1
4
+ huggingface_hub==0.36.0
5
+ nibabel==5.3.2
6
+ numpy==2.0.2
7
+ Pillow==10.3.0
8
+ tqdm==4.67.1