Spaces:
Runtime error
Runtime error
| """ | |
| aerial-segmentation | |
| Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees. | |
| """ | |
| import os | |
| import gradio as gr | |
| import cv2 | |
| os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
| from transformers import DetrFeatureExtractor, DetrForSegmentation | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import detectron2 | |
| import json | |
| # import some common detectron2 utilities | |
| import itertools | |
| import seaborn as sns | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.utils.visualizer import ColorMode | |
| from detectron2.data import MetadataCatalog, DatasetCatalog | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.utils.visualizer import ColorMode | |
| from detectron2.structures import Instances | |
| def list_pth_files_in_directory(directory, version="v1"): | |
| files = os.listdir(directory) | |
| version = version.split("v")[1] | |
| # return files that contains substring version and end with .pth | |
| pth_files = [f for f in files if version in f and f.endswith(".pth")] | |
| return pth_files | |
| def get_version_cfg_yml(path): | |
| directory = path.split("/")[0] | |
| version = path.split("/")[1] | |
| files = os.listdir(directory) | |
| cfg_file = [f for f in files if (f.endswith(".yml") or f.endswith(".yaml")) and version in f] | |
| return directory + "/" + cfg_file[0] | |
| def update_row_visibility(mode): | |
| visibility = { | |
| "tree": mode in ["Trees", "Both"], | |
| "building": mode in ["Buildings", "Both"] | |
| } | |
| tree_row, building_row = gr.Row(visible=visibility["tree"]), gr.Row(visible=visibility["building"]) | |
| return tree_row, building_row | |
| def update_path_options(version): | |
| if "tree" in version: | |
| directory = "tree_model_weights" | |
| else: | |
| directory = "building_model_weight" | |
| return gr.Dropdown(choices=list_pth_files_in_directory(directory, version), label=f"Select a {version.split('v')[0]} model file", visible=True, interactive=True) | |
| # Model for trees | |
| def tree_model(tree_version_dropdown, tree_pth_dropdown, tree_threshold, device="cpu"): | |
| tree_cfg = get_cfg() | |
| tree_cfg.merge_from_file(get_version_cfg_yml(f"tree_model_weights/{tree_version_dropdown}")) | |
| tree_cfg.MODEL.DEVICE=device | |
| tree_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}" | |
| tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # TODO change this | |
| tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = tree_threshold | |
| tree_predictor = DefaultPredictor(tree_cfg) | |
| return tree_predictor | |
| # Model for buildings | |
| def building_model(building_version_dropdown, building_pth_dropdown, building_threshold, device="cpu"): | |
| building_cfg = get_cfg() | |
| building_cfg.merge_from_file(get_version_cfg_yml(f"building_model_weight/{building_version_dropdown}")) | |
| building_cfg.MODEL.DEVICE=device | |
| building_cfg.MODEL.WEIGHTS = f"building_model_weight/{building_pth_dropdown}" | |
| building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8 # TODO change this | |
| building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = building_threshold | |
| building_predictor = DefaultPredictor(building_cfg) | |
| return building_predictor | |
| # A function that runs the buildings model on an given image and confidence threshold | |
| def segment_building(im, building_predictor): | |
| outputs = building_predictor(im) | |
| building_instances = outputs["instances"].to("cpu") | |
| return building_instances | |
| # A function that runs the trees model on an given image and confidence threshold | |
| def segment_tree(im, tree_predictor): | |
| outputs = tree_predictor(im) | |
| tree_instances = outputs["instances"].to("cpu") | |
| return tree_instances | |
| # Function to map strings to color mode | |
| def map_color_mode(color_mode): | |
| if color_mode == "Black/white": | |
| return ColorMode.IMAGE_BW | |
| elif color_mode == "Random": | |
| return ColorMode.IMAGE | |
| elif color_mode == "Segmentation" or color_mode == None: | |
| return ColorMode.SEGMENTATION | |
| def load_predictor(model, version, pth, threshold): | |
| return model(version, pth, threshold) | |
| def load_instances(image, predictor, segment_function): | |
| return segment_function(image, predictor) | |
| def combine_instances(tree_instances, building_instances): | |
| return Instances.cat([tree_instances, building_instances]) | |
| def get_metadata(dataset_name, coco_file): | |
| metadata = MetadataCatalog.get(dataset_name) | |
| with open(coco_file, "r") as f: | |
| coco = json.load(f) | |
| categories = coco["categories"] | |
| metadata.thing_classes = [c["name"] for c in categories] | |
| return metadata | |
| def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth): | |
| im = np.array(im) | |
| color_mode = map_color_mode(color_mode) | |
| instances = None | |
| if mode in {"Trees", "Both"}: | |
| tree_predictor = load_predictor(tree_model, tree_version, tree_pth, tree_threshold) | |
| tree_instances = load_instances(im, tree_predictor, segment_tree) | |
| instances = tree_instances | |
| if mode in {"Buildings", "Both"}: | |
| building_predictor = load_predictor(building_model, building_version, building_pth, building_threshold) | |
| building_instances = load_instances(im, building_predictor, segment_building) | |
| instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances) | |
| # Assuming 'urban-small_train' is intended for both Trees and Buildings | |
| metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json") | |
| visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode) | |
| output_image = visualizer.draw_instance_predictions(instances) | |
| return Image.fromarray(output_image.get_image()[:, :, ::-1]) |