Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # MIT License | |
| # | |
| # Copyright (c) 2020 Phil Wang | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # The following is largely based on code from https://github.com/lucidrains/stylegan2-pytorch | |
| from random import random, randrange | |
| from typing import List, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from hydra.utils import instantiate | |
| from lightning.pytorch import Trainer | |
| from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger | |
| from omegaconf import DictConfig | |
| from torch.utils.tensorboard.writer import SummaryWriter | |
| from nemo.collections.common.parts.utils import mask_sequence_tensor | |
| from nemo.collections.tts.losses.spectrogram_enhancer_losses import ( | |
| ConsistencyLoss, | |
| GeneratorLoss, | |
| GradientPenaltyLoss, | |
| HingeLoss, | |
| ) | |
| from nemo.collections.tts.parts.utils.helpers import to_device_recursive | |
| from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck | |
| from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType | |
| from nemo.core.neural_types.elements import BoolType | |
| from nemo.utils import logging | |
| try: | |
| import torchvision | |
| TORCHVISION_AVAILABLE = True | |
| except (ImportError, ModuleNotFoundError): | |
| TORCHVISION_AVAILABLE = False | |
| class SpectrogramEnhancerModel(ModelPT, Exportable): | |
| """ | |
| GAN-based model to add details to blurry spectrograms from TTS models like Tacotron or FastPitch. Based on StyleGAN 2 [1] | |
| [1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) | |
| """ | |
| def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: | |
| self.spectrogram_model = None | |
| super().__init__(cfg=cfg, trainer=trainer) | |
| self.generator = instantiate(cfg.generator) | |
| self.discriminator = instantiate(cfg.discriminator) | |
| self.generator_loss = GeneratorLoss() | |
| self.discriminator_loss = HingeLoss() | |
| self.consistency_loss = ConsistencyLoss(cfg.consistency_loss_weight) | |
| self.gradient_penalty_loss = GradientPenaltyLoss(cfg.gradient_penalty_loss_weight) | |
| def move_to_correct_device(self, e): | |
| return to_device_recursive(e, next(iter(self.generator.parameters())).device) | |
| def normalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: | |
| spectrogram = spectrogram - self._cfg.spectrogram_min_value | |
| spectrogram = spectrogram / (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) | |
| return mask_sequence_tensor(spectrogram, lengths) | |
| def unnormalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: | |
| spectrogram = spectrogram * (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) | |
| spectrogram = spectrogram + self._cfg.spectrogram_min_value | |
| return mask_sequence_tensor(spectrogram, lengths) | |
| def generate_zs(self, batch_size: int = 1, mixing: bool = False): | |
| if mixing and self._cfg.mixed_prob < random(): | |
| mixing_point = randrange(1, self.generator.num_layers) | |
| first_part = [torch.randn(batch_size, self._cfg.latent_dim)] * mixing_point | |
| second_part = [torch.randn(batch_size, self._cfg.latent_dim)] * (self.generator.num_layers - mixing_point) | |
| zs = [*first_part, *second_part] | |
| else: | |
| zs = [torch.randn(batch_size, self._cfg.latent_dim)] * self.generator.num_layers | |
| return self.move_to_correct_device(zs) | |
| def generate_noise(self, batch_size: int = 1) -> torch.Tensor: | |
| noise = torch.rand(batch_size, self._cfg.n_bands, 4096, 1) | |
| return self.move_to_correct_device(noise) | |
| def pad_spectrograms(self, spectrograms): | |
| multiplier = self.generator.upsample_factor | |
| *_, max_length = spectrograms.shape | |
| return F.pad(spectrograms, (0, multiplier - max_length % multiplier)) | |
| def forward( | |
| self, | |
| *, | |
| input_spectrograms: torch.Tensor, | |
| lengths: torch.Tensor, | |
| mixing: bool = False, | |
| normalize: bool = True, | |
| ): | |
| """ | |
| Generator forward pass. Noise inputs will be generated. | |
| input_spectrograms: batch of spectrograms, typically synthetic | |
| lengths: length for every spectrogam in the batch | |
| mixing: style mixing, usually True during training | |
| normalize: normalize spectrogram range to ~[0, 1], True for normal use | |
| returns: batch of enhanced spectrograms | |
| For explanation of style mixing refer to [1] | |
| [1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) | |
| """ | |
| return self.forward_with_custom_noise( | |
| input_spectrograms=input_spectrograms, | |
| lengths=lengths, | |
| mixing=mixing, | |
| normalize=normalize, | |
| zs=None, | |
| ws=None, | |
| noise=None, | |
| ) | |
| def forward_with_custom_noise( | |
| self, | |
| input_spectrograms: torch.Tensor, | |
| lengths: torch.Tensor, | |
| zs: Optional[List[torch.Tensor]] = None, | |
| ws: Optional[List[torch.Tensor]] = None, | |
| noise: Optional[torch.Tensor] = None, | |
| mixing: bool = False, | |
| normalize: bool = True, | |
| ): | |
| """ | |
| Generator forward pass. Noise inputs will be generated if None. | |
| input_spectrograms: batch of spectrograms, typically synthetic | |
| lenghts: length for every spectrogam in the batch | |
| zs: latent noise inputs on the unit sphere (either this or ws or neither) | |
| ws: latent noise inputs in the style space (either this or zs or neither) | |
| noise: per-pixel indepentent gaussian noise | |
| mixing: style mixing, usually True during training | |
| normalize: normalize spectrogram range to ~[0, 1], True for normal use | |
| returns: batch of enhanced spectrograms | |
| For explanation of style mixing refer to [1] | |
| For definititions of z, w [2] | |
| [1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) | |
| [2] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN, 2019 (https://arxiv.org/abs/1912.04958) | |
| """ | |
| batch_size, *_, max_length = input_spectrograms.shape | |
| # generate noise | |
| if zs is not None and ws is not None: | |
| raise ValueError( | |
| "Please specify either zs or ws or neither, but not both. It is not clear which one to use." | |
| ) | |
| if zs is None: | |
| zs = self.generate_zs(batch_size, mixing) | |
| if ws is None: | |
| ws = [self.generator.style_mapping(z) for z in zs] | |
| if noise is None: | |
| noise = self.generate_noise(batch_size) | |
| input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") | |
| # normalize if needed, mask and pad appropriately | |
| if normalize: | |
| input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) | |
| input_spectrograms = self.pad_spectrograms(input_spectrograms) | |
| # the main call | |
| enhanced_spectrograms = self.generator(input_spectrograms, lengths, ws, noise) | |
| # denormalize if needed, mask and remove padding | |
| if normalize: | |
| enhanced_spectrograms = self.unnormalize_spectrograms(enhanced_spectrograms, lengths) | |
| enhanced_spectrograms = enhanced_spectrograms[:, :, :, :max_length] | |
| enhanced_spectrograms = rearrange(enhanced_spectrograms, "b 1 c l -> b c l") | |
| return enhanced_spectrograms | |
| def training_step(self, batch, batch_idx, optimizer_idx): | |
| input_spectrograms, target_spectrograms, lengths = batch | |
| with torch.no_grad(): | |
| input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) | |
| target_spectrograms = self.normalize_spectrograms(target_spectrograms, lengths) | |
| # train discriminator | |
| if optimizer_idx == 0: | |
| enhanced_spectrograms = self.forward( | |
| input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False | |
| ) | |
| enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") | |
| fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) | |
| target_spectrograms_ = rearrange(target_spectrograms, "b c l -> b 1 c l").requires_grad_() | |
| real_logits = self.discriminator(target_spectrograms_, input_spectrograms, lengths) | |
| d_loss = self.discriminator_loss(real_logits, fake_logits) | |
| self.log("d_loss", d_loss, prog_bar=True) | |
| if batch_idx % self._cfg.gradient_penalty_loss_every_n_steps == 0: | |
| gp_loss = self.gradient_penalty_loss(target_spectrograms_, real_logits) | |
| self.log("d_loss_gp", gp_loss, prog_bar=True) | |
| return d_loss + gp_loss | |
| return d_loss | |
| # train generator | |
| if optimizer_idx == 1: | |
| enhanced_spectrograms = self.forward( | |
| input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False | |
| ) | |
| input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") | |
| enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") | |
| fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) | |
| g_loss = self.generator_loss(fake_logits) | |
| c_loss = self.consistency_loss(input_spectrograms, enhanced_spectrograms, lengths) | |
| self.log("g_loss", g_loss, prog_bar=True) | |
| self.log("c_loss", c_loss, prog_bar=True) | |
| with torch.no_grad(): | |
| target_spectrograms = rearrange(target_spectrograms, "b c l -> b 1 c l") | |
| self.log_illustration(target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths) | |
| return g_loss + c_loss | |
| def configure_optimizers(self): | |
| generator_opt = instantiate( | |
| self._cfg.generator_opt, | |
| params=self.generator.parameters(), | |
| ) | |
| discriminator_opt = instantiate(self._cfg.discriminator_opt, params=self.discriminator.parameters()) | |
| return [discriminator_opt, generator_opt], [] | |
| def setup_training_data(self, train_data_config): | |
| dataset = instantiate(train_data_config.dataset) | |
| self._train_dl = torch.utils.data.DataLoader( | |
| dataset, collate_fn=dataset.collate_fn, **train_data_config.dataloader_params | |
| ) | |
| def setup_validation_data(self, val_data_config): | |
| """ | |
| There is no validation step for this model. | |
| It is not clear whether any of used losses is a sensible metric for choosing between two models. | |
| This might change in the future. | |
| """ | |
| pass | |
| def list_available_models(cls): | |
| list_of_models = [] | |
| # en, multi speaker, LibriTTS, 16000 Hz | |
| # stft 25ms 10ms matching ASR params | |
| # for use during Enhlish ASR training/adaptation | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="tts_en_spectrogram_enhancer_for_asr_finetuning", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_spectrogram_enhancer_for_asr_finetuning.nemo", | |
| description="This model is trained to add details to synthetic spectrograms." | |
| " It was trained on pairs of real-synthesized spectrograms generated by FastPitch." | |
| " STFT parameters follow ASR with 25 ms window and 10 ms hop." | |
| " It is supposed to be used in conjunction with that model for ASR training/adaptation.", | |
| class_=cls, | |
| ) | |
| list_of_models.append(model) | |
| return list_of_models | |
| def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths): | |
| if self.global_rank != 0: | |
| return | |
| if not self.loggers: | |
| return | |
| step = self.trainer.global_step // 2 # because of G/D training | |
| if step % self.trainer.log_every_n_steps != 0: | |
| return | |
| idx = 0 | |
| length = int(lengths.flatten()[idx].item()) | |
| tensor = torch.stack( | |
| [ | |
| enhanced_spectrograms - input_spectrograms, | |
| input_spectrograms, | |
| enhanced_spectrograms, | |
| target_spectrograms, | |
| ], | |
| dim=0, | |
| ).cpu()[:, idx, :, :, :length] | |
| assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." | |
| grid = torchvision.utils.make_grid(tensor, nrow=1).clamp(0.0, 1.0) | |
| for logger in self.loggers: | |
| if isinstance(logger, TensorBoardLogger): | |
| writer: SummaryWriter = logger.experiment | |
| writer.add_image("spectrograms", grid, global_step=step) | |
| writer.flush() | |
| elif isinstance(logger, WandbLogger): | |
| logger.log_image("spectrograms", [grid], caption=["residual, input, output, ground truth"], step=step) | |
| else: | |
| logging.warning("Unsupported logger type: %s", str(type(logger))) | |