Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| def preprocess_panoptic_gt(gt_labels, gt_masks, gt_semantic_seg, num_things, | |
| num_stuff, img_metas): | |
| """Preprocess the ground truth for a image. | |
| Args: | |
| gt_labels (Tensor): Ground truth labels of each bbox, | |
| with shape (num_gts, ). | |
| gt_masks (BitmapMasks): Ground truth masks of each instances | |
| of a image, shape (num_gts, h, w). | |
| gt_semantic_seg (Tensor | None): Ground truth of semantic | |
| segmentation with the shape (1, h, w). | |
| [0, num_thing_class - 1] means things, | |
| [num_thing_class, num_class-1] means stuff, | |
| 255 means VOID. It's None when training instance segmentation. | |
| img_metas (dict): List of image meta information. | |
| Returns: | |
| tuple: a tuple containing the following targets. | |
| - labels (Tensor): Ground truth class indices for a | |
| image, with shape (n, ), n is the sum of number | |
| of stuff type and number of instance in a image. | |
| - masks (Tensor): Ground truth mask for a image, with | |
| shape (n, h, w). Contains stuff and things when training | |
| panoptic segmentation, and things only when training | |
| instance segmentation. | |
| """ | |
| num_classes = num_things + num_stuff | |
| things_masks = gt_masks.pad(img_metas['pad_shape'][:2], pad_val=0)\ | |
| .to_tensor(dtype=torch.bool, device=gt_labels.device) | |
| if gt_semantic_seg is None: | |
| masks = things_masks.long() | |
| return gt_labels, masks | |
| things_labels = gt_labels | |
| gt_semantic_seg = gt_semantic_seg.squeeze(0) | |
| semantic_labels = torch.unique( | |
| gt_semantic_seg, | |
| sorted=False, | |
| return_inverse=False, | |
| return_counts=False) | |
| stuff_masks_list = [] | |
| stuff_labels_list = [] | |
| for label in semantic_labels: | |
| if label < num_things or label >= num_classes: | |
| continue | |
| stuff_mask = gt_semantic_seg == label | |
| stuff_masks_list.append(stuff_mask) | |
| stuff_labels_list.append(label) | |
| if len(stuff_masks_list) > 0: | |
| stuff_masks = torch.stack(stuff_masks_list, dim=0) | |
| stuff_labels = torch.stack(stuff_labels_list, dim=0) | |
| labels = torch.cat([things_labels, stuff_labels], dim=0) | |
| masks = torch.cat([things_masks, stuff_masks], dim=0) | |
| else: | |
| labels = things_labels | |
| masks = things_masks | |
| masks = masks.long() | |
| return labels, masks | |