Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from modules.backbones import build_backbone | |
| from utils.hparams import hparams | |
| class RectifiedFlow(nn.Module): | |
| def __init__(self, out_dims, num_feats=1, t_start=0., time_scale_factor=1000, | |
| backbone_type=None, backbone_args=None, | |
| spec_min=None, spec_max=None): | |
| super().__init__() | |
| self.velocity_fn: nn.Module = build_backbone(out_dims, num_feats, backbone_type, backbone_args) | |
| self.out_dims = out_dims | |
| self.num_feats = num_feats | |
| self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False) | |
| if self.use_shallow_diffusion: | |
| assert 0. <= t_start <= 1., 'T_start should be in [0, 1].' | |
| else: | |
| t_start = 0. | |
| self.t_start = t_start | |
| self.time_scale_factor = time_scale_factor | |
| # spec: [B, T, M] or [B, F, T, M] | |
| # spec_min and spec_max: [1, 1, M] or [1, 1, F, M] => transpose(-3, -2) => [1, 1, M] or [1, F, 1, M] | |
| spec_min = torch.FloatTensor(spec_min)[None, None, :out_dims].transpose(-3, -2) | |
| spec_max = torch.FloatTensor(spec_max)[None, None, :out_dims].transpose(-3, -2) | |
| self.register_buffer('spec_min', spec_min, persistent=False) | |
| self.register_buffer('spec_max', spec_max, persistent=False) | |
| def p_losses(self, x_end, t, cond): | |
| x_start = torch.randn_like(x_end) | |
| x_t = x_start + t[:, None, None, None] * (x_end - x_start) | |
| v_pred = self.velocity_fn(x_t, t * self.time_scale_factor, cond) | |
| return v_pred, x_end - x_start | |
| def forward(self, condition, gt_spec=None, src_spec=None, infer=True): | |
| cond = condition.transpose(1, 2) | |
| b, device = condition.shape[0], condition.device | |
| if not infer: | |
| # gt_spec: [B, T, M] or [B, F, T, M] | |
| spec = self.norm_spec(gt_spec).transpose(-2, -1) # [B, M, T] or [B, F, M, T] | |
| if self.num_feats == 1: | |
| spec = spec[:, None, :, :] # [B, F=1, M, T] | |
| t = self.t_start + (1.0 - self.t_start) * torch.rand((b,), device=device) | |
| v_pred, v_gt = self.p_losses(spec, t, cond=cond) | |
| return v_pred, v_gt, t | |
| else: | |
| # src_spec: [B, T, M] or [B, F, T, M] | |
| if src_spec is not None: | |
| spec = self.norm_spec(src_spec).transpose(-2, -1) | |
| if self.num_feats == 1: | |
| spec = spec[:, None, :, :] | |
| else: | |
| spec = None | |
| x = self.inference(cond, b=b, x_end=spec, device=device) | |
| return self.denorm_spec(x) | |
| def sample_euler(self, x, t, dt, cond): | |
| x += self.velocity_fn(x, self.time_scale_factor * t, cond) * dt | |
| t += dt | |
| return x, t | |
| def sample_rk2(self, x, t, dt, cond): | |
| k_1 = self.velocity_fn(x, self.time_scale_factor * t, cond) | |
| k_2 = self.velocity_fn(x + 0.5 * k_1 * dt, self.time_scale_factor * (t + 0.5 * dt), cond) | |
| x += k_2 * dt | |
| t += dt | |
| return x, t | |
| def sample_rk4(self, x, t, dt, cond): | |
| k_1 = self.velocity_fn(x, self.time_scale_factor * t, cond) | |
| k_2 = self.velocity_fn(x + 0.5 * k_1 * dt, self.time_scale_factor * (t + 0.5 * dt), cond) | |
| k_3 = self.velocity_fn(x + 0.5 * k_2 * dt, self.time_scale_factor * (t + 0.5 * dt), cond) | |
| k_4 = self.velocity_fn(x + k_3 * dt, self.time_scale_factor * (t + dt), cond) | |
| x += (k_1 + 2 * k_2 + 2 * k_3 + k_4) * dt / 6 | |
| t += dt | |
| return x, t | |
| def sample_rk5(self, x, t, dt, cond): | |
| k_1 = self.velocity_fn(x, self.time_scale_factor * t, cond) | |
| k_2 = self.velocity_fn(x + 0.25 * k_1 * dt, self.time_scale_factor * (t + 0.25 * dt), cond) | |
| k_3 = self.velocity_fn(x + 0.125 * (k_2 + k_1) * dt, self.time_scale_factor * (t + 0.25 * dt), cond) | |
| k_4 = self.velocity_fn(x + 0.5 * (-k_2 + 2 * k_3) * dt, self.time_scale_factor * (t + 0.5 * dt), cond) | |
| k_5 = self.velocity_fn(x + 0.0625 * (3 * k_1 + 9 * k_4) * dt, self.time_scale_factor * (t + 0.75 * dt), cond) | |
| k_6 = self.velocity_fn(x + (-3 * k_1 + 2 * k_2 + 12 * k_3 - 12 * k_4 + 8 * k_5) * dt / 7, | |
| self.time_scale_factor * (t + dt), | |
| cond) | |
| x += (7 * k_1 + 32 * k_3 + 12 * k_4 + 32 * k_5 + 7 * k_6) * dt / 90 | |
| t += dt | |
| return x, t | |
| def inference(self, cond, b=1, x_end=None, device=None): | |
| noise = torch.randn(b, self.num_feats, self.out_dims, cond.shape[2], device=device) | |
| t_start = hparams.get('T_start_infer', self.t_start) | |
| if self.use_shallow_diffusion and t_start > 0: | |
| assert x_end is not None, 'Missing shallow diffusion source.' | |
| if t_start >= 1.: | |
| t_start = 1. | |
| x = x_end | |
| else: | |
| x = t_start * x_end + (1 - t_start) * noise | |
| else: | |
| t_start = 0. | |
| x = noise | |
| algorithm = hparams['sampling_algorithm'] | |
| infer_step = hparams['sampling_steps'] | |
| if t_start < 1: | |
| dt = (1.0 - t_start) / max(1, infer_step) | |
| algorithm_fn = { | |
| 'euler': self.sample_euler, | |
| 'rk2': self.sample_rk2, | |
| 'rk4': self.sample_rk4, | |
| 'rk5': self.sample_rk5, | |
| }.get(algorithm) | |
| if algorithm_fn is None: | |
| raise ValueError(f'Unsupported algorithm for Rectified Flow: {algorithm}.') | |
| dts = torch.tensor([dt]).to(x) | |
| for i in tqdm(range(infer_step), desc='sample time step', total=infer_step, | |
| disable=not hparams['infer'], leave=False): | |
| x, _ = algorithm_fn(x, t_start + i * dts, dt, cond) | |
| x = x.float() | |
| x = x.transpose(2, 3).squeeze(1) # [B, F, M, T] => [B, T, M] or [B, F, T, M] | |
| return x | |
| def norm_spec(self, x): | |
| return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 | |
| def denorm_spec(self, x): | |
| return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min | |
| class RepetitiveRectifiedFlow(RectifiedFlow): | |
| def __init__(self, vmin: float | int | list, vmax: float | int | list, | |
| repeat_bins: int, time_scale_factor=1000, | |
| backbone_type=None, backbone_args=None): | |
| assert (isinstance(vmin, (float, int)) and isinstance(vmin, (float, int))) or len(vmin) == len(vmax) | |
| num_feats = 1 if isinstance(vmin, (float, int)) else len(vmin) | |
| spec_min = [vmin] if num_feats == 1 else [[v] for v in vmin] | |
| spec_max = [vmax] if num_feats == 1 else [[v] for v in vmax] | |
| self.repeat_bins = repeat_bins | |
| super().__init__( | |
| out_dims=repeat_bins, num_feats=num_feats, | |
| time_scale_factor=time_scale_factor, | |
| backbone_type=backbone_type, backbone_args=backbone_args, | |
| spec_min=spec_min, spec_max=spec_max | |
| ) | |
| def norm_spec(self, x): | |
| """ | |
| :param x: [B, T] or [B, F, T] | |
| :return [B, T, R] or [B, F, T, R] | |
| """ | |
| if self.num_feats == 1: | |
| repeats = [1, 1, self.repeat_bins] | |
| else: | |
| repeats = [1, 1, 1, self.repeat_bins] | |
| return super().norm_spec(x.unsqueeze(-1).repeat(repeats)) | |
| def denorm_spec(self, x): | |
| """ | |
| :param x: [B, T, R] or [B, F, T, R] | |
| :return [B, T] or [B, F, T] | |
| """ | |
| return super().denorm_spec(x).mean(dim=-1) | |
| class PitchRectifiedFlow(RepetitiveRectifiedFlow): | |
| def __init__(self, vmin: float, vmax: float, | |
| cmin: float, cmax: float, repeat_bins, | |
| time_scale_factor=1000, | |
| backbone_type=None, backbone_args=None): | |
| self.vmin = vmin # norm min | |
| self.vmax = vmax # norm max | |
| self.cmin = cmin # clip min | |
| self.cmax = cmax # clip max | |
| super().__init__( | |
| vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
| time_scale_factor=time_scale_factor, | |
| backbone_type=backbone_type, backbone_args=backbone_args | |
| ) | |
| def norm_spec(self, x): | |
| return super().norm_spec(x.clamp(min=self.cmin, max=self.cmax)) | |
| def denorm_spec(self, x): | |
| return super().denorm_spec(x).clamp(min=self.cmin, max=self.cmax) | |
| class MultiVarianceRectifiedFlow(RepetitiveRectifiedFlow): | |
| def __init__( | |
| self, ranges: List[Tuple[float, float]], | |
| clamps: List[Tuple[float | None, float | None] | None], | |
| repeat_bins, time_scale_factor=1000, | |
| backbone_type=None, backbone_args=None | |
| ): | |
| assert len(ranges) == len(clamps) | |
| self.clamps = clamps | |
| vmin = [r[0] for r in ranges] | |
| vmax = [r[1] for r in ranges] | |
| if len(vmin) == 1: | |
| vmin = vmin[0] | |
| if len(vmax) == 1: | |
| vmax = vmax[0] | |
| super().__init__( | |
| vmin=vmin, vmax=vmax, repeat_bins=repeat_bins, | |
| time_scale_factor=time_scale_factor, | |
| backbone_type=backbone_type, backbone_args=backbone_args | |
| ) | |
| def clamp_spec(self, xs: list | tuple): | |
| clamped = [] | |
| for x, c in zip(xs, self.clamps): | |
| if c is None: | |
| clamped.append(x) | |
| continue | |
| clamped.append(x.clamp(min=c[0], max=c[1])) | |
| return clamped | |
| def norm_spec(self, xs: list | tuple): | |
| """ | |
| :param xs: sequence of [B, T] | |
| :return: [B, F, T] => super().norm_spec(xs) => [B, F, T, R] | |
| """ | |
| assert len(xs) == self.num_feats | |
| clamped = self.clamp_spec(xs) | |
| xs = torch.stack(clamped, dim=1) # [B, F, T] | |
| if self.num_feats == 1: | |
| xs = xs.squeeze(1) # [B, T] | |
| return super().norm_spec(xs) | |
| def denorm_spec(self, xs): | |
| """ | |
| :param xs: [B, T, R] or [B, F, T, R] => super().denorm_spec(xs) => [B, T] or [B, F, T] | |
| :return: sequence of [B, T] | |
| """ | |
| xs = super().denorm_spec(xs) | |
| if self.num_feats == 1: | |
| xs = [xs] | |
| else: | |
| xs = xs.unbind(dim=1) | |
| assert len(xs) == self.num_feats | |
| return self.clamp_spec(xs) | |