Spaces:
Sleeping
Sleeping
| from collections import OrderedDict | |
| import tqdm | |
| import json | |
| import pathlib | |
| import numpy as np | |
| import torch | |
| from typing import Dict | |
| from basics.base_svs_infer import BaseSVSInfer | |
| from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST | |
| from modules.fastspeech.tts_modules import LengthRegulator | |
| from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput | |
| from modules.vocoders.registry import VOCODERS | |
| from utils import load_ckpt | |
| from utils.hparams import hparams | |
| from utils.infer_utils import cross_fade, resample_align_curve, save_wav | |
| from utils.phoneme_utils import build_phoneme_list | |
| from utils.text_encoder import TokenTextEncoder | |
| class DiffSingerAcousticInfer(BaseSVSInfer): | |
| def __init__(self, device=None, load_model=True, load_vocoder=True, ckpt_steps=None): | |
| super().__init__(device=device) | |
| if load_model: | |
| self.variance_checklist = [] | |
| self.variances_to_embed = set() | |
| if hparams.get('use_energy_embed', False): | |
| self.variances_to_embed.add('energy') | |
| if hparams.get('use_breathiness_embed', False): | |
| self.variances_to_embed.add('breathiness') | |
| if hparams.get('use_voicing_embed', False): | |
| self.variances_to_embed.add('voicing') | |
| if hparams.get('use_tension_embed', False): | |
| self.variances_to_embed.add('tension') | |
| self.ph_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) | |
| if hparams['use_spk_id']: | |
| with open(pathlib.Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f: | |
| self.spk_map = json.load(f) | |
| assert isinstance(self.spk_map, dict) and len(self.spk_map) > 0, 'Invalid or empty speaker map!' | |
| assert len(self.spk_map) == len(set(self.spk_map.values())), 'Duplicate speaker id in speaker map!' | |
| self.model = self.build_model(ckpt_steps=ckpt_steps) | |
| self.lr = LengthRegulator().to(self.device) | |
| if load_vocoder: | |
| self.vocoder = self.build_vocoder() | |
| def build_model(self, ckpt_steps=None): | |
| model = DiffSingerAcoustic( | |
| vocab_size=len(self.ph_encoder), | |
| out_dims=hparams['audio_num_mel_bins'] | |
| ).eval().to(self.device) | |
| load_ckpt(model, hparams['work_dir'], ckpt_steps=ckpt_steps, | |
| prefix_in_ckpt='model', strict=True, device=self.device) | |
| return model | |
| def build_vocoder(self): | |
| if hparams['vocoder'] in VOCODERS: | |
| vocoder = VOCODERS[hparams['vocoder']]() | |
| else: | |
| vocoder = VOCODERS[hparams['vocoder'].split('.')[-1]]() | |
| vocoder.to_device(self.device) | |
| return vocoder | |
| def preprocess_input(self, param, idx=0): | |
| """ | |
| :param param: one segment in the .ds file | |
| :param idx: index of the segment | |
| :return: batch of the model inputs | |
| """ | |
| batch = {} | |
| summary = OrderedDict() | |
| txt_tokens = torch.LongTensor([self.ph_encoder.encode(param['ph_seq'])]).to(self.device) # => [B, T_txt] | |
| batch['tokens'] = txt_tokens | |
| ph_dur = torch.from_numpy(np.array(param['ph_dur'].split(), np.float32)).to(self.device) | |
| ph_acc = torch.round(torch.cumsum(ph_dur, dim=0) / self.timestep + 0.5).long() | |
| durations = torch.diff(ph_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device))[None] # => [B=1, T_txt] | |
| mel2ph = self.lr(durations, txt_tokens == 0) # => [B=1, T] | |
| batch['mel2ph'] = mel2ph | |
| length = mel2ph.size(1) # => T | |
| summary['tokens'] = txt_tokens.size(1) | |
| summary['frames'] = length | |
| summary['seconds'] = '%.2f' % (length * self.timestep) | |
| if hparams['use_spk_id']: | |
| spk_mix_id, spk_mix_value = self.load_speaker_mix( | |
| param_src=param, summary_dst=summary, mix_mode='frame', mix_length=length | |
| ) | |
| batch['spk_mix_id'] = spk_mix_id | |
| batch['spk_mix_value'] = spk_mix_value | |
| batch['f0'] = torch.from_numpy(resample_align_curve( | |
| np.array(param['f0_seq'].split(), np.float32), | |
| original_timestep=float(param['f0_timestep']), | |
| target_timestep=self.timestep, | |
| align_length=length | |
| )).to(self.device)[None] | |
| for v_name in VARIANCE_CHECKLIST: | |
| if v_name in self.variances_to_embed: | |
| batch[v_name] = torch.from_numpy(resample_align_curve( | |
| np.array(param[v_name].split(), np.float32), | |
| original_timestep=float(param[f'{v_name}_timestep']), | |
| target_timestep=self.timestep, | |
| align_length=length | |
| )).to(self.device)[None] | |
| summary[v_name] = 'manual' | |
| if hparams['use_key_shift_embed']: | |
| shift_min, shift_max = hparams['augmentation_args']['random_pitch_shifting']['range'] | |
| gender = param.get('gender') | |
| if gender is None: | |
| gender = 0. | |
| if isinstance(gender, (int, float, bool)): # static gender value | |
| summary['gender'] = f'static({gender:.3f})' | |
| key_shift_value = gender * shift_max if gender >= 0 else gender * abs(shift_min) | |
| batch['key_shift'] = torch.FloatTensor([key_shift_value]).to(self.device)[:, None] # => [B=1, T=1] | |
| else: | |
| summary['gender'] = 'dynamic' | |
| gender_seq = resample_align_curve( | |
| np.array(gender.split(), np.float32), | |
| original_timestep=float(param['gender_timestep']), | |
| target_timestep=self.timestep, | |
| align_length=length | |
| ) | |
| gender_mask = gender_seq >= 0 | |
| key_shift_seq = gender_seq * (gender_mask * shift_max + (1 - gender_mask) * abs(shift_min)) | |
| batch['key_shift'] = torch.clip( | |
| torch.from_numpy(key_shift_seq.astype(np.float32)).to(self.device)[None], # => [B=1, T] | |
| min=shift_min, max=shift_max | |
| ) | |
| if hparams['use_speed_embed']: | |
| if param.get('velocity') is None: | |
| summary['velocity'] = 'default' | |
| batch['speed'] = torch.FloatTensor([1.]).to(self.device)[:, None] # => [B=1, T=1] | |
| else: | |
| summary['velocity'] = 'manual' | |
| speed_min, speed_max = hparams['augmentation_args']['random_time_stretching']['range'] | |
| speed_seq = resample_align_curve( | |
| np.array(param['velocity'].split(), np.float32), | |
| original_timestep=float(param['velocity_timestep']), | |
| target_timestep=self.timestep, | |
| align_length=length | |
| ) | |
| batch['speed'] = torch.clip( | |
| torch.from_numpy(speed_seq.astype(np.float32)).to(self.device)[None], # => [B=1, T] | |
| min=speed_min, max=speed_max | |
| ) | |
| print(f'[{idx}]\t' + ', '.join(f'{k}: {v}' for k, v in summary.items())) | |
| return batch | |
| def forward_model(self, sample): | |
| txt_tokens = sample['tokens'] | |
| variances = { | |
| v_name: sample.get(v_name) | |
| for v_name in self.variances_to_embed | |
| } | |
| if hparams['use_spk_id']: | |
| spk_mix_id = sample['spk_mix_id'] | |
| spk_mix_value = sample['spk_mix_value'] | |
| # perform mixing on spk embed | |
| spk_mix_embed = torch.sum( | |
| self.model.fs2.spk_embed(spk_mix_id) * spk_mix_value.unsqueeze(3), # => [B, T, N, H] | |
| dim=2, keepdim=False | |
| ) # => [B, T, H] | |
| else: | |
| spk_mix_embed = None | |
| mel_pred: ShallowDiffusionOutput = self.model( | |
| txt_tokens, mel2ph=sample['mel2ph'], f0=sample['f0'], **variances, | |
| key_shift=sample.get('key_shift'), speed=sample.get('speed'), | |
| spk_mix_embed=spk_mix_embed, infer=True | |
| ) | |
| return mel_pred.diff_out | |
| def run_vocoder(self, spec, **kwargs): | |
| y = self.vocoder.spec2wav_torch(spec, **kwargs) | |
| return y[None] | |
| def run_inference( | |
| self, params, | |
| out_dir: pathlib.Path = None, | |
| title: str = None, | |
| num_runs: int = 1, | |
| spk_mix: Dict[str, float] = None, | |
| seed: int = -1, | |
| save_mel: bool = False | |
| ): | |
| batches = [self.preprocess_input(param, idx=i) for i, param in enumerate(params)] | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| suffix = '.wav' if not save_mel else '.mel.pt' | |
| for i in range(num_runs): | |
| if save_mel: | |
| result = [] | |
| else: | |
| result = np.zeros(0) | |
| current_length = 0 | |
| for param, batch in tqdm.tqdm( | |
| zip(params, batches), desc='infer segments', total=len(params) | |
| ): | |
| if 'seed' in param: | |
| torch.manual_seed(param["seed"] & 0xffff_ffff) | |
| torch.cuda.manual_seed_all(param["seed"] & 0xffff_ffff) | |
| elif seed >= 0: | |
| torch.manual_seed(seed & 0xffff_ffff) | |
| torch.cuda.manual_seed_all(seed & 0xffff_ffff) | |
| mel_pred = self.forward_model(batch) | |
| if save_mel: | |
| result.append({ | |
| 'offset': param.get('offset', 0.), | |
| 'mel': mel_pred.cpu(), | |
| 'f0': batch['f0'].cpu() | |
| }) | |
| else: | |
| waveform_pred = self.run_vocoder(mel_pred, f0=batch['f0'])[0].cpu().numpy() | |
| silent_length = round(param.get('offset', 0) * hparams['audio_sample_rate']) - current_length | |
| if silent_length >= 0: | |
| result = np.append(result, np.zeros(silent_length)) | |
| result = np.append(result, waveform_pred) | |
| else: | |
| result = cross_fade(result, waveform_pred, current_length + silent_length) | |
| current_length = current_length + silent_length + waveform_pred.shape[0] | |
| if num_runs > 1: | |
| filename = f'{title}-{str(i).zfill(3)}{suffix}' | |
| else: | |
| filename = title + suffix | |
| save_path = out_dir / filename | |
| if save_mel: | |
| print(f'| save mel: {save_path}') | |
| torch.save(result, save_path) | |
| else: | |
| print(f'| save audio: {save_path}') | |
| save_wav(result, save_path, hparams['audio_sample_rate']) | |