Spaces:
Build error
Build error
| # Import required libraries | |
| import os | |
| import io | |
| import torch | |
| import pydicom | |
| import numpy as np | |
| import streamlit as st | |
| # Import utility and custom functions | |
| from PIL import Image | |
| from Util.DICOM import DICOM_Utils | |
| from Util.Custom_Model import Build_Custom_Model, reshape_transform | |
| # Import additional MONAI and PyTorch Grad-CAM utilities | |
| from monai.utils import set_determinism | |
| from monai.networks.nets import SEResNet50 | |
| from monai.transforms import ( | |
| Activations, | |
| EnsureChannelFirst, | |
| AsDiscrete, | |
| Compose, | |
| RandFlip, | |
| RandRotate, | |
| RandZoom, | |
| ScaleIntensity, | |
| AsChannelFirst, | |
| AddChannel, | |
| RandSpatialCrop, | |
| ScaleIntensityRangePercentiles, | |
| Resize, | |
| ) | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| # (Int) Random seed | |
| SEED = 0 | |
| # (Int) Model parameters | |
| NUM_CLASSES = 1 | |
| # (String) CT Model directory | |
| CT_MODEL_DIRECTORY = "models/CLOTS/CT" | |
| # (String) MRI Model directory | |
| MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" | |
| # (Boolean) Use custom model | |
| CUSTOM_MODEL_FLAG = True | |
| # (List[int]) Image size | |
| SPATIAL_SIZE = [224, 224] | |
| # (String) CT Model file name | |
| CT_MODEL_FILE_NAME = "best_metric_model.pth" | |
| # (String) MRI Model file name | |
| MRI_MODEL_FILE_NAME = "best_metric_model.pth" | |
| # (Boolean) List model modules | |
| LIST_MODEL_MODULES = False | |
| # (String) Model name | |
| CT_MODEL_NAME = "swin_base_patch4_window7_224" | |
| # (String) Model name | |
| MRI_MODEL_NAME = "swin_base_patch4_window7_224" | |
| # (Float) Model inference threshold | |
| CT_INFERENCE_THRESHOLD = 0.5 | |
| # (Float) Model inference threshold | |
| MRI_INFERENCE_THRESHOLD = 0.5 | |
| # (Int) Display CAM Class ID | |
| CAM_CLASS_ID = 0 | |
| # (Int) Window Center for image display | |
| DEFAULT_CT_WINDOW_CENTER = 40 | |
| # (Int) Window Width for image display | |
| DEFAULT_CT_WINDOW_WIDTH = 100 | |
| # (Int) Window Center for image display | |
| DEFAULT_MRI_WINDOW_CENTER = 400 | |
| # (Int) Window Width for image display | |
| DEFAULT_MRI_WINDOW_WIDTH = 1000 | |
| # (Int) Minimum value for Window Center | |
| WINDOW_CENTER_MIN = -600 | |
| # (Int) Maximum value for Window Center | |
| WINDOW_CENTER_MAX = 1000 | |
| # (Int) Minimum value for Window Width | |
| WINDOW_WIDTH_MIN = 1 | |
| # (Int) Maximum value for Window Width | |
| WINDOW_WIDTH_MAX = 3000 | |
| # Evaluation Transforms | |
| eval_transforms = Compose( | |
| [ | |
| AsChannelFirst(), | |
| ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), | |
| Resize(spatial_size=SPATIAL_SIZE) | |
| ] | |
| ) | |
| # CAM Transforms | |
| cam_transforms = Compose( | |
| [ | |
| AsChannelFirst(), | |
| Resize(spatial_size=SPATIAL_SIZE) | |
| ] | |
| ) | |
| # Original Transforms | |
| original_transforms = Compose( | |
| [ | |
| AsChannelFirst() | |
| ] | |
| ) | |
| # Function to convert PIL Image to byte stream in PNG format for downloading | |
| def image_to_bytes(image): | |
| byte_stream = io.BytesIO() | |
| image.save(byte_stream, format='PNG') | |
| return byte_stream.getvalue() | |
| # Convert the file size from bytes to megabytes | |
| def bytes_to_megabytes(file_size_bytes): | |
| # Convert bytes to MB (1 MB = 1024 * 1024 bytes) | |
| file_size_megabytes = round(file_size_bytes / (1024 * 1024), 2) | |
| return str(file_size_megabytes) + " MB" # Rounding to 2 decimal places for readability | |
| def meta_tensor_to_numpy(meta_tensor): | |
| """ | |
| Convert a PyTorch MetaTensor to a NumPy array | |
| """ | |
| # Ensure the MetaTensor is on the CPU | |
| meta_tensor = meta_tensor.cpu() | |
| # Convert the MetaTensor to a PyTorch tensor | |
| torch_tensor = meta_tensor.to(dtype=torch.float32) | |
| # Convert the PyTorch tensor to a NumPy array | |
| numpy_array = torch_tensor.detach().numpy() | |
| return numpy_array | |
| set_determinism(seed=SEED) | |
| torch.manual_seed(SEED) | |
| # Parameters | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| USE_CUDA = False | |
| if device == torch.device("cuda"): | |
| USE_CUDA = True | |
| def load_model(root_dir, model_name, model_file_name): | |
| if CUSTOM_MODEL_FLAG: | |
| model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) | |
| else: | |
| model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) | |
| model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)) | |
| model.eval() | |
| return model | |
| ct_model = load_model(CT_MODEL_DIRECTORY, CT_MODEL_NAME, CT_MODEL_FILE_NAME) | |
| mri_model = load_model(MRI_MODEL_DIRECTORY, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME) | |
| if LIST_MODEL_MODULES: | |
| for ct_name, _ in ct_model.named_modules(): | |
| print(ct_name) | |
| for mri_name, _ in mri_model.named_modules(): | |
| print(mri_name) | |
| # Initialize Streamlit | |
| st.title("Analyze") | |
| # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH | |
| st.sidebar.header("Windowing Parameters for DICOM") | |
| MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) | |
| MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) | |
| CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) | |
| CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) | |
| uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) | |
| if uploaded_mri_file is not None: | |
| # Read DICOM file into NumPy array | |
| dicom_data = pydicom.dcmread(uploaded_mri_file) | |
| dicom_array = dicom_data.pixel_array | |
| # Convert the data type to float32 | |
| dicom_array = dicom_array.astype(np.float32) | |
| # Then add a channel dimension | |
| dicom_array = dicom_array[:, :, np.newaxis] | |
| # To check file details | |
| file_details = {"File_Name": uploaded_mri_file.name, "File_Type": uploaded_mri_file.type, "File_Size": bytes_to_megabytes(uploaded_mri_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))} | |
| st.write(file_details) | |
| transformed_array = eval_transforms(dicom_array) | |
| # Convert to PyTorch tensor and move to device | |
| image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() | |
| prob = outputs[0][0] | |
| CLOTS_CLASSIFICATION = False | |
| if(prob >= MRI_INFERENCE_THRESHOLD): | |
| CLOTS_CLASSIFICATION=True | |
| st.header("MRI Classification") | |
| st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") | |
| st.subheader(f"Confidence : {prob * 100:.1f}%") | |
| # Load the original DICOM image for download | |
| download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device) | |
| download_image_tensor = download_image_tensor.squeeze() | |
| # Transform the download image and apply windowing | |
| download_image_numpy = meta_tensor_to_numpy(download_image_tensor) | |
| windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) | |
| # Streamlit button to trigger image download | |
| image_data = image_to_bytes(Image.fromarray(windowed_download_image)) | |
| st.download_button( | |
| label="Download MRI Image", | |
| data=image_data, | |
| file_name="downloaded_mri_image.png", | |
| mime="image/png" | |
| ) | |
| # Load the original DICOM image for display | |
| display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device) | |
| display_image_tensor = display_image_tensor.squeeze() | |
| # Transform the image and apply windowing | |
| display_image_numpy = meta_tensor_to_numpy(display_image_tensor) | |
| windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) | |
| st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) | |
| # Expand to three channels | |
| windowed_image = np.expand_dims(windowed_image, axis=2) | |
| windowed_image = np.tile(windowed_image, [1, 1, 3]) | |
| # Ensure both are of float32 type | |
| windowed_image = windowed_image.astype(np.float32) | |
| # Normalize to [0, 1] range | |
| windowed_image = np.float32(windowed_image) / 255 | |
| # Build the CAM (Class Activation Map) | |
| target_layers = [mri_model.model.norm] | |
| cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA) | |
| grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # Now you can safely call the show_cam_on_image function | |
| visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) | |
| st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True) | |
| uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) | |
| if uploaded_ct_file is not None: | |
| # Read DICOM file into NumPy array | |
| dicom_data = pydicom.dcmread(uploaded_ct_file) | |
| dicom_array = dicom_data.pixel_array | |
| # Convert the data type to float32 | |
| dicom_array = dicom_array.astype(np.float32) | |
| # Then add a channel dimension | |
| dicom_array = dicom_array[:, :, np.newaxis] | |
| # To check file details | |
| file_details = {"File_Name": uploaded_ct_file.name, "File_Type": uploaded_ct_file.type, "File_Size": bytes_to_megabytes(uploaded_ct_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))} | |
| st.write(file_details) | |
| transformed_array = eval_transforms(dicom_array) | |
| # Convert to PyTorch tensor and move to device | |
| image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() | |
| prob = outputs[0][0] | |
| CLOTS_CLASSIFICATION = False | |
| if(prob >= CT_INFERENCE_THRESHOLD): | |
| CLOTS_CLASSIFICATION=True | |
| st.header("CT Classification") | |
| st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") | |
| st.subheader(f"Confidence : {prob * 100:.1f}%") | |
| # Load the original DICOM image for download | |
| download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device) | |
| download_image_tensor = download_image_tensor.squeeze() | |
| # Transform the download image and apply windowing | |
| download_image_numpy = meta_tensor_to_numpy(download_image_tensor) | |
| windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) | |
| # Streamlit button to trigger image download | |
| image_data = image_to_bytes(Image.fromarray(windowed_download_image)) | |
| st.download_button( | |
| label="Download CT Image", | |
| data=image_data, | |
| file_name="downloaded_ct_image.png", | |
| mime="image/png" | |
| ) | |
| # Load the original DICOM image for display | |
| display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device) | |
| display_image_tensor = display_image_tensor.squeeze() | |
| # Transform the image and apply windowing | |
| display_image_numpy = meta_tensor_to_numpy(display_image_tensor) | |
| windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) | |
| st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) | |
| # Expand to three channels | |
| windowed_image = np.expand_dims(windowed_image, axis=2) | |
| windowed_image = np.tile(windowed_image, [1, 1, 3]) | |
| # Ensure both are of float32 type | |
| windowed_image = windowed_image.astype(np.float32) | |
| # Normalize to [0, 1] range | |
| windowed_image = np.float32(windowed_image) / 255 | |
| # Build the CAM (Class Activation Map) | |
| target_layers = [ct_model.model.norm] | |
| cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA) | |
| grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # Now you can safely call the show_cam_on_image function | |
| visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) | |
| st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True) | |