Spaces:
Sleeping
Sleeping
| from utils.hparams import hparams | |
| class BaseAugmentation: | |
| """ | |
| Base class for data augmentation. | |
| All methods of this class should be thread-safe. | |
| 1. *process_item*: | |
| Apply augmentation to one piece of data. | |
| """ | |
| def __init__(self, data_dirs: list, augmentation_args: dict): | |
| self.raw_data_dirs = data_dirs | |
| self.augmentation_args = augmentation_args | |
| self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] | |
| def process_item(self, item: dict, **kwargs) -> dict: | |
| raise NotImplementedError() | |
| def require_same_keys(func): | |
| def run(*args, **kwargs): | |
| item: dict = args[1] | |
| res: dict = func(*args, **kwargs) | |
| assert set(item.keys()) == set(res.keys()), 'Item keys mismatch after augmentation.\n' \ | |
| f'Before: {sorted(item.keys())}\n' \ | |
| f'After: {sorted(res.keys())}' | |
| return res | |
| return run | |