Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import List, Tuple | |
| import torch | |
| from torch import Tensor | |
| from modules.core import ( | |
| GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion | |
| ) | |
| def extract(a, t): | |
| return a[t].reshape((1, 1, 1, 1)) | |
| # noinspection PyMethodOverriding | |
| class GaussianDiffusionONNX(GaussianDiffusion): | |
| def backbone(self): | |
| return self.denoise_fn | |
| # We give up the setter for the property `backbone` because this will cause TorchScript to fail | |
| # @backbone.setter | |
| def set_backbone(self, value): | |
| self.denoise_fn = value | |
| def q_sample(self, x_start, t, noise): | |
| return ( | |
| extract(self.sqrt_alphas_cumprod, t) * x_start + | |
| extract(self.sqrt_one_minus_alphas_cumprod, t) * noise | |
| ) | |
| def p_sample(self, x, t, cond): | |
| x_pred = self.denoise_fn(x, t, cond) | |
| x_recon = ( | |
| extract(self.sqrt_recip_alphas_cumprod, t) * x - | |
| extract(self.sqrt_recipm1_alphas_cumprod, t) * x_pred | |
| ) | |
| # This is previously inherited from original DiffSinger repository | |
| # and disabled due to some loudness issues when speedup = 1. | |
| # x_recon = torch.clamp(x_recon, min=-1., max=1.) | |
| model_mean = ( | |
| extract(self.posterior_mean_coef1, t) * x_recon + | |
| extract(self.posterior_mean_coef2, t) * x | |
| ) | |
| model_log_variance = extract(self.posterior_log_variance_clipped, t) | |
| noise = torch.randn_like(x) | |
| # no noise when t == 0 | |
| nonzero_mask = ((t > 0).float()).reshape(1, 1, 1, 1) | |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise | |
| def p_sample_ddim(self, x, t, interval: int, cond): | |
| a_t = extract(self.alphas_cumprod, t) | |
| t_prev = t - interval | |
| a_prev = extract(self.alphas_cumprod, t_prev * (t_prev > 0)) | |
| noise_pred = self.denoise_fn(x, t, cond=cond) | |
| x_prev = a_prev.sqrt() * ( | |
| x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred | |
| ) | |
| return x_prev | |
| def plms_get_x_pred(self, x, noise_t, t, t_prev): | |
| a_t = extract(self.alphas_cumprod, t) | |
| a_prev = extract(self.alphas_cumprod, t_prev) | |
| a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() | |
| x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / ( | |
| a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) | |
| x_pred = x + x_delta | |
| return x_pred | |
| def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor], stage: int): | |
| noise_pred = self.denoise_fn(x_prev, t, cond) | |
| t_prev = t - interval | |
| t_prev = t_prev * (t_prev > 0) | |
| if stage == 0: | |
| x_pred = self.plms_get_x_pred(x_prev, noise_pred, t, t_prev) | |
| noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond) | |
| noise_pred_prime = (noise_pred + noise_pred_prev) / 2. | |
| elif stage == 1: | |
| noise_pred_prime = (3. * noise_pred - noise_list[-1]) / 2. | |
| elif stage == 2: | |
| noise_pred_prime = (23. * noise_pred - 16. * noise_list[-1] + 5. * noise_list[-2]) / 12. | |
| else: | |
| noise_pred_prime = (55. * noise_pred - 59. * noise_list[-1] + 37. | |
| * noise_list[-2] - 9. * noise_list[-3]) / 24. | |
| x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev) | |
| return noise_pred, x_prev | |
| def norm_spec(self, x): | |
| k = (self.spec_max - self.spec_min) / 2. | |
| b = (self.spec_max + self.spec_min) / 2. | |
| return (x - b) / k | |
| def denorm_spec(self, x): | |
| k = (self.spec_max - self.spec_min) / 2. | |
| b = (self.spec_max + self.spec_min) / 2. | |
| return x * k + b | |
| def forward(self, condition, x_start=None, depth=None, steps: int = 10): | |
| condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T] | |
| device = condition.device | |
| n_frames = condition.shape[2] | |
| noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device) | |
| if x_start is None: | |
| speedup = max(1, self.timesteps // steps) | |
| speedup = self.timestep_factors[torch.sum(self.timestep_factors <= speedup) - 1] | |
| step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None] | |
| x = noise | |
| else: | |
| depth_int64 = min(torch.round(depth * self.timesteps).long(), self.k_step) | |
| speedup = max(1, depth_int64 // steps) | |
| depth_int64 = depth_int64 // speedup * speedup # make depth_int64 a multiple of speedup | |
| step_range = torch.arange(0, depth_int64, speedup, dtype=torch.long, device=device).flip(0)[:, None] | |
| x_start = self.norm_spec(x_start).transpose(-2, -1) | |
| if self.num_feats == 1: | |
| x_start = x_start[:, None, :, :] | |
| if depth_int64 >= self.timesteps: | |
| x = noise | |
| elif depth_int64 > 0: | |
| x = self.q_sample( | |
| x_start, torch.full((1,), depth_int64 - 1, device=device, dtype=torch.long), noise | |
| ) | |
| else: | |
| x = x_start | |
| if speedup > 1: | |
| for t in step_range: | |
| x = self.p_sample_ddim(x, t, interval=speedup, cond=condition) | |
| # plms_noise_stage: int = 0 | |
| # noise_list: List[Tensor] = [] | |
| # for t in step_range: | |
| # noise_pred, x = self.p_sample_plms( | |
| # x, t, interval=speedup, cond=condition, | |
| # noise_list=noise_list, stage=plms_noise_stage | |
| # ) | |
| # if plms_noise_stage == 0: | |
| # noise_list = [noise_pred] | |
| # plms_noise_stage = plms_noise_stage + 1 | |
| # else: | |
| # if plms_noise_stage >= 3: | |
| # noise_list.pop(0) | |
| # else: | |
| # plms_noise_stage = plms_noise_stage + 1 | |
| # noise_list.append(noise_pred) | |
| else: | |
| for t in step_range: | |
| x = self.p_sample(x, t, cond=condition) | |
| if self.num_feats == 1: | |
| x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M] | |
| else: | |
| x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M] | |
| x = self.denorm_spec(x) | |
| return x | |
| class PitchDiffusionONNX(GaussianDiffusionONNX, PitchDiffusion): | |
| def __init__(self, vmin: float, vmax: float, | |
| cmin: float, cmax: float, repeat_bins, | |
| timesteps=1000, k_step=1000, | |
| backbone_type=None, backbone_args=None, | |
| betas=None): | |
| self.vmin = vmin | |
| self.vmax = vmax | |
| self.cmin = cmin | |
| self.cmax = cmax | |
| super(PitchDiffusion, self).__init__( | |
| vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
| timesteps=timesteps, k_step=k_step, | |
| backbone_type=backbone_type, backbone_args=backbone_args, | |
| betas=betas | |
| ) | |
| def clamp_spec(self, x): | |
| return x.clamp(min=self.cmin, max=self.cmax) | |
| def denorm_spec(self, x): | |
| d = (self.spec_max - self.spec_min) / 2. | |
| m = (self.spec_max + self.spec_min) / 2. | |
| x = x * d + m | |
| x = x.mean(dim=-1) | |
| return x | |
| class MultiVarianceDiffusionONNX(GaussianDiffusionONNX, MultiVarianceDiffusion): | |
| def __init__( | |
| self, ranges: List[Tuple[float, float]], | |
| clamps: List[Tuple[float | None, float | None] | None], | |
| repeat_bins, timesteps=1000, k_step=1000, | |
| backbone_type=None, backbone_args=None, | |
| betas=None | |
| ): | |
| assert len(ranges) == len(clamps) | |
| self.clamps = clamps | |
| vmin = [r[0] for r in ranges] | |
| vmax = [r[1] for r in ranges] | |
| if len(vmin) == 1: | |
| vmin = vmin[0] | |
| if len(vmax) == 1: | |
| vmax = vmax[0] | |
| super(MultiVarianceDiffusion, self).__init__( | |
| vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
| timesteps=timesteps, k_step=k_step, | |
| backbone_type=backbone_type, backbone_args=backbone_args, | |
| betas=betas | |
| ) | |
| def denorm_spec(self, x): | |
| d = (self.spec_max - self.spec_min) / 2. | |
| m = (self.spec_max + self.spec_min) / 2. | |
| x = x * d + m | |
| x = x.mean(dim=-1) | |
| return x | |