Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import mmcv | |
| import numpy as np | |
| from mmdet.core import INSTANCE_OFFSET, bbox2result | |
| from mmdet.core.visualization import imshow_det_bboxes | |
| from ..builder import DETECTORS, build_backbone, build_head, build_neck | |
| from .single_stage import SingleStageDetector | |
| class MaskFormer(SingleStageDetector): | |
| r"""Implementation of `Per-Pixel Classification is | |
| NOT All You Need for Semantic Segmentation | |
| <https://arxiv.org/pdf/2107.06278>`_.""" | |
| def __init__(self, | |
| backbone, | |
| neck=None, | |
| panoptic_head=None, | |
| panoptic_fusion_head=None, | |
| train_cfg=None, | |
| test_cfg=None, | |
| init_cfg=None): | |
| super(SingleStageDetector, self).__init__(init_cfg=init_cfg) | |
| self.backbone = build_backbone(backbone) | |
| if neck is not None: | |
| self.neck = build_neck(neck) | |
| panoptic_head_ = copy.deepcopy(panoptic_head) | |
| panoptic_head_.update(train_cfg=train_cfg) | |
| panoptic_head_.update(test_cfg=test_cfg) | |
| self.panoptic_head = build_head(panoptic_head_) | |
| panoptic_fusion_head_ = copy.deepcopy(panoptic_fusion_head) | |
| panoptic_fusion_head_.update(test_cfg=test_cfg) | |
| self.panoptic_fusion_head = build_head(panoptic_fusion_head_) | |
| self.num_things_classes = self.panoptic_head.num_things_classes | |
| self.num_stuff_classes = self.panoptic_head.num_stuff_classes | |
| self.num_classes = self.panoptic_head.num_classes | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| # BaseDetector.show_result default for instance segmentation | |
| if self.num_stuff_classes > 0: | |
| self.show_result = self._show_pan_result | |
| def forward_dummy(self, img, img_metas): | |
| """Used for computing network flops. See | |
| `mmdetection/tools/analysis_tools/get_flops.py` | |
| Args: | |
| img (Tensor): of shape (N, C, H, W) encoding input images. | |
| Typically these should be mean centered and std scaled. | |
| img_metas (list[Dict]): list of image info dict where each dict | |
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
| For details on the values of these keys see | |
| `mmdet/datasets/pipelines/formatting.py:Collect`. | |
| """ | |
| super(SingleStageDetector, self).forward_train(img, img_metas) | |
| x = self.extract_feat(img) | |
| outs = self.panoptic_head(x, img_metas) | |
| return outs | |
| def forward_train(self, | |
| img, | |
| img_metas, | |
| gt_bboxes, | |
| gt_labels, | |
| gt_masks, | |
| gt_semantic_seg=None, | |
| gt_bboxes_ignore=None, | |
| **kargs): | |
| """ | |
| Args: | |
| img (Tensor): of shape (N, C, H, W) encoding input images. | |
| Typically these should be mean centered and std scaled. | |
| img_metas (list[Dict]): list of image info dict where each dict | |
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
| For details on the values of these keys see | |
| `mmdet/datasets/pipelines/formatting.py:Collect`. | |
| gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
| shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
| gt_labels (list[Tensor]): class indices corresponding to each box. | |
| gt_masks (list[BitmapMasks]): true segmentation masks for each box | |
| used if the architecture supports a segmentation task. | |
| gt_semantic_seg (list[tensor]): semantic segmentation mask for | |
| images for panoptic segmentation. | |
| Defaults to None for instance segmentation. | |
| gt_bboxes_ignore (list[Tensor]): specify which bounding | |
| boxes can be ignored when computing the loss. | |
| Defaults to None. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| # add batch_input_shape in img_metas | |
| super(SingleStageDetector, self).forward_train(img, img_metas) | |
| x = self.extract_feat(img) | |
| losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes, | |
| gt_labels, gt_masks, | |
| gt_semantic_seg, | |
| gt_bboxes_ignore) | |
| return losses | |
| def simple_test(self, imgs, img_metas, **kwargs): | |
| """Test without augmentation. | |
| Args: | |
| imgs (Tensor): A batch of images. | |
| img_metas (list[dict]): List of image information. | |
| Returns: | |
| list[dict[str, np.array | tuple[list]] | tuple[list]]: | |
| Semantic segmentation results and panoptic segmentation \ | |
| results of each image for panoptic segmentation, or formatted \ | |
| bbox and mask results of each image for instance segmentation. | |
| .. code-block:: none | |
| [ | |
| # panoptic segmentation | |
| { | |
| 'pan_results': np.array, # shape = [h, w] | |
| 'ins_results': tuple[list], | |
| # semantic segmentation results are not supported yet | |
| 'sem_results': np.array | |
| }, | |
| ... | |
| ] | |
| or | |
| .. code-block:: none | |
| [ | |
| # instance segmentation | |
| ( | |
| bboxes, # list[np.array] | |
| masks # list[list[np.array]] | |
| ), | |
| ... | |
| ] | |
| """ | |
| feats = self.extract_feat(imgs) | |
| mask_cls_results, mask_pred_results = self.panoptic_head.simple_test( | |
| feats, img_metas, **kwargs) | |
| results = self.panoptic_fusion_head.simple_test( | |
| mask_cls_results, mask_pred_results, img_metas, **kwargs) | |
| for i in range(len(results)): | |
| if 'pan_results' in results[i]: | |
| results[i]['pan_results'] = results[i]['pan_results'].detach( | |
| ).cpu().numpy() | |
| if 'ins_results' in results[i]: | |
| labels_per_image, bboxes, mask_pred_binary = results[i][ | |
| 'ins_results'] | |
| bbox_results = bbox2result(bboxes, labels_per_image, | |
| self.num_things_classes) | |
| mask_results = [[] for _ in range(self.num_things_classes)] | |
| for j, label in enumerate(labels_per_image): | |
| mask = mask_pred_binary[j].detach().cpu().numpy() | |
| mask_results[label].append(mask) | |
| results[i]['ins_results'] = bbox_results, mask_results | |
| assert 'sem_results' not in results[i], 'segmantic segmentation '\ | |
| 'results are not supported yet.' | |
| if self.num_stuff_classes == 0: | |
| results = [res['ins_results'] for res in results] | |
| return results | |
| def aug_test(self, imgs, img_metas, **kwargs): | |
| raise NotImplementedError | |
| def onnx_export(self, img, img_metas): | |
| raise NotImplementedError | |
| def _show_pan_result(self, | |
| img, | |
| result, | |
| score_thr=0.3, | |
| bbox_color=(72, 101, 241), | |
| text_color=(72, 101, 241), | |
| mask_color=None, | |
| thickness=2, | |
| font_size=13, | |
| win_name='', | |
| show=False, | |
| wait_time=0, | |
| out_file=None): | |
| """Draw `panoptic result` over `img`. | |
| Args: | |
| img (str or Tensor): The image to be displayed. | |
| result (dict): The results. | |
| score_thr (float, optional): Minimum score of bboxes to be shown. | |
| Default: 0.3. | |
| bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. | |
| The tuple of color should be in BGR order. Default: 'green'. | |
| text_color (str or tuple(int) or :obj:`Color`):Color of texts. | |
| The tuple of color should be in BGR order. Default: 'green'. | |
| mask_color (None or str or tuple(int) or :obj:`Color`): | |
| Color of masks. The tuple of color should be in BGR order. | |
| Default: None. | |
| thickness (int): Thickness of lines. Default: 2. | |
| font_size (int): Font size of texts. Default: 13. | |
| win_name (str): The window name. Default: ''. | |
| wait_time (float): Value of waitKey param. | |
| Default: 0. | |
| show (bool): Whether to show the image. | |
| Default: False. | |
| out_file (str or None): The filename to write the image. | |
| Default: None. | |
| Returns: | |
| img (Tensor): Only if not `show` or `out_file`. | |
| """ | |
| img = mmcv.imread(img) | |
| img = img.copy() | |
| pan_results = result['pan_results'] | |
| # keep objects ahead | |
| ids = np.unique(pan_results)[::-1] | |
| legal_indices = ids != self.num_classes # for VOID label | |
| ids = ids[legal_indices] | |
| labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |
| segms = (pan_results[None] == ids[:, None, None]) | |
| # if out_file specified, do not show image in window | |
| if out_file is not None: | |
| show = False | |
| # draw bounding boxes | |
| img = imshow_det_bboxes( | |
| img, | |
| segms=segms, | |
| labels=labels, | |
| class_names=self.CLASSES, | |
| bbox_color=bbox_color, | |
| text_color=text_color, | |
| mask_color=mask_color, | |
| thickness=thickness, | |
| font_size=font_size, | |
| win_name=win_name, | |
| show=show, | |
| wait_time=wait_time, | |
| out_file=out_file) | |
| if not (show or out_file): | |
| return img | |