Spaces:
Sleeping
Sleeping
| from copy import deepcopy | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from basics.base_augmentation import BaseAugmentation, require_same_keys | |
| from basics.base_pe import BasePE | |
| from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST | |
| from modules.fastspeech.tts_modules import LengthRegulator | |
| from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch | |
| from utils.hparams import hparams | |
| from utils.infer_utils import resample_align_curve | |
| class SpectrogramStretchAugmentation(BaseAugmentation): | |
| """ | |
| This class contains methods for frequency-domain and time-domain stretching augmentation. | |
| """ | |
| def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None): | |
| super().__init__(data_dirs, augmentation_args) | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.lr = LengthRegulator().to(self.device) | |
| self.pe = pe | |
| def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict: | |
| aug_item = deepcopy(item) | |
| waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) | |
| mel = get_mel_torch( | |
| waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'], | |
| hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'], | |
| fmin=hparams['fmin'], fmax=hparams['fmax'], | |
| keyshift=key_shift, speed=speed, device=self.device | |
| ) | |
| aug_item['mel'] = mel | |
| if speed != 1. or hparams['use_speed_embed']: | |
| aug_item['length'] = mel.shape[0] | |
| aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed | |
| aug_item['seconds'] /= aug_item['speed'] | |
| aug_item['ph_dur'] /= aug_item['speed'] | |
| aug_item['mel2ph'] = get_mel2ph_torch( | |
| self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device | |
| ).cpu().numpy() | |
| f0, _ = self.pe.get_pitch( | |
| waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'], | |
| hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'], | |
| speed=speed, interp_uv=True | |
| ) | |
| aug_item['f0'] = f0.astype(np.float32) | |
| # NOTE: variance curves are directly resampled according to speed, | |
| # despite how frequency-domain features change after the augmentation. | |
| # For acoustic models, this can bring more (but not much) difficulty | |
| # to learn how variance curves affect the mel spectrograms, since | |
| # they must realize how the augmentation causes the mismatch. | |
| # | |
| # This is a simple way to combine augmentation and variances. However, | |
| # dealing variance curves like this will decrease the accuracy of | |
| # variance controls. In most situations, not being ~100% accurate | |
| # will not ruin the user experience. For example, it does not matter | |
| # if the energy does not exactly equal the RMS; it is just fine | |
| # as long as higher energy can bring higher loudness and strength. | |
| # The neural networks itself cannot be 100% accurate, though. | |
| # | |
| # There are yet other choices to simulate variance curves: | |
| # 1. Re-extract the features from resampled waveforms; | |
| # 2. Re-extract the features from re-constructed waveforms using | |
| # the transformed mel spectrograms through the vocoder. | |
| # But there are actually no perfect ways to make them all accurate | |
| # and stable. | |
| for v_name in VARIANCE_CHECKLIST: | |
| if v_name in item: | |
| aug_item[v_name] = resample_align_curve( | |
| aug_item[v_name], | |
| original_timestep=self.timestep, | |
| target_timestep=self.timestep * aug_item['speed'], | |
| align_length=aug_item['length'] | |
| ) | |
| if key_shift != 0. or hparams['use_key_shift_embed']: | |
| if replace_spk_id is None: | |
| aug_item['key_shift'] = key_shift | |
| else: | |
| aug_item['spk_id'] = replace_spk_id | |
| aug_item['f0'] *= 2 ** (key_shift / 12) | |
| return aug_item | |