Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch.nn as nn | |
| from mmcv.utils import Registry, build_from_cfg | |
| TRANSFORMER = Registry('Transformer') | |
| LINEAR_LAYERS = Registry('linear layers') | |
| def build_transformer(cfg, default_args=None): | |
| """Builder for Transformer.""" | |
| return build_from_cfg(cfg, TRANSFORMER, default_args) | |
| LINEAR_LAYERS.register_module('Linear', module=nn.Linear) | |
| def build_linear_layer(cfg, *args, **kwargs): | |
| """Build linear layer. | |
| Args: | |
| cfg (None or dict): The linear layer config, which should contain: | |
| - type (str): Layer type. | |
| - layer args: Args needed to instantiate an linear layer. | |
| args (argument list): Arguments passed to the `__init__` | |
| method of the corresponding linear layer. | |
| kwargs (keyword arguments): Keyword arguments passed to the `__init__` | |
| method of the corresponding linear layer. | |
| Returns: | |
| nn.Module: Created linear layer. | |
| """ | |
| if cfg is None: | |
| cfg_ = dict(type='Linear') | |
| else: | |
| if not isinstance(cfg, dict): | |
| raise TypeError('cfg must be a dict') | |
| if 'type' not in cfg: | |
| raise KeyError('the cfg dict must contain the key "type"') | |
| cfg_ = cfg.copy() | |
| layer_type = cfg_.pop('type') | |
| if layer_type not in LINEAR_LAYERS: | |
| raise KeyError(f'Unrecognized linear type {layer_type}') | |
| else: | |
| linear_layer = LINEAR_LAYERS.get(layer_type) | |
| layer = linear_layer(*args, **kwargs, **cfg_) | |
| return layer | |