subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2020 Phil Wang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# The following is largely based on code from https://github.com/lucidrains/stylegan2-pytorch
from random import random, randrange
from typing import List, Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from omegaconf import DictConfig
from torch.utils.tensorboard.writer import SummaryWriter
from nemo.collections.common.parts.utils import mask_sequence_tensor
from nemo.collections.tts.losses.spectrogram_enhancer_losses import (
ConsistencyLoss,
GeneratorLoss,
GradientPenaltyLoss,
HingeLoss,
)
from nemo.collections.tts.parts.utils.helpers import to_device_recursive
from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck
from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType
from nemo.core.neural_types.elements import BoolType
from nemo.utils import logging
try:
import torchvision
TORCHVISION_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
TORCHVISION_AVAILABLE = False
class SpectrogramEnhancerModel(ModelPT, Exportable):
"""
GAN-based model to add details to blurry spectrograms from TTS models like Tacotron or FastPitch. Based on StyleGAN 2 [1]
[1] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958)
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None:
self.spectrogram_model = None
super().__init__(cfg=cfg, trainer=trainer)
self.generator = instantiate(cfg.generator)
self.discriminator = instantiate(cfg.discriminator)
self.generator_loss = GeneratorLoss()
self.discriminator_loss = HingeLoss()
self.consistency_loss = ConsistencyLoss(cfg.consistency_loss_weight)
self.gradient_penalty_loss = GradientPenaltyLoss(cfg.gradient_penalty_loss_weight)
def move_to_correct_device(self, e):
return to_device_recursive(e, next(iter(self.generator.parameters())).device)
def normalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
spectrogram = spectrogram - self._cfg.spectrogram_min_value
spectrogram = spectrogram / (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value)
return mask_sequence_tensor(spectrogram, lengths)
def unnormalize_spectrograms(self, spectrogram: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
spectrogram = spectrogram * (self._cfg.spectrogram_max_value - self._cfg.spectrogram_min_value)
spectrogram = spectrogram + self._cfg.spectrogram_min_value
return mask_sequence_tensor(spectrogram, lengths)
def generate_zs(self, batch_size: int = 1, mixing: bool = False):
if mixing and self._cfg.mixed_prob < random():
mixing_point = randrange(1, self.generator.num_layers)
first_part = [torch.randn(batch_size, self._cfg.latent_dim)] * mixing_point
second_part = [torch.randn(batch_size, self._cfg.latent_dim)] * (self.generator.num_layers - mixing_point)
zs = [*first_part, *second_part]
else:
zs = [torch.randn(batch_size, self._cfg.latent_dim)] * self.generator.num_layers
return self.move_to_correct_device(zs)
def generate_noise(self, batch_size: int = 1) -> torch.Tensor:
noise = torch.rand(batch_size, self._cfg.n_bands, 4096, 1)
return self.move_to_correct_device(noise)
def pad_spectrograms(self, spectrograms):
multiplier = self.generator.upsample_factor
*_, max_length = spectrograms.shape
return F.pad(spectrograms, (0, multiplier - max_length % multiplier))
@typecheck(
input_types={
"input_spectrograms": NeuralType(("B", "D", "T_spec"), MelSpectrogramType()),
"lengths": NeuralType(("B",), LengthsType()),
"mixing": NeuralType(None, BoolType(), optional=True),
"normalize": NeuralType(None, BoolType(), optional=True),
}
)
def forward(
self,
*,
input_spectrograms: torch.Tensor,
lengths: torch.Tensor,
mixing: bool = False,
normalize: bool = True,
):
"""
Generator forward pass. Noise inputs will be generated.
input_spectrograms: batch of spectrograms, typically synthetic
lengths: length for every spectrogam in the batch
mixing: style mixing, usually True during training
normalize: normalize spectrogram range to ~[0, 1], True for normal use
returns: batch of enhanced spectrograms
For explanation of style mixing refer to [1]
[1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948)
"""
return self.forward_with_custom_noise(
input_spectrograms=input_spectrograms,
lengths=lengths,
mixing=mixing,
normalize=normalize,
zs=None,
ws=None,
noise=None,
)
def forward_with_custom_noise(
self,
input_spectrograms: torch.Tensor,
lengths: torch.Tensor,
zs: Optional[List[torch.Tensor]] = None,
ws: Optional[List[torch.Tensor]] = None,
noise: Optional[torch.Tensor] = None,
mixing: bool = False,
normalize: bool = True,
):
"""
Generator forward pass. Noise inputs will be generated if None.
input_spectrograms: batch of spectrograms, typically synthetic
lenghts: length for every spectrogam in the batch
zs: latent noise inputs on the unit sphere (either this or ws or neither)
ws: latent noise inputs in the style space (either this or zs or neither)
noise: per-pixel indepentent gaussian noise
mixing: style mixing, usually True during training
normalize: normalize spectrogram range to ~[0, 1], True for normal use
returns: batch of enhanced spectrograms
For explanation of style mixing refer to [1]
For definititions of z, w [2]
[1] Karras et. al. - A Style-Based Generator Architecture for Generative Adversarial Networks, 2018 (https://arxiv.org/abs/1812.04948)
[2] Karras et. al. - Analyzing and Improving the Image Quality of StyleGAN, 2019 (https://arxiv.org/abs/1912.04958)
"""
batch_size, *_, max_length = input_spectrograms.shape
# generate noise
if zs is not None and ws is not None:
raise ValueError(
"Please specify either zs or ws or neither, but not both. It is not clear which one to use."
)
if zs is None:
zs = self.generate_zs(batch_size, mixing)
if ws is None:
ws = [self.generator.style_mapping(z) for z in zs]
if noise is None:
noise = self.generate_noise(batch_size)
input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l")
# normalize if needed, mask and pad appropriately
if normalize:
input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths)
input_spectrograms = self.pad_spectrograms(input_spectrograms)
# the main call
enhanced_spectrograms = self.generator(input_spectrograms, lengths, ws, noise)
# denormalize if needed, mask and remove padding
if normalize:
enhanced_spectrograms = self.unnormalize_spectrograms(enhanced_spectrograms, lengths)
enhanced_spectrograms = enhanced_spectrograms[:, :, :, :max_length]
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b 1 c l -> b c l")
return enhanced_spectrograms
def training_step(self, batch, batch_idx, optimizer_idx):
input_spectrograms, target_spectrograms, lengths = batch
with torch.no_grad():
input_spectrograms = self.normalize_spectrograms(input_spectrograms, lengths)
target_spectrograms = self.normalize_spectrograms(target_spectrograms, lengths)
# train discriminator
if optimizer_idx == 0:
enhanced_spectrograms = self.forward(
input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False
)
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l")
fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths)
target_spectrograms_ = rearrange(target_spectrograms, "b c l -> b 1 c l").requires_grad_()
real_logits = self.discriminator(target_spectrograms_, input_spectrograms, lengths)
d_loss = self.discriminator_loss(real_logits, fake_logits)
self.log("d_loss", d_loss, prog_bar=True)
if batch_idx % self._cfg.gradient_penalty_loss_every_n_steps == 0:
gp_loss = self.gradient_penalty_loss(target_spectrograms_, real_logits)
self.log("d_loss_gp", gp_loss, prog_bar=True)
return d_loss + gp_loss
return d_loss
# train generator
if optimizer_idx == 1:
enhanced_spectrograms = self.forward(
input_spectrograms=input_spectrograms, lengths=lengths, mixing=True, normalize=False
)
input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l")
enhanced_spectrograms = rearrange(enhanced_spectrograms, "b c l -> b 1 c l")
fake_logits = self.discriminator(enhanced_spectrograms, input_spectrograms, lengths)
g_loss = self.generator_loss(fake_logits)
c_loss = self.consistency_loss(input_spectrograms, enhanced_spectrograms, lengths)
self.log("g_loss", g_loss, prog_bar=True)
self.log("c_loss", c_loss, prog_bar=True)
with torch.no_grad():
target_spectrograms = rearrange(target_spectrograms, "b c l -> b 1 c l")
self.log_illustration(target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths)
return g_loss + c_loss
def configure_optimizers(self):
generator_opt = instantiate(
self._cfg.generator_opt,
params=self.generator.parameters(),
)
discriminator_opt = instantiate(self._cfg.discriminator_opt, params=self.discriminator.parameters())
return [discriminator_opt, generator_opt], []
def setup_training_data(self, train_data_config):
dataset = instantiate(train_data_config.dataset)
self._train_dl = torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collate_fn, **train_data_config.dataloader_params
)
def setup_validation_data(self, val_data_config):
"""
There is no validation step for this model.
It is not clear whether any of used losses is a sensible metric for choosing between two models.
This might change in the future.
"""
pass
@classmethod
def list_available_models(cls):
list_of_models = []
# en, multi speaker, LibriTTS, 16000 Hz
# stft 25ms 10ms matching ASR params
# for use during Enhlish ASR training/adaptation
model = PretrainedModelInfo(
pretrained_model_name="tts_en_spectrogram_enhancer_for_asr_finetuning",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_spectrogram_enhancer_for_asr_finetuning.nemo",
description="This model is trained to add details to synthetic spectrograms."
" It was trained on pairs of real-synthesized spectrograms generated by FastPitch."
" STFT parameters follow ASR with 25 ms window and 10 ms hop."
" It is supposed to be used in conjunction with that model for ASR training/adaptation.",
class_=cls,
)
list_of_models.append(model)
return list_of_models
def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths):
if self.global_rank != 0:
return
if not self.loggers:
return
step = self.trainer.global_step // 2 # because of G/D training
if step % self.trainer.log_every_n_steps != 0:
return
idx = 0
length = int(lengths.flatten()[idx].item())
tensor = torch.stack(
[
enhanced_spectrograms - input_spectrograms,
input_spectrograms,
enhanced_spectrograms,
target_spectrograms,
],
dim=0,
).cpu()[:, idx, :, :, :length]
assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required."
grid = torchvision.utils.make_grid(tensor, nrow=1).clamp(0.0, 1.0)
for logger in self.loggers:
if isinstance(logger, TensorBoardLogger):
writer: SummaryWriter = logger.experiment
writer.add_image("spectrograms", grid, global_step=step)
writer.flush()
elif isinstance(logger, WandbLogger):
logger.log_image("spectrograms", [grid], caption=["residual, input, output, ground truth"], step=step)
else:
logging.warning("Unsupported logger type: %s", str(type(logger)))