Spaces:
Sleeping
Sleeping
| import copy | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from deployment.modules.diffusion import ( | |
| GaussianDiffusionONNX, PitchDiffusionONNX, MultiVarianceDiffusionONNX | |
| ) | |
| from deployment.modules.rectified_flow import ( | |
| RectifiedFlowONNX, PitchRectifiedFlowONNX, MultiVarianceRectifiedFlowONNX | |
| ) | |
| from deployment.modules.fastspeech2 import FastSpeech2AcousticONNX, FastSpeech2VarianceONNX | |
| from modules.toplevel import DiffSingerAcoustic, DiffSingerVariance | |
| from utils.hparams import hparams | |
| class DiffSingerAcousticONNX(DiffSingerAcoustic): | |
| def __init__(self, vocab_size, out_dims): | |
| super().__init__(vocab_size, out_dims) | |
| del self.fs2 | |
| del self.diffusion | |
| self.fs2 = FastSpeech2AcousticONNX( | |
| vocab_size=vocab_size | |
| ) | |
| if self.diffusion_type == 'ddpm': | |
| self.diffusion = GaussianDiffusionONNX( | |
| out_dims=out_dims, | |
| num_feats=1, | |
| timesteps=hparams['timesteps'], | |
| k_step=hparams['K_step'], | |
| backbone_type=self.backbone_type, | |
| backbone_args=self.backbone_args, | |
| spec_min=hparams['spec_min'], | |
| spec_max=hparams['spec_max'] | |
| ) | |
| elif self.diffusion_type == 'reflow': | |
| self.diffusion = RectifiedFlowONNX( | |
| out_dims=out_dims, | |
| num_feats=1, | |
| t_start=hparams['T_start'], | |
| time_scale_factor=hparams['time_scale_factor'], | |
| backbone_type=self.backbone_type, | |
| backbone_args=self.backbone_args, | |
| spec_min=hparams['spec_min'], | |
| spec_max=hparams['spec_max'] | |
| ) | |
| else: | |
| raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") | |
| self.mel_base = hparams.get('mel_base', '10') | |
| def ensure_mel_base(self, mel): | |
| if self.mel_base != 'e': | |
| # log10 mel to log mel | |
| mel = mel * 2.30259 | |
| return mel | |
| def forward_fs2_aux( | |
| self, | |
| tokens: Tensor, | |
| durations: Tensor, | |
| f0: Tensor, | |
| variances: dict, | |
| gender: Tensor = None, | |
| velocity: Tensor = None, | |
| spk_embed: Tensor = None | |
| ): | |
| condition = self.fs2( | |
| tokens, durations, f0, variances=variances, | |
| gender=gender, velocity=velocity, spk_embed=spk_embed | |
| ) | |
| if self.use_shallow_diffusion: | |
| aux_mel_pred = self.aux_decoder(condition, infer=True) | |
| return condition, aux_mel_pred | |
| else: | |
| return condition | |
| def forward_shallow_diffusion( | |
| self, condition: Tensor, x_start: Tensor, | |
| depth, steps: int | |
| ) -> Tensor: | |
| mel_pred = self.diffusion(condition, x_start=x_start, depth=depth, steps=steps) | |
| return self.ensure_mel_base(mel_pred) | |
| def forward_diffusion(self, condition: Tensor, steps: int): | |
| mel_pred = self.diffusion(condition, steps=steps) | |
| return self.ensure_mel_base(mel_pred) | |
| def forward_shallow_reflow( | |
| self, condition: Tensor, x_end: Tensor, | |
| depth, steps: int | |
| ): | |
| mel_pred = self.diffusion(condition, x_end=x_end, depth=depth, steps=steps) | |
| return self.ensure_mel_base(mel_pred) | |
| def forward_reflow(self, condition: Tensor, steps: int): | |
| mel_pred = self.diffusion(condition, steps=steps) | |
| return self.ensure_mel_base(mel_pred) | |
| def view_as_fs2_aux(self) -> nn.Module: | |
| model = copy.deepcopy(self) | |
| del model.diffusion | |
| model.forward = model.forward_fs2_aux | |
| return model | |
| def view_as_diffusion(self) -> nn.Module: | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.use_shallow_diffusion: | |
| del model.aux_decoder | |
| model.forward = model.forward_shallow_diffusion | |
| else: | |
| model.forward = model.forward_diffusion | |
| return model | |
| def view_as_reflow(self) -> nn.Module: | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.use_shallow_diffusion: | |
| del model.aux_decoder | |
| model.forward = model.forward_shallow_reflow | |
| else: | |
| model.forward = model.forward_reflow | |
| return model | |
| class DiffSingerVarianceONNX(DiffSingerVariance): | |
| def __init__(self, vocab_size): | |
| super().__init__(vocab_size=vocab_size) | |
| del self.fs2 | |
| self.fs2 = FastSpeech2VarianceONNX( | |
| vocab_size=vocab_size | |
| ) | |
| self.hidden_size = hparams['hidden_size'] | |
| if self.predict_pitch: | |
| del self.pitch_predictor | |
| self.smooth: nn.Conv1d = None | |
| pitch_hparams = hparams['pitch_prediction_args'] | |
| if self.diffusion_type == 'ddpm': | |
| self.pitch_predictor = PitchDiffusionONNX( | |
| vmin=pitch_hparams['pitd_norm_min'], | |
| vmax=pitch_hparams['pitd_norm_max'], | |
| cmin=pitch_hparams['pitd_clip_min'], | |
| cmax=pitch_hparams['pitd_clip_max'], | |
| repeat_bins=pitch_hparams['repeat_bins'], | |
| timesteps=hparams['timesteps'], | |
| k_step=hparams['K_step'], | |
| backbone_type=self.pitch_backbone_type, | |
| backbone_args=self.pitch_backbone_args | |
| ) | |
| elif self.diffusion_type == 'reflow': | |
| self.pitch_predictor = PitchRectifiedFlowONNX( | |
| vmin=pitch_hparams['pitd_norm_min'], | |
| vmax=pitch_hparams['pitd_norm_max'], | |
| cmin=pitch_hparams['pitd_clip_min'], | |
| cmax=pitch_hparams['pitd_clip_max'], | |
| repeat_bins=pitch_hparams['repeat_bins'], | |
| time_scale_factor=hparams['time_scale_factor'], | |
| backbone_type=self.pitch_backbone_type, | |
| backbone_args=self.pitch_backbone_args | |
| ) | |
| else: | |
| raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") | |
| if self.predict_variances: | |
| del self.variance_predictor | |
| if self.diffusion_type == 'ddpm': | |
| self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusionONNX) | |
| elif self.diffusion_type == 'reflow': | |
| self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlowONNX) | |
| else: | |
| raise NotImplementedError(self.diffusion_type) | |
| def build_smooth_op(self, device): | |
| smooth_kernel_size = round(hparams['midi_smooth_width'] * hparams['audio_sample_rate'] / hparams['hop_size']) | |
| smooth = nn.Conv1d( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=smooth_kernel_size, | |
| bias=False, | |
| padding='same', | |
| padding_mode='replicate' | |
| ).eval() | |
| smooth_kernel = torch.sin(torch.from_numpy( | |
| np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi | |
| )) | |
| smooth_kernel /= smooth_kernel.sum() | |
| smooth.weight.data = smooth_kernel[None, None] | |
| self.smooth = smooth.to(device) | |
| def embed_frozen_spk(self, encoder_out): | |
| if hparams['use_spk_id'] and hasattr(self, 'frozen_spk_embed'): | |
| encoder_out += self.frozen_spk_embed | |
| return encoder_out | |
| def forward_linguistic_encoder_word(self, tokens, word_div, word_dur): | |
| encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur) | |
| encoder_out = self.embed_frozen_spk(encoder_out) | |
| return encoder_out, x_masks | |
| def forward_linguistic_encoder_phoneme(self, tokens, ph_dur): | |
| encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur) | |
| encoder_out = self.embed_frozen_spk(encoder_out) | |
| return encoder_out, x_masks | |
| def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): | |
| return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed) | |
| def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): | |
| mel2x = self.lr(x_dur) | |
| if x_dim is not None: | |
| x_src = F.pad(x_src, [0, 0, 1, 0]) | |
| mel2x = mel2x[..., None].repeat([1, 1, x_dim]) | |
| else: | |
| x_src = F.pad(x_src, [1, 0]) | |
| x_cond = torch.gather(x_src, 1, mel2x) | |
| return x_cond | |
| def forward_pitch_preprocess( | |
| self, encoder_out, ph_dur, | |
| note_midi=None, note_rest=None, note_dur=None, note_glide=None, | |
| pitch=None, expr=None, retake=None, spk_embed=None | |
| ): | |
| condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) | |
| if self.use_melody_encoder: | |
| if self.melody_encoder.use_glide_embed and note_glide is None: | |
| note_glide = torch.LongTensor([[0]]).to(encoder_out.device) | |
| melody_encoder_out = self.melody_encoder( | |
| note_midi, note_rest, note_dur, | |
| glide=note_glide | |
| ) | |
| melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) | |
| condition += melody_encoder_out | |
| if expr is None: | |
| retake_embed = self.pitch_retake_embed(retake.long()) | |
| else: | |
| retake_true_embed = self.pitch_retake_embed( | |
| torch.ones(1, 1, dtype=torch.long, device=encoder_out.device) | |
| ) # [B=1, T=1] => [B=1, T=1, H] | |
| retake_false_embed = self.pitch_retake_embed( | |
| torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device) | |
| ) # [B=1, T=1] => [B=1, T=1, H] | |
| expr = (expr * retake)[:, :, None] # [B, T, 1] | |
| retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed | |
| pitch_cond = condition + retake_embed | |
| frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) | |
| base_pitch = self.smooth(frame_midi_pitch) | |
| if self.use_melody_encoder: | |
| delta_pitch = (pitch - base_pitch) * ~retake | |
| pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) | |
| else: | |
| base_pitch = base_pitch * retake + pitch * ~retake | |
| pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) | |
| if hparams['use_spk_id'] and spk_embed is not None: | |
| pitch_cond += spk_embed | |
| return pitch_cond, base_pitch | |
| def forward_pitch_reflow( | |
| self, pitch_cond, steps: int = 10 | |
| ): | |
| x_pred = self.pitch_predictor(pitch_cond, steps=steps) | |
| return x_pred | |
| def forward_pitch_postprocess(self, x_pred, base_pitch): | |
| pitch_pred = self.pitch_predictor.clamp_spec(x_pred) + base_pitch | |
| return pitch_pred | |
| def forward_variance_preprocess( | |
| self, encoder_out, ph_dur, pitch, | |
| variances: dict = None, retake=None, spk_embed=None | |
| ): | |
| condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) | |
| variance_cond = condition + self.pitch_embed(pitch[:, :, None]) | |
| non_retake_masks = [ | |
| v_retake.float() # [B, T, 1] | |
| for v_retake in (~retake).split(1, dim=2) | |
| ] | |
| variance_embeds = [ | |
| self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks | |
| for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks) | |
| ] | |
| variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1) | |
| if hparams['use_spk_id'] and spk_embed is not None: | |
| variance_cond += spk_embed | |
| return variance_cond | |
| def forward_variance_reflow(self, variance_cond, steps: int = 10): | |
| xs_pred = self.variance_predictor(variance_cond, steps=steps) | |
| return xs_pred | |
| def forward_variance_postprocess(self, xs_pred): | |
| if self.variance_predictor.num_feats == 1: | |
| xs_pred = [xs_pred] | |
| else: | |
| xs_pred = xs_pred.unbind(dim=1) | |
| variance_pred = self.variance_predictor.clamp_spec(xs_pred) | |
| return tuple(variance_pred) | |
| def view_as_linguistic_encoder(self): | |
| model = copy.deepcopy(self) | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.fs2 = model.fs2.view_as_encoder() | |
| if self.predict_dur: | |
| model.forward = model.forward_linguistic_encoder_word | |
| else: | |
| model.forward = model.forward_linguistic_encoder_phoneme | |
| return model | |
| def view_as_dur_predictor(self): | |
| assert self.predict_dur | |
| model = copy.deepcopy(self) | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.fs2 = model.fs2.view_as_dur_predictor() | |
| model.forward = model.forward_dur_predictor | |
| return model | |
| def view_as_pitch_preprocess(self): | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.forward = model.forward_pitch_preprocess | |
| return model | |
| def view_as_pitch_predictor(self): | |
| assert self.predict_pitch | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| del model.lr | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.forward = model.forward_pitch_reflow | |
| return model | |
| def view_as_pitch_postprocess(self): | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.forward = model.forward_pitch_postprocess | |
| return model | |
| def view_as_variance_preprocess(self): | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| if self.predict_variances: | |
| del model.variance_predictor | |
| model.forward = model.forward_variance_preprocess | |
| return model | |
| def view_as_variance_predictor(self): | |
| assert self.predict_variances | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| del model.lr | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| model.forward = model.forward_variance_reflow | |
| return model | |
| def view_as_variance_postprocess(self): | |
| model = copy.deepcopy(self) | |
| del model.fs2 | |
| if self.predict_pitch: | |
| del model.pitch_predictor | |
| if self.use_melody_encoder: | |
| del model.melody_encoder | |
| model.forward = model.forward_variance_postprocess | |
| return model | |