Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # Copyright (c) 2019 Western Digital Corporation or its affiliates. | |
| import torch | |
| from ..builder import DETECTORS | |
| from .single_stage import SingleStageDetector | |
| class YOLOV3(SingleStageDetector): | |
| def __init__(self, | |
| backbone, | |
| neck, | |
| bbox_head, | |
| train_cfg=None, | |
| test_cfg=None, | |
| pretrained=None, | |
| init_cfg=None): | |
| super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg, | |
| test_cfg, pretrained, init_cfg) | |
| def onnx_export(self, img, img_metas): | |
| """Test function for exporting to ONNX, without test time augmentation. | |
| Args: | |
| img (torch.Tensor): input images. | |
| img_metas (list[dict]): List of image information. | |
| Returns: | |
| tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] | |
| and class labels of shape [N, num_det]. | |
| """ | |
| x = self.extract_feat(img) | |
| outs = self.bbox_head.forward(x) | |
| # get shape as tensor | |
| img_shape = torch._shape_as_tensor(img)[2:] | |
| img_metas[0]['img_shape_for_onnx'] = img_shape | |
| det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) | |
| return det_bboxes, det_labels | |