Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import numpy as np | |
| import torch | |
| from mmdet.utils.util_mixins import NiceRepr | |
| class GeneralData(NiceRepr): | |
| """A general data structure of OpenMMlab. | |
| A data structure that stores the meta information, | |
| the annotations of the images or the model predictions, | |
| which can be used in communication between components. | |
| The attributes in `GeneralData` are divided into two parts, | |
| the `meta_info_fields` and the `data_fields` respectively. | |
| - `meta_info_fields`: Usually contains the | |
| information about the image such as filename, | |
| image_shape, pad_shape, etc. All attributes in | |
| it are immutable once set, | |
| but the user can add new meta information with | |
| `set_meta_info` function, all information can be accessed | |
| with methods `meta_info_keys`, `meta_info_values`, | |
| `meta_info_items`. | |
| - `data_fields`: Annotations or model predictions are | |
| stored. The attributes can be accessed or modified by | |
| dict-like or object-like operations, such as | |
| `.` , `[]`, `in`, `del`, `pop(str)` `get(str)`, `keys()`, | |
| `values()`, `items()`. Users can also apply tensor-like methods | |
| to all obj:`torch.Tensor` in the `data_fileds`, | |
| such as `.cuda()`, `.cpu()`, `.numpy()`, `device`, `.to()` | |
| `.detach()`, `.numpy()` | |
| Args: | |
| meta_info (dict, optional): A dict contains the meta information | |
| of single image. such as `img_shape`, `scale_factor`, etc. | |
| Default: None. | |
| data (dict, optional): A dict contains annotations of single image or | |
| model predictions. Default: None. | |
| Examples: | |
| >>> from mmdet.core import GeneralData | |
| >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |
| >>> instance_data = GeneralData(meta_info=img_meta) | |
| >>> img_shape in instance_data | |
| True | |
| >>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) | |
| >>> instance_data["det_scores"] = torch.Tensor([0.01, 0.1, 0.2, 0.3]) | |
| >>> print(results) | |
| <GeneralData( | |
| META INFORMATION | |
| img_shape: (800, 1196, 3) | |
| pad_shape: (800, 1216, 3) | |
| DATA FIELDS | |
| shape of det_labels: torch.Size([4]) | |
| shape of det_scores: torch.Size([4]) | |
| ) at 0x7f84acd10f90> | |
| >>> instance_data.det_scores | |
| tensor([0.0100, 0.1000, 0.2000, 0.3000]) | |
| >>> instance_data.det_labels | |
| tensor([0, 1, 2, 3]) | |
| >>> instance_data['det_labels'] | |
| tensor([0, 1, 2, 3]) | |
| >>> 'det_labels' in instance_data | |
| True | |
| >>> instance_data.img_shape | |
| (800, 1196, 3) | |
| >>> 'det_scores' in instance_data | |
| True | |
| >>> del instance_data.det_scores | |
| >>> 'det_scores' in instance_data | |
| False | |
| >>> det_labels = instance_data.pop('det_labels', None) | |
| >>> det_labels | |
| tensor([0, 1, 2, 3]) | |
| >>> 'det_labels' in instance_data | |
| >>> False | |
| """ | |
| def __init__(self, meta_info=None, data=None): | |
| self._meta_info_fields = set() | |
| self._data_fields = set() | |
| if meta_info is not None: | |
| self.set_meta_info(meta_info=meta_info) | |
| if data is not None: | |
| self.set_data(data) | |
| def set_meta_info(self, meta_info): | |
| """Add meta information. | |
| Args: | |
| meta_info (dict): A dict contains the meta information | |
| of image. such as `img_shape`, `scale_factor`, etc. | |
| Default: None. | |
| """ | |
| assert isinstance(meta_info, | |
| dict), f'meta should be a `dict` but get {meta_info}' | |
| meta = copy.deepcopy(meta_info) | |
| for k, v in meta.items(): | |
| # should be consistent with original meta_info | |
| if k in self._meta_info_fields: | |
| ori_value = getattr(self, k) | |
| if isinstance(ori_value, (torch.Tensor, np.ndarray)): | |
| if (ori_value == v).all(): | |
| continue | |
| else: | |
| raise KeyError( | |
| f'img_meta_info {k} has been set as ' | |
| f'{getattr(self, k)} before, which is immutable ') | |
| elif ori_value == v: | |
| continue | |
| else: | |
| raise KeyError( | |
| f'img_meta_info {k} has been set as ' | |
| f'{getattr(self, k)} before, which is immutable ') | |
| else: | |
| self._meta_info_fields.add(k) | |
| self.__dict__[k] = v | |
| def set_data(self, data): | |
| """Update a dict to `data_fields`. | |
| Args: | |
| data (dict): A dict contains annotations of image or | |
| model predictions. Default: None. | |
| """ | |
| assert isinstance(data, | |
| dict), f'meta should be a `dict` but get {data}' | |
| for k, v in data.items(): | |
| self.__setattr__(k, v) | |
| def new(self, meta_info=None, data=None): | |
| """Return a new results with same image meta information. | |
| Args: | |
| meta_info (dict, optional): A dict contains the meta information | |
| of image. such as `img_shape`, `scale_factor`, etc. | |
| Default: None. | |
| data (dict, optional): A dict contains annotations of image or | |
| model predictions. Default: None. | |
| """ | |
| new_data = self.__class__() | |
| new_data.set_meta_info(dict(self.meta_info_items())) | |
| if meta_info is not None: | |
| new_data.set_meta_info(meta_info) | |
| if data is not None: | |
| new_data.set_data(data) | |
| return new_data | |
| def keys(self): | |
| """ | |
| Returns: | |
| list: Contains all keys in data_fields. | |
| """ | |
| return [key for key in self._data_fields] | |
| def meta_info_keys(self): | |
| """ | |
| Returns: | |
| list: Contains all keys in meta_info_fields. | |
| """ | |
| return [key for key in self._meta_info_fields] | |
| def values(self): | |
| """ | |
| Returns: | |
| list: Contains all values in data_fields. | |
| """ | |
| return [getattr(self, k) for k in self.keys()] | |
| def meta_info_values(self): | |
| """ | |
| Returns: | |
| list: Contains all values in meta_info_fields. | |
| """ | |
| return [getattr(self, k) for k in self.meta_info_keys()] | |
| def items(self): | |
| for k in self.keys(): | |
| yield (k, getattr(self, k)) | |
| def meta_info_items(self): | |
| for k in self.meta_info_keys(): | |
| yield (k, getattr(self, k)) | |
| def __setattr__(self, name, val): | |
| if name in ('_meta_info_fields', '_data_fields'): | |
| if not hasattr(self, name): | |
| super().__setattr__(name, val) | |
| else: | |
| raise AttributeError( | |
| f'{name} has been used as a ' | |
| f'private attribute, which is immutable. ') | |
| else: | |
| if name in self._meta_info_fields: | |
| raise AttributeError(f'`{name}` is used in meta information,' | |
| f'which is immutable') | |
| self._data_fields.add(name) | |
| super().__setattr__(name, val) | |
| def __delattr__(self, item): | |
| if item in ('_meta_info_fields', '_data_fields'): | |
| raise AttributeError(f'{item} has been used as a ' | |
| f'private attribute, which is immutable. ') | |
| if item in self._meta_info_fields: | |
| raise KeyError(f'{item} is used in meta information, ' | |
| f'which is immutable.') | |
| super().__delattr__(item) | |
| if item in self._data_fields: | |
| self._data_fields.remove(item) | |
| # dict-like methods | |
| __setitem__ = __setattr__ | |
| __delitem__ = __delattr__ | |
| def __getitem__(self, name): | |
| return getattr(self, name) | |
| def get(self, *args): | |
| assert len(args) < 3, '`get` get more than 2 arguments' | |
| return self.__dict__.get(*args) | |
| def pop(self, *args): | |
| assert len(args) < 3, '`pop` get more than 2 arguments' | |
| name = args[0] | |
| if name in self._meta_info_fields: | |
| raise KeyError(f'{name} is a key in meta information, ' | |
| f'which is immutable') | |
| if args[0] in self._data_fields: | |
| self._data_fields.remove(args[0]) | |
| return self.__dict__.pop(*args) | |
| # with default value | |
| elif len(args) == 2: | |
| return args[1] | |
| else: | |
| raise KeyError(f'{args[0]}') | |
| def __contains__(self, item): | |
| return item in self._data_fields or \ | |
| item in self._meta_info_fields | |
| # Tensor-like methods | |
| def to(self, *args, **kwargs): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if hasattr(v, 'to'): | |
| v = v.to(*args, **kwargs) | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def cpu(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.cpu() | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def npu(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.npu() | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def mlu(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.mlu() | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def cuda(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.cuda() | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def detach(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.detach() | |
| new_data[k] = v | |
| return new_data | |
| # Tensor-like methods | |
| def numpy(self): | |
| """Apply same name function to all tensors in data_fields.""" | |
| new_data = self.new() | |
| for k, v in self.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.detach().cpu().numpy() | |
| new_data[k] = v | |
| return new_data | |
| def __nice__(self): | |
| repr = '\n \n META INFORMATION \n' | |
| for k, v in self.meta_info_items(): | |
| repr += f'{k}: {v} \n' | |
| repr += '\n DATA FIELDS \n' | |
| for k, v in self.items(): | |
| if isinstance(v, (torch.Tensor, np.ndarray)): | |
| repr += f'shape of {k}: {v.shape} \n' | |
| else: | |
| repr += f'{k}: {v} \n' | |
| return repr + '\n' | |