# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # MIT License # # Copyright (c) 2020 Phil Wang # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # The following is largely based on code from https://github.com/lucidrains/stylegan2-pytorch from random import random, randrange from typing import List, Optional import torch import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import DictConfig from torch.utils.tensorboard.writer import SummaryWriter from nemo.collections.common.parts.utils import mask_sequence_tensor from nemo.collections.tts.losses.spectrogram_enhancer_losses import ( ConsistencyLoss, GeneratorLoss, GradientPenaltyLoss, HingeLoss, ) from nemo.collections.tts.parts.utils.helpers import to_device_recursive from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType from nemo.core.neural_types.elements import BoolType from nemo.utils import logging try: import torchvision TORCHVISION_AVAILABLE = True except (ImportError, ModuleNotFoundError): TORCHVISION_AVAILABLE = False class SpectrogramEnhancerModel(ModelPT, Exportable): """ GAN-based model to add details to blurry spectrograms from TTS models like Tacotron or FastPitch. Based on StyleGAN 2 [1] [1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) """ def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: self.spectrogram_model = None super().__init__(cfg=cfg, trainer=trainer) self.generator = instantiate(cfg.generator) self.discriminator = instantiate(cfg.discriminator) self.generator_loss = GeneratorLoss() self.discriminator_loss = HingeLoss() self.consistency_loss = ConsistencyLoss(cfg.consistency_loss_weight) self.gradient_penalty_loss = GradientPenaltyLoss(cfg.gradient_penalty_loss_weight) def move_to_correct_device(self, e): return to_device_recursive(e, next(iter(self.generator.parameters())).device) def normalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: spectrogram = spectrogram - self._cfg.spectrogram_min_value spectrogram = spectrogram / (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) return mask_sequence_tensor(spectrogram, lengths) def unnormalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: spectrogram = spectrogram * (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value) spectrogram = spectrogram + self._cfg.spectrogram_min_value return mask_sequence_tensor(spectrogram, lengths) def generate_zs(self, batch_size: int = 1, mixing: bool = False): if mixing and self._cfg.mixed_prob < random(): mixing_point = randrange(1, self.generator.num_layers) first_part = [torch.randn(batch_size, self._cfg.latent_dim)] * mixing_point second_part = [torch.randn(batch_size, self._cfg.latent_dim)] * (self.generator.num_layers - mixing_point) zs = [*first_part, *second_part] else: zs = [torch.randn(batch_size, self._cfg.latent_dim)] * self.generator.num_layers return self.move_to_correct_device(zs) def generate_noise(self, batch_size: int = 1) -> torch.Tensor: noise = torch.rand(batch_size, self._cfg.n_bands, 4096, 1) return self.move_to_correct_device(noise) def pad_spectrograms(self, spectrograms): multiplier = self.generator.upsample_factor *_, max_length = spectrograms.shape return F.pad(spectrograms, (0, multiplier - max_length % multiplier)) @typecheck( input_types={ "input_spectrograms": NeuralType(("B", "D", "T_spec"), MelSpectrogramType()), "lengths": NeuralType(("B",), LengthsType()), "mixing": NeuralType(None, BoolType(), optional=True), "normalize": NeuralType(None, BoolType(), optional=True), } ) def forward( self, *, input_spectrograms: torch.Tensor, lengths: torch.Tensor, mixing: bool = False, normalize: bool = True, ): """ Generator forward pass. Noise inputs will be generated. input_spectrograms: batch of spectrograms, typically synthetic lengths: length for every spectrogam in the batch mixing: style mixing, usually True during training normalize: normalize spectrogram range to ~[0, 1], True for normal use returns: batch of enhanced spectrograms For explanation of style mixing refer to [1] [1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) """ return self.forward_with_custom_noise( input_spectrograms=input_spectrograms, lengths=lengths, mixing=mixing, normalize=normalize, zs=None, ws=None, noise=None, ) def forward_with_custom_noise( self, input_spectrograms: torch.Tensor, lengths: torch.Tensor, zs: Optional[List[torch.Tensor]] = None, ws: Optional[List[torch.Tensor]] = None, noise: Optional[torch.Tensor] = None, mixing: bool = False, normalize: bool = True, ): """ Generator forward pass. Noise inputs will be generated if None. input_spectrograms: batch of spectrograms, typically synthetic lenghts: length for every spectrogam in the batch zs: latent noise inputs on the unit sphere (either this or ws or neither) ws: latent noise inputs in the style space (either this or zs or neither) noise: per-pixel indepentent gaussian noise mixing: style mixing, usually True during training normalize: normalize spectrogram range to ~[0, 1], True for normal use returns: batch of enhanced spectrograms For explanation of style mixing refer to [1] For definititions of z, w [2] [1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948) [2] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN, 2019 (https://arxiv.org/abs/1912.04958) """ batch_size, *_, max_length = input_spectrograms.shape # generate noise if zs is not None and ws is not None: raise ValueError( "Please specify either zs or ws or neither, but not both. It is not clear which one to use." ) if zs is None: zs = self.generate_zs(batch_size, mixing) if ws is None: ws = [self.generator.style_mapping(z) for z in zs] if noise is None: noise = self.generate_noise(batch_size) input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") # normalize if needed, mask and pad appropriately if normalize: input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) input_spectrograms = self.pad_spectrograms(input_spectrograms) # the main call enhanced_spectrograms = self.generator(input_spectrograms, lengths, ws, noise) # denormalize if needed, mask and remove padding if normalize: enhanced_spectrograms = self.unnormalize_spectrograms(enhanced_spectrograms, lengths) enhanced_spectrograms = enhanced_spectrograms[:, :, :, :max_length] enhanced_spectrograms = rearrange(enhanced_spectrograms, "b 1 c l -> b c l") return enhanced_spectrograms def training_step(self, batch, batch_idx, optimizer_idx): input_spectrograms, target_spectrograms, lengths = batch with torch.no_grad(): input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths) target_spectrograms = self.normalize_spectrograms(target_spectrograms, lengths) # train discriminator if optimizer_idx == 0: enhanced_spectrograms = self.forward( input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False ) enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) target_spectrograms_ = rearrange(target_spectrograms, "b c l -> b 1 c l").requires_grad_() real_logits = self.discriminator(target_spectrograms_, input_spectrograms, lengths) d_loss = self.discriminator_loss(real_logits, fake_logits) self.log("d_loss", d_loss, prog_bar=True) if batch_idx % self._cfg.gradient_penalty_loss_every_n_steps == 0: gp_loss = self.gradient_penalty_loss(target_spectrograms_, real_logits) self.log("d_loss_gp", gp_loss, prog_bar=True) return d_loss + gp_loss return d_loss # train generator if optimizer_idx == 1: enhanced_spectrograms = self.forward( input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False ) input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l") enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l") fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths) g_loss = self.generator_loss(fake_logits) c_loss = self.consistency_loss(input_spectrograms, enhanced_spectrograms, lengths) self.log("g_loss", g_loss, prog_bar=True) self.log("c_loss", c_loss, prog_bar=True) with torch.no_grad(): target_spectrograms = rearrange(target_spectrograms, "b c l -> b 1 c l") self.log_illustration(target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths) return g_loss + c_loss def configure_optimizers(self): generator_opt = instantiate( self._cfg.generator_opt, params=self.generator.parameters(), ) discriminator_opt = instantiate(self._cfg.discriminator_opt, params=self.discriminator.parameters()) return [discriminator_opt, generator_opt], [] def setup_training_data(self, train_data_config): dataset = instantiate(train_data_config.dataset) self._train_dl = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, **train_data_config.dataloader_params ) def setup_validation_data(self, val_data_config): """ There is no validation step for this model. It is not clear whether any of used losses is a sensible metric for choosing between two models. This might change in the future. """ pass @classmethod def list_available_models(cls): list_of_models = [] # en, multi speaker, LibriTTS, 16000 Hz # stft 25ms 10ms matching ASR params # for use during Enhlish ASR training/adaptation model = PretrainedModelInfo( pretrained_model_name="tts_en_spectrogram_enhancer_for_asr_finetuning", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_spectrogram_enhancer_for_asr_finetuning.nemo", description="This model is trained to add details to synthetic spectrograms." " It was trained on pairs of real-synthesized spectrograms generated by FastPitch." " STFT parameters follow ASR with 25 ms window and 10 ms hop." " It is supposed to be used in conjunction with that model for ASR training/adaptation.", class_=cls, ) list_of_models.append(model) return list_of_models def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths): if self.global_rank != 0: return if not self.loggers: return step = self.trainer.global_step // 2 # because of G/D training if step % self.trainer.log_every_n_steps != 0: return idx = 0 length = int(lengths.flatten()[idx].item()) tensor = torch.stack( [ enhanced_spectrograms - input_spectrograms, input_spectrograms, enhanced_spectrograms, target_spectrograms, ], dim=0, ).cpu()[:, idx, :, :, :length] assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." grid = torchvision.utils.make_grid(tensor, nrow=1).clamp(0.0, 1.0) for logger in self.loggers: if isinstance(logger, TensorBoardLogger): writer: SummaryWriter = logger.experiment writer.add_image("spectrograms", grid, global_step=step) writer.flush() elif isinstance(logger, WandbLogger): logger.log_image("spectrograms", [grid], caption=["residual, input, output, ground truth"], step=step) else: logging.warning("Unsupported logger type: %s", str(type(logger)))