Spaces:
Sleeping
Sleeping
| def get_backbone_type(root_config: dict, nested_config: dict = None): | |
| if nested_config is None: | |
| nested_config = root_config | |
| return nested_config.get( | |
| 'backbone_type', | |
| root_config.get( | |
| 'backbone_type', | |
| root_config.get('diff_decoder_type', 'wavenet') | |
| ) | |
| ) | |
| def get_backbone_args(config: dict, backbone_type: str): | |
| args = config.get('backbone_args') | |
| if args is not None: | |
| return args | |
| elif backbone_type == 'wavenet': | |
| return { | |
| 'num_layers': config.get('residual_layers'), | |
| 'num_channels': config.get('residual_channels'), | |
| 'dilation_cycle_length': config.get('dilation_cycle_length'), | |
| } | |
| else: | |
| return None | |