Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from ..builder import DETECTORS | |
| from .faster_rcnn import FasterRCNN | |
| class TridentFasterRCNN(FasterRCNN): | |
| """Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_""" | |
| def __init__(self, | |
| backbone, | |
| rpn_head, | |
| roi_head, | |
| train_cfg, | |
| test_cfg, | |
| neck=None, | |
| pretrained=None, | |
| init_cfg=None): | |
| super(TridentFasterRCNN, self).__init__( | |
| backbone=backbone, | |
| neck=neck, | |
| rpn_head=rpn_head, | |
| roi_head=roi_head, | |
| train_cfg=train_cfg, | |
| test_cfg=test_cfg, | |
| pretrained=pretrained, | |
| init_cfg=init_cfg) | |
| assert self.backbone.num_branch == self.roi_head.num_branch | |
| assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx | |
| self.num_branch = self.backbone.num_branch | |
| self.test_branch_idx = self.backbone.test_branch_idx | |
| def simple_test(self, img, img_metas, proposals=None, rescale=False): | |
| """Test without augmentation.""" | |
| assert self.with_bbox, 'Bbox head must be implemented.' | |
| x = self.extract_feat(img) | |
| if proposals is None: | |
| num_branch = (self.num_branch if self.test_branch_idx == -1 else 1) | |
| trident_img_metas = img_metas * num_branch | |
| proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas) | |
| else: | |
| proposal_list = proposals | |
| # TODO: Fix trident_img_metas undefined errors | |
| # when proposals is specified | |
| return self.roi_head.simple_test( | |
| x, proposal_list, trident_img_metas, rescale=rescale) | |
| def aug_test(self, imgs, img_metas, rescale=False): | |
| """Test with augmentations. | |
| If rescale is False, then returned bboxes and masks will fit the scale | |
| of imgs[0]. | |
| """ | |
| x = self.extract_feats(imgs) | |
| num_branch = (self.num_branch if self.test_branch_idx == -1 else 1) | |
| trident_img_metas = [img_metas * num_branch for img_metas in img_metas] | |
| proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas) | |
| return self.roi_head.aug_test( | |
| x, proposal_list, img_metas, rescale=rescale) | |
| def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs): | |
| """make copies of img and gts to fit multi-branch.""" | |
| trident_gt_bboxes = tuple(gt_bboxes * self.num_branch) | |
| trident_gt_labels = tuple(gt_labels * self.num_branch) | |
| trident_img_metas = tuple(img_metas * self.num_branch) | |
| return super(TridentFasterRCNN, | |
| self).forward_train(img, trident_img_metas, | |
| trident_gt_bboxes, trident_gt_labels) | |