| from typing import Dict, List, Optional, Tuple, Type, Any |
| from pathlib import Path |
| import uuid |
| import tempfile |
|
|
| import numpy as np |
| import torch |
| import torchvision |
| import torchxrayvision as xrv |
| import matplotlib.pyplot as plt |
| import skimage.io |
| import skimage.measure |
| import skimage.transform |
| import traceback |
|
|
| from pydantic import BaseModel, Field |
| from langchain_core.callbacks import ( |
| AsyncCallbackManagerForToolRun, |
| CallbackManagerForToolRun, |
| ) |
| from langchain_core.tools import BaseTool |
|
|
|
|
| class ChestXRaySegmentationInput(BaseModel): |
| """Input schema for the Chest X-ray Segmentation Tool.""" |
|
|
| image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") |
| organs: Optional[List[str]] = Field( |
| None, |
| description="List of organs to segment. If None, all available organs will be segmented. " |
| "Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " |
| "Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " |
| "Mediastinum, Weasand, Spine", |
| ) |
|
|
|
|
| class OrganMetrics(BaseModel): |
| """Detailed metrics for a segmented organ.""" |
|
|
| |
| area_pixels: int = Field(..., description="Area in pixels") |
| area_cm2: float = Field(..., description="Approximate area in cm²") |
| centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") |
| bbox: Tuple[int, int, int, int] = Field( |
| ..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)" |
| ) |
|
|
| |
| width: int = Field(..., description="Width of the organ in pixels") |
| height: int = Field(..., description="Height of the organ in pixels") |
| aspect_ratio: float = Field(..., description="Height/width ratio") |
|
|
| |
| relative_position: Dict[str, float] = Field( |
| ..., description="Position relative to image boundaries (0-1 scale)" |
| ) |
|
|
| |
| mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") |
| std_intensity: float = Field(..., description="Standard deviation of pixel intensity") |
| confidence_score: float = Field(..., description="Model confidence score for this organ") |
|
|
|
|
| class ChestXRaySegmentationTool(BaseTool): |
| """Tool for performing detailed segmentation analysis of chest X-ray images.""" |
|
|
| name: str = "chest_xray_segmentation" |
| description: str = ( |
| "Segments chest X-ray images to specified anatomical structures. " |
| "Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " |
| "Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " |
| "Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " |
| "and Spine. Returns segmentation visualization and comprehensive metrics. " |
| "Let the user know the area is not accurate unless input has been DICOM." |
| ) |
| args_schema: Type[BaseModel] = ChestXRaySegmentationInput |
|
|
| model: Any = None |
| device: Optional[str] = "cuda" |
| transform: Any = None |
| pixel_spacing_mm: float = 0.2 |
| temp_dir: Path = Path("temp") |
| organ_map: Dict[str, int] = None |
|
|
| def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")): |
| """Initialize the segmentation tool with model and temporary directory.""" |
| super().__init__() |
| self.model = xrv.baseline_models.chestx_det.PSPNet() |
| self.device = torch.device(device) if device else "cuda" |
| self.model = self.model.to(self.device) |
| self.model.eval() |
|
|
| self.transform = torchvision.transforms.Compose( |
| [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)] |
| ) |
|
|
| self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir) |
| self.temp_dir.mkdir(exist_ok=True) |
|
|
| |
| self.organ_map = { |
| "Left Clavicle": 0, |
| "Right Clavicle": 1, |
| "Left Scapula": 2, |
| "Right Scapula": 3, |
| "Left Lung": 4, |
| "Right Lung": 5, |
| "Left Hilus Pulmonis": 6, |
| "Right Hilus Pulmonis": 7, |
| "Heart": 8, |
| "Aorta": 9, |
| "Facies Diaphragmatica": 10, |
| "Mediastinum": 11, |
| "Weasand": 12, |
| "Spine": 13, |
| } |
|
|
| def _align_mask_to_original( |
| self, mask: np.ndarray, original_shape: Tuple[int, int] |
| ) -> np.ndarray: |
| """ |
| Align a mask from the transformed (cropped/resized) space back to the full original image. |
| Assumes that the transform does a center crop to a square of side = min(original height, width) |
| and then resizes to (512,512). |
| """ |
| orig_h, orig_w = original_shape |
| crop_size = min(orig_h, orig_w) |
| crop_top = (orig_h - crop_size) // 2 |
| crop_left = (orig_w - crop_size) // 2 |
|
|
| |
| resized_mask = skimage.transform.resize( |
| mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False |
| ) |
| full_mask = np.zeros(original_shape) |
| full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask |
| return full_mask |
|
|
| def _compute_organ_metrics( |
| self, mask: np.ndarray, original_img: np.ndarray, confidence: float |
| ) -> Optional[OrganMetrics]: |
| """Compute comprehensive metrics for a single organ mask.""" |
| |
| if mask.shape != original_img.shape: |
| mask = self._align_mask_to_original(mask, original_img.shape) |
|
|
| props = skimage.measure.regionprops(mask.astype(int)) |
| if not props: |
| return None |
|
|
| props = props[0] |
| area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 |
|
|
| img_height, img_width = mask.shape |
| cy, cx = props.centroid |
| relative_pos = { |
| "top": cy / img_height, |
| "left": cx / img_width, |
| "center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), |
| } |
|
|
| organ_pixels = original_img[mask > 0] |
| mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 |
| std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 |
|
|
| return OrganMetrics( |
| area_pixels=int(mask.sum()), |
| area_cm2=float(area_cm2), |
| centroid=(float(cy), float(cx)), |
| bbox=tuple(map(int, props.bbox)), |
| width=int(props.bbox[3] - props.bbox[1]), |
| height=int(props.bbox[2] - props.bbox[0]), |
| aspect_ratio=float( |
| (props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1]) |
| ), |
| relative_position=relative_pos, |
| mean_intensity=float(mean_intensity), |
| std_intensity=float(std_intensity), |
| confidence_score=float(confidence), |
| ) |
|
|
| def _save_visualization( |
| self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int] |
| ) -> str: |
| """Save visualization of original image with segmentation masks overlaid.""" |
| plt.figure(figsize=(10, 10)) |
| plt.imshow( |
| original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0] |
| ) |
|
|
| |
| colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices))) |
|
|
| |
| for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): |
| mask = pred_masks[0, organ_idx].cpu().numpy() |
| if mask.sum() > 0: |
| |
| if mask.shape != original_img.shape: |
| mask = self._align_mask_to_original(mask, original_img.shape) |
|
|
| |
| colored_mask = np.zeros((*original_img.shape, 4)) |
| colored_mask[mask > 0] = (*color[:3], 0.3) |
| plt.imshow( |
| colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0] |
| ) |
|
|
| |
| organ_name = list(self.organ_map.keys())[ |
| list(self.organ_map.values()).index(organ_idx) |
| ] |
| plt.plot([], [], color=color, label=organ_name, linewidth=3) |
|
|
| plt.title("Segmentation Overlay") |
| plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") |
| plt.axis("off") |
|
|
| save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" |
| plt.savefig(save_path, bbox_inches="tight", dpi=300) |
| plt.close() |
|
|
| return str(save_path) |
|
|
| def _run( |
| self, |
| image_path: str, |
| organs: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForToolRun] = None, |
| ) -> Tuple[Dict[str, Any], Dict]: |
| """Run segmentation analysis for specified organs.""" |
| try: |
| |
| if organs: |
| organs = [o.strip() for o in organs] |
| invalid_organs = [o for o in organs if o not in self.organ_map] |
| if invalid_organs: |
| raise ValueError(f"Invalid organs specified: {invalid_organs}") |
| organ_indices = [self.organ_map[o] for o in organs] |
| else: |
| organ_indices = list(self.organ_map.values()) |
| organs = list(self.organ_map.keys()) |
|
|
| |
| original_img = skimage.io.imread(image_path) |
| if len(original_img.shape) > 2: |
| original_img = original_img[:, :, 0] |
|
|
| img = xrv.datasets.normalize(original_img, 255) |
| img = img[None, ...] |
| img = self.transform(img) |
| img = torch.from_numpy(img) |
| img = img.to(self.device) |
|
|
| |
| with torch.no_grad(): |
| pred = self.model(img) |
| pred_probs = torch.sigmoid(pred) |
| pred_masks = (pred_probs > 0.5).float() |
|
|
| |
| viz_path = self._save_visualization(original_img, pred_masks, organ_indices) |
|
|
| |
| results = {} |
| for idx, organ_name in zip(organ_indices, organs): |
| mask = pred_masks[0, idx].cpu().numpy() |
| if mask.sum() > 0: |
| metrics = self._compute_organ_metrics( |
| mask, original_img, float(pred_probs[0, idx].mean().cpu()) |
| ) |
| if metrics: |
| results[organ_name] = metrics |
|
|
| output = { |
| "segmentation_image_path": viz_path, |
| "metrics": {organ: metrics.dict() for organ, metrics in results.items()}, |
| } |
|
|
| metadata = { |
| "image_path": image_path, |
| "segmentation_image_path": viz_path, |
| "original_size": original_img.shape, |
| "model_size": tuple(img.shape[-2:]), |
| "pixel_spacing_mm": self.pixel_spacing_mm, |
| "requested_organs": organs, |
| "processed_organs": list(results.keys()), |
| "analysis_status": "completed", |
| } |
|
|
| return output, metadata |
|
|
| except Exception as e: |
| error_output = {"error": str(e)} |
| error_metadata = { |
| "image_path": image_path, |
| "analysis_status": "failed", |
| "error_traceback": traceback.format_exc(), |
| } |
| return error_output, error_metadata |
|
|
| async def _arun( |
| self, |
| image_path: str, |
| organs: Optional[List[str]] = None, |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
| ) -> Tuple[Dict[str, Any], Dict]: |
| """Async version of _run.""" |
| return self._run(image_path, organs) |
|
|