Update with README.md, requirements.txt, and inference.py for Inference Endpoint
Browse files- README.md +121 -0
- inference.py +82 -0
- model.safetensors +1 -1
- requirements.txt +8 -0
README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
---
|
| 3 |
+
pretty_name: SAM Brain Tumor Segmentation
|
| 4 |
+
tagged_ids:
|
| 5 |
+
- sam-brain-tumor-segmentation
|
| 6 |
+
pipeline_tag: image-segmentation
|
| 7 |
+
library_name: transformers
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# SAM Brain Tumor Segmentation Model
|
| 12 |
+
|
| 13 |
+
This model is a fine-tuned [Segment Anything Model (SAM)](https://huggingface.co/facebook/sam-vit-base) for brain tumor segmentation from medical imaging data. It was trained using a simulated dataset of 2D slices derived from 3D NIfTI (.nii.gz) images and their corresponding segmentation masks.
|
| 14 |
+
|
| 15 |
+
## Model Description
|
| 16 |
+
|
| 17 |
+
The original SAM model is a powerful general-purpose image segmentation model. This fine-tuned version specializes in identifying brain tumors, leveraging the prompt-based segmentation capabilities of SAM. The model is prompted with bounding boxes around the tumor regions (derived from ground truth masks during training) to generate precise segmentation masks.
|
| 18 |
+
|
| 19 |
+
### Training Details
|
| 20 |
+
|
| 21 |
+
- **Base Model**: `facebook/sam-vit-base`
|
| 22 |
+
- **Dataset**: Simulated 2D axial slices from 3D NIfTI images, normalized to 0-1 range.
|
| 23 |
+
- **Image Preprocessing**: Grayscale images were duplicated across 3 channels to match SAM's expected input. Bounding box prompts were generated from ground truth masks.
|
| 24 |
+
- **Loss Functions**: Binary Cross-Entropy (BCE) Loss and Dice Loss.
|
| 25 |
+
- **Optimizer**: AdamW with a learning rate of 1e-5.
|
| 26 |
+
- **Epochs**: 5
|
| 27 |
+
- **Average Dice Score on Validation Set**: 0.9756 (on simulated data)
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
To use this model for inference, you can load it with the `transformers` library and provide an image along with a bounding box prompt for the region of interest. The model will then predict a segmentation mask.
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from transformers import SamModel, SamProcessor
|
| 35 |
+
from PIL import Image
|
| 36 |
+
import torch
|
| 37 |
+
import numpy as np
|
| 38 |
+
|
| 39 |
+
# Load the fine-tuned model and processor
|
| 40 |
+
processor = SamProcessor.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
|
| 41 |
+
model = SamModel.from_pretrained("Lorenzob/sam-brain-tumor-segmentation")
|
| 42 |
+
|
| 43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
+
model.to(device)
|
| 45 |
+
|
| 46 |
+
# Example: Create a dummy image (replace with your actual medical image)
|
| 47 |
+
# This should be a 2D grayscale image, then converted to 3 channels.
|
| 48 |
+
# For a real image, load it and ensure it's normalized 0-1 and uint8 or float.
|
| 49 |
+
image_size = 256 # Example size
|
| 50 |
+
dummy_image_data = np.random.rand(image_size, image_size) * 255
|
| 51 |
+
dummy_image = Image.fromarray(dummy_image_data.astype(np.uint8)).convert("RGB")
|
| 52 |
+
|
| 53 |
+
# Example: Define a bounding box for the tumor region (x_min, y_min, x_max, y_max)
|
| 54 |
+
# In a real scenario, this bounding box would be provided by an expert or a detection model.
|
| 55 |
+
input_boxes = [[100, 100, 200, 200]] # Example bounding box coordinates
|
| 56 |
+
|
| 57 |
+
# Preprocess the image and bounding box
|
| 58 |
+
inputs = processor(dummy_image, input_boxes=input_boxes, return_tensors="pt").to(device)
|
| 59 |
+
|
| 60 |
+
# Perform inference
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
outputs = model(**inputs, multimask_output=False)
|
| 63 |
+
|
| 64 |
+
# Post-process the predicted mask
|
| 65 |
+
masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
| 66 |
+
|
| 67 |
+
# The output `masks` is a list of dictionaries. Each dict contains 'segmentation'.
|
| 68 |
+
# For simplicity, let's take the first mask (assuming multimask_output=False)
|
| 69 |
+
predicted_mask = masks[0]['segmentation'].squeeze().numpy() # Shape (H, W)
|
| 70 |
+
|
| 71 |
+
print("Predicted mask shape:", predicted_mask.shape)
|
| 72 |
+
# You can visualize 'predicted_mask' using matplotlib or other image libraries.
|
| 73 |
+
# For example:
|
| 74 |
+
# import matplotlib.pyplot as plt
|
| 75 |
+
# plt.imshow(predicted_mask, cmap='gray')
|
| 76 |
+
# plt.title('Predicted Segmentation Mask')
|
| 77 |
+
# plt.show()
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Inference Endpoint Configuration (Optional)
|
| 81 |
+
|
| 82 |
+
If you wish to deploy this model as an Inference Endpoint on Hugging Face, here's a sample configuration you might use in your `README.md` (or directly in the UI):
|
| 83 |
+
|
| 84 |
+
```yaml
|
| 85 |
+
widget:
|
| 86 |
+
- src: "app.py"
|
| 87 |
+
example_title: "Brain Tumor Segmentation Example"
|
| 88 |
+
inputs:
|
| 89 |
+
- filename: "image.png"
|
| 90 |
+
image: https://huggingface.co/datasets/huggingface/sample-images/resolve/main/segmentation_image_input.png
|
| 91 |
+
input_boxes: [[100, 100, 200, 200]]
|
| 92 |
+
|
| 93 |
+
--- # Optional section for specific endpoint settings
|
| 94 |
+
|
| 95 |
+
parameters:
|
| 96 |
+
do_normalize: false # Assuming inputs are already normalized 0-1
|
| 97 |
+
do_rescale: false # Assuming inputs are already scaled correctly
|
| 98 |
+
multimask_output: false # For single best mask output
|
| 99 |
+
|
| 100 |
+
# Example of specific hardware/software config for advanced users
|
| 101 |
+
# inference:
|
| 102 |
+
# accelerator: cuda
|
| 103 |
+
# container: pytorch_latest
|
| 104 |
+
# hardware: gpu_small
|
| 105 |
+
# task: image-segmentation
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
**Note**: The example image and `input_boxes` in the YAML configuration are placeholders. For a real medical image endpoint, you would provide a relevant example image and a bounding box corresponding to a tumor within that image.
|
| 109 |
+
|
| 110 |
+
## Limitations
|
| 111 |
+
|
| 112 |
+
- The model was fine-tuned on a simulated dataset. Its performance on real, diverse clinical data may vary and needs further rigorous validation.
|
| 113 |
+
- The model relies on a bounding box prompt. Its accuracy is highly dependent on the quality and precision of the provided bounding box.
|
| 114 |
+
- Currently, the model handles 2D slices. Adaptation for full 3D volume segmentation would require further development.
|
| 115 |
+
|
| 116 |
+
## Future Work
|
| 117 |
+
|
| 118 |
+
- Evaluate and fine-tune the model on large, real-world medical imaging datasets (e.g., BraTS, TCIA).
|
| 119 |
+
- Explore methods for automatic bounding box generation for tumor regions.
|
| 120 |
+
- Extend the model to handle 3D medical images directly.
|
| 121 |
+
- Implement quantitative metrics (e.g., IoU, Hausdorff Distance) during evaluation with real data.
|
inference.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import SamModel, SamProcessor
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import io
|
| 7 |
+
import base64
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
class InferenceHandler:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
+
self.model = SamModel.from_pretrained("./sam_brain_tumor_model").to(self.device)
|
| 14 |
+
self.processor = SamProcessor.from_pretrained("./sam_brain_tumor_model")
|
| 15 |
+
|
| 16 |
+
def preprocess(self, request_body):
|
| 17 |
+
# Expect request_body to be a JSON string with 'image' (base64) and 'boxes' (list of list of floats)
|
| 18 |
+
data = json.loads(request_body)
|
| 19 |
+
|
| 20 |
+
# Decode image from base64
|
| 21 |
+
image_bytes = base64.b64decode(data['image'])
|
| 22 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 23 |
+
|
| 24 |
+
# Get bounding boxes
|
| 25 |
+
input_boxes = data.get('boxes', [])
|
| 26 |
+
# Ensure boxes are in the expected format (list of list of 4 floats)
|
| 27 |
+
input_boxes = [[float(coord) for coord in box] for box in input_boxes]
|
| 28 |
+
|
| 29 |
+
# Prepare inputs for the model
|
| 30 |
+
inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt", do_rescale=False, do_normalize=False).to(self.device)
|
| 31 |
+
return inputs, image.size
|
| 32 |
+
|
| 33 |
+
def inference(self, inputs):
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
outputs = self.model(**inputs, multimask_output=False)
|
| 36 |
+
return outputs
|
| 37 |
+
|
| 38 |
+
def postprocess(self, outputs, original_size):
|
| 39 |
+
# Post-process masks to original image size
|
| 40 |
+
masks = self.processor.post_process_masks(
|
| 41 |
+
outputs.pred_masks.cpu(),
|
| 42 |
+
torch.tensor([original_size]), # (W, H) -> (H, W)
|
| 43 |
+
outputs.reshaped_input_sizes.cpu()
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Convert masks to binary numpy arrays and then to base64 for JSON response
|
| 47 |
+
results = []
|
| 48 |
+
for mask_dict in masks:
|
| 49 |
+
mask_np = mask_dict['segmentation'].squeeze().numpy().astype(np.uint8) * 255 # Convert to 0/255
|
| 50 |
+
buffered = io.BytesIO()
|
| 51 |
+
Image.fromarray(mask_np).save(buffered, format="PNG")
|
| 52 |
+
encoded_mask = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 53 |
+
results.append({"mask": encoded_mask, "score": mask_dict.get('score', 0.0)})
|
| 54 |
+
return json.dumps(results)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Example of how to use the handler locally (for testing)
|
| 58 |
+
if __name__ == '__main__':
|
| 59 |
+
handler = InferenceHandler()
|
| 60 |
+
|
| 61 |
+
# Create a dummy image
|
| 62 |
+
dummy_image_size = (256, 256)
|
| 63 |
+
dummy_image_np = np.random.randint(0, 256, dummy_image_size, dtype=np.uint8)
|
| 64 |
+
image = Image.fromarray(dummy_image_np)
|
| 65 |
+
|
| 66 |
+
# Encode dummy image to base64
|
| 67 |
+
buffered = io.BytesIO()
|
| 68 |
+
image.save(buffered, format="PNG")
|
| 69 |
+
encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
| 70 |
+
|
| 71 |
+
# Example bounding box
|
| 72 |
+
example_boxes = [[50, 50, 200, 200]]
|
| 73 |
+
|
| 74 |
+
# Create a dummy request body
|
| 75 |
+
dummy_request_body = json.dumps({"image": encoded_image, "boxes": example_boxes})
|
| 76 |
+
|
| 77 |
+
print("
|
| 78 |
+
--- Testing InferenceHandler locally ---")
|
| 79 |
+
inputs, original_size = handler.preprocess(dummy_request_body)
|
| 80 |
+
outputs = handler.inference(inputs)
|
| 81 |
+
processed_response = handler.postprocess(outputs, original_size)
|
| 82 |
+
print("Local test successful. Response structure (truncated):", processed_response[:200], "...")
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 374979376
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2395d9f09a56238ae54dcb447b7c5230aedf8c7c46fff0644c28666901c6bc11
|
| 3 |
size 374979376
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
torch==2.8.0+cpu
|
| 3 |
+
transformers==4.57.1
|
| 4 |
+
huggingface_hub==0.36.0
|
| 5 |
+
nibabel==5.3.2
|
| 6 |
+
numpy==2.0.2
|
| 7 |
+
Pillow==10.3.0
|
| 8 |
+
tqdm==4.67.1
|