Spaces:
Sleeping
Sleeping
| import matplotlib | |
| import torch | |
| import torch.distributions | |
| import torch.optim | |
| import torch.utils.data | |
| import utils | |
| import utils.infer_utils | |
| from basics.base_dataset import BaseDataset | |
| from basics.base_task import BaseTask | |
| from basics.base_vocoder import BaseVocoder | |
| from modules.aux_decoder import build_aux_loss | |
| from modules.losses import DiffusionLoss, RectifiedFlowLoss | |
| from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput | |
| from modules.vocoders.registry import get_vocoder_cls | |
| from utils.hparams import hparams | |
| from utils.plot import spec_to_figure | |
| matplotlib.use('Agg') | |
| class AcousticDataset(BaseDataset): | |
| def __init__(self, prefix, preload=False): | |
| super(AcousticDataset, self).__init__(prefix, hparams['dataset_size_key'], preload) | |
| self.required_variances = {} # key: variance name, value: padding value | |
| if hparams['use_energy_embed']: | |
| self.required_variances['energy'] = 0.0 | |
| if hparams['use_breathiness_embed']: | |
| self.required_variances['breathiness'] = 0.0 | |
| if hparams['use_voicing_embed']: | |
| self.required_variances['voicing'] = 0.0 | |
| if hparams['use_tension_embed']: | |
| self.required_variances['tension'] = 0.0 | |
| self.need_key_shift = hparams['use_key_shift_embed'] | |
| self.need_speed = hparams['use_speed_embed'] | |
| self.need_spk_id = hparams['use_spk_id'] | |
| def collater(self, samples): | |
| batch = super().collater(samples) | |
| if batch['size'] == 0: | |
| return batch | |
| tokens = utils.collate_nd([s['tokens'] for s in samples], 0) | |
| f0 = utils.collate_nd([s['f0'] for s in samples], 0.0) | |
| mel2ph = utils.collate_nd([s['mel2ph'] for s in samples], 0) | |
| mel = utils.collate_nd([s['mel'] for s in samples], 0.0) | |
| batch.update({ | |
| 'tokens': tokens, | |
| 'mel2ph': mel2ph, | |
| 'mel': mel, | |
| 'f0': f0, | |
| }) | |
| for v_name, v_pad in self.required_variances.items(): | |
| batch[v_name] = utils.collate_nd([s[v_name] for s in samples], v_pad) | |
| if self.need_key_shift: | |
| batch['key_shift'] = torch.FloatTensor([s['key_shift'] for s in samples])[:, None] | |
| if self.need_speed: | |
| batch['speed'] = torch.FloatTensor([s['speed'] for s in samples])[:, None] | |
| if self.need_spk_id: | |
| spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) | |
| batch['spk_ids'] = spk_ids | |
| return batch | |
| class AcousticTask(BaseTask): | |
| def __init__(self): | |
| super().__init__() | |
| self.dataset_cls = AcousticDataset | |
| self.diffusion_type = hparams['diffusion_type'] | |
| assert self.diffusion_type in ['ddpm', 'reflow'], f"Unknown diffusion type: {self.diffusion_type}" | |
| self.use_shallow_diffusion = hparams['use_shallow_diffusion'] | |
| if self.use_shallow_diffusion: | |
| self.shallow_args = hparams['shallow_diffusion_args'] | |
| self.train_aux_decoder = self.shallow_args['train_aux_decoder'] | |
| self.train_diffusion = self.shallow_args['train_diffusion'] | |
| self.use_vocoder = hparams['infer'] or hparams['val_with_vocoder'] | |
| if self.use_vocoder: | |
| self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() | |
| self.logged_gt_wav = set() | |
| self.required_variances = [] | |
| if hparams['use_energy_embed']: | |
| self.required_variances.append('energy') | |
| if hparams['use_breathiness_embed']: | |
| self.required_variances.append('breathiness') | |
| if hparams['use_voicing_embed']: | |
| self.required_variances.append('voicing') | |
| if hparams['use_tension_embed']: | |
| self.required_variances.append('tension') | |
| super()._finish_init() | |
| def _build_model(self): | |
| return DiffSingerAcoustic( | |
| vocab_size=len(self.phone_encoder), | |
| out_dims=hparams['audio_num_mel_bins'] | |
| ) | |
| # noinspection PyAttributeOutsideInit | |
| def build_losses_and_metrics(self): | |
| if self.use_shallow_diffusion: | |
| self.aux_mel_loss = build_aux_loss(self.shallow_args['aux_decoder_arch']) | |
| self.lambda_aux_mel_loss = hparams['lambda_aux_mel_loss'] | |
| self.register_validation_loss('aux_mel_loss') | |
| if self.diffusion_type == 'ddpm': | |
| self.mel_loss = DiffusionLoss(loss_type=hparams['main_loss_type']) | |
| elif self.diffusion_type == 'reflow': | |
| self.mel_loss = RectifiedFlowLoss( | |
| loss_type=hparams['main_loss_type'], log_norm=hparams['main_loss_log_norm'] | |
| ) | |
| else: | |
| raise ValueError(f"Unknown diffusion type: {self.diffusion_type}") | |
| self.register_validation_loss('mel_loss') | |
| def run_model(self, sample, infer=False): | |
| txt_tokens = sample['tokens'] # [B, T_ph] | |
| target = sample['mel'] # [B, T_s, M] | |
| mel2ph = sample['mel2ph'] # [B, T_s] | |
| f0 = sample['f0'] | |
| variances = { | |
| v_name: sample[v_name] | |
| for v_name in self.required_variances | |
| } | |
| key_shift = sample.get('key_shift') | |
| speed = sample.get('speed') | |
| if hparams['use_spk_id']: | |
| spk_embed_id = sample['spk_ids'] | |
| else: | |
| spk_embed_id = None | |
| output: ShallowDiffusionOutput = self.model( | |
| txt_tokens, mel2ph=mel2ph, f0=f0, **variances, | |
| key_shift=key_shift, speed=speed, spk_embed_id=spk_embed_id, | |
| gt_mel=target, infer=infer | |
| ) | |
| if infer: | |
| return output | |
| else: | |
| losses = {} | |
| if output.aux_out is not None: | |
| aux_out = output.aux_out | |
| norm_gt = self.model.aux_decoder.norm_spec(target) | |
| aux_mel_loss = self.lambda_aux_mel_loss * self.aux_mel_loss(aux_out, norm_gt) | |
| losses['aux_mel_loss'] = aux_mel_loss | |
| non_padding = (mel2ph > 0).unsqueeze(-1).float() | |
| if output.diff_out is not None: | |
| if self.diffusion_type == 'ddpm': | |
| x_recon, x_noise = output.diff_out | |
| mel_loss = self.mel_loss(x_recon, x_noise, non_padding=non_padding) | |
| elif self.diffusion_type == 'reflow': | |
| v_pred, v_gt, t = output.diff_out | |
| mel_loss = self.mel_loss(v_pred, v_gt, t=t, non_padding=non_padding) | |
| else: | |
| raise ValueError(f"Unknown diffusion type: {self.diffusion_type}") | |
| losses['mel_loss'] = mel_loss | |
| return losses | |
| def on_train_start(self): | |
| if self.use_vocoder and self.vocoder.get_device() != self.device: | |
| self.vocoder.to_device(self.device) | |
| def _on_validation_start(self): | |
| if self.use_vocoder and self.vocoder.get_device() != self.device: | |
| self.vocoder.to_device(self.device) | |
| def _validation_step(self, sample, batch_idx): | |
| losses = self.run_model(sample, infer=False) | |
| if sample['size'] > 0 and min(sample['indices']) < hparams['num_valid_plots']: | |
| mel_out: ShallowDiffusionOutput = self.run_model(sample, infer=True) | |
| for i in range(len(sample['indices'])): | |
| data_idx = sample['indices'][i].item() | |
| if data_idx < hparams['num_valid_plots']: | |
| if self.use_vocoder: | |
| self.plot_wav( | |
| data_idx, sample['mel'][i], | |
| mel_out.aux_out[i] if mel_out.aux_out is not None else None, | |
| mel_out.diff_out[i], | |
| sample['f0'][i] | |
| ) | |
| if mel_out.aux_out is not None: | |
| self.plot_mel(data_idx, sample['mel'][i], mel_out.aux_out[i], 'auxmel') | |
| if mel_out.diff_out is not None: | |
| self.plot_mel(data_idx, sample['mel'][i], mel_out.diff_out[i], 'diffmel') | |
| return losses, sample['size'] | |
| ############ | |
| # validation plots | |
| ############ | |
| def plot_wav(self, data_idx, gt_mel, aux_mel, diff_mel, f0): | |
| f0_len = self.valid_dataset.metadata['f0'][data_idx] | |
| mel_len = self.valid_dataset.metadata['mel'][data_idx] | |
| gt_mel = gt_mel[:mel_len].unsqueeze(0) | |
| if aux_mel is not None: | |
| aux_mel = aux_mel[:mel_len].unsqueeze(0) | |
| if diff_mel is not None: | |
| diff_mel = diff_mel[:mel_len].unsqueeze(0) | |
| f0 = f0[:f0_len].unsqueeze(0) | |
| if data_idx not in self.logged_gt_wav: | |
| gt_wav = self.vocoder.spec2wav_torch(gt_mel, f0=f0) | |
| self.logger.all_rank_experiment.add_audio( | |
| f'gt_{data_idx}', gt_wav, | |
| sample_rate=hparams['audio_sample_rate'], | |
| global_step=self.global_step | |
| ) | |
| self.logged_gt_wav.add(data_idx) | |
| if aux_mel is not None: | |
| aux_wav = self.vocoder.spec2wav_torch(aux_mel, f0=f0) | |
| self.logger.all_rank_experiment.add_audio( | |
| f'aux_{data_idx}', aux_wav, | |
| sample_rate=hparams['audio_sample_rate'], | |
| global_step=self.global_step | |
| ) | |
| if diff_mel is not None: | |
| diff_wav = self.vocoder.spec2wav_torch(diff_mel, f0=f0) | |
| self.logger.all_rank_experiment.add_audio( | |
| f'diff_{data_idx}', diff_wav, | |
| sample_rate=hparams['audio_sample_rate'], | |
| global_step=self.global_step | |
| ) | |
| def plot_mel(self, data_idx, gt_spec, out_spec, name_prefix='mel'): | |
| vmin = hparams['mel_vmin'] | |
| vmax = hparams['mel_vmax'] | |
| mel_len = self.valid_dataset.metadata['mel'][data_idx] | |
| spec_cat = torch.cat([(out_spec - gt_spec).abs() + vmin, gt_spec, out_spec], -1) | |
| title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" | |
| self.logger.all_rank_experiment.add_figure(f'{name_prefix}_{data_idx}', spec_to_figure( | |
| spec_cat[:mel_len], vmin, vmax, title_text | |
| ), global_step=self.global_step) | |