Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import itertools | |
| from math import ceil | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from hydra.utils import instantiate | |
| from lightning.pytorch import Trainer | |
| from omegaconf import DictConfig, OmegaConf, open_dict | |
| from nemo.collections.tts.losses.audio_codec_loss import ( | |
| FeatureMatchingLoss, | |
| MultiResolutionMelLoss, | |
| MultiResolutionSTFTLoss, | |
| RelativeFeatureMatchingLoss, | |
| SISDRLoss, | |
| TimeDomainLoss, | |
| ) | |
| from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder, default_precision | |
| from nemo.collections.tts.modules.common import GaussianDropout | |
| from nemo.collections.tts.parts.utils.callbacks import LoggingCallback | |
| from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers | |
| from nemo.core import ModelPT | |
| from nemo.core.classes.common import PretrainedModelInfo, typecheck | |
| from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, TokenIndex | |
| from nemo.core.neural_types.neural_type import NeuralType | |
| from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler | |
| from nemo.utils import logging, model_utils | |
| try: | |
| import torchaudio | |
| HAVE_TORCHAUDIO = True | |
| except ModuleNotFoundError: | |
| HAVE_TORCHAUDIO = False | |
| class AudioCodecModel(ModelPT): | |
| def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |
| # Convert to Hydra 1.0 compatible DictConfig | |
| cfg = model_utils.convert_model_config_to_dict_config(cfg) | |
| cfg = model_utils.maybe_update_config_version(cfg) | |
| self.world_size = 1 | |
| if trainer is not None: | |
| self.world_size = trainer.num_nodes * trainer.num_devices | |
| super().__init__(cfg=cfg, trainer=trainer) | |
| # Expected sample rate for the input audio | |
| self.sample_rate = cfg.sample_rate | |
| # Number of samples in each audio frame that is encoded | |
| self.samples_per_frame = cfg.samples_per_frame | |
| # Discriminator updates | |
| self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1) | |
| self.disc_update_period = cfg.get("disc_update_period", 1) | |
| if self.disc_updates_per_period > self.disc_update_period: | |
| raise ValueError( | |
| f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})' | |
| ) | |
| # Encoder setup | |
| self.audio_encoder = instantiate(cfg.audio_encoder) | |
| # Optionally, add gaussian noise to encoder output as an information bottleneck | |
| encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0) | |
| if encoder_noise_stdev: | |
| self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev) | |
| else: | |
| self.encoder_noise = None | |
| if "vector_quantizer" in cfg: | |
| self.vector_quantizer = instantiate(cfg.vector_quantizer) | |
| vq_output_types = list(self.vector_quantizer.output_types.keys()) | |
| if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss': | |
| self.vector_quantizer_has_commit_loss = True | |
| logging.info('Vector quantizer supports commit loss.') | |
| else: | |
| self.vector_quantizer_has_commit_loss = False | |
| logging.info('Vector quantizer does not support commit loss.') | |
| else: | |
| logging.warning('Vector quantizer will not be used.') | |
| self.vector_quantizer = None | |
| # Decoder setup | |
| self.audio_decoder = instantiate(cfg.audio_decoder) | |
| # Discriminator setup | |
| self.discriminator = instantiate(cfg.discriminator) | |
| # Mel loss setup | |
| loss_resolutions = cfg.loss_resolutions | |
| mel_loss_dims = cfg.get("mel_loss_dims") | |
| mel_loss_log_guard = cfg.get("mel_loss_log_guard", 1.0) | |
| self.mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0) | |
| self.mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0) | |
| self.mel_loss_fn = MultiResolutionMelLoss( | |
| sample_rate=self.sample_rate, | |
| mel_dims=mel_loss_dims, | |
| resolutions=loss_resolutions, | |
| log_guard=mel_loss_log_guard, | |
| ) | |
| # STFT loss setup | |
| stft_loss_log_guard = cfg.get("stft_loss_log_guard", 1.0) | |
| self.stft_loss_scale = cfg.get("stft_loss_scale", 0.0) | |
| self.stft_loss_fn = MultiResolutionSTFTLoss( | |
| resolutions=loss_resolutions, | |
| log_guard=stft_loss_log_guard, | |
| ) | |
| # Time domain loss setup | |
| self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0) | |
| self.si_sdr_loss_scale = cfg.get("si_sdr_loss_scale", 0.0) | |
| self.time_domain_loss_fn = TimeDomainLoss() | |
| self.si_sdr_loss_fn = SISDRLoss() | |
| # Discriminator loss setup | |
| self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0) | |
| self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0) | |
| self.gen_loss_fn = instantiate(cfg.generator_loss) | |
| self.disc_loss_fn = instantiate(cfg.discriminator_loss) | |
| self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0) | |
| if "mmd_loss" in cfg: | |
| self.mmd_loss_fn = instantiate(cfg.mmd_loss) | |
| self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0) | |
| else: | |
| self.mmd_loss_fn = None | |
| self.mmd_loss_scale = None | |
| if "mmd_time_loss" in cfg: | |
| self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss) | |
| self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0) | |
| else: | |
| self.mmd_time_loss_fn = None | |
| self.mmd_time_loss_scale = None | |
| feature_loss_type = cfg.get("feature_loss_type", "relative") | |
| if feature_loss_type == "relative": | |
| self.feature_loss_fn = RelativeFeatureMatchingLoss() | |
| elif feature_loss_type == "absolute": | |
| self.feature_loss_fn = FeatureMatchingLoss() | |
| else: | |
| raise ValueError(f'Unknown feature loss type {feature_loss_type}.') | |
| # Codebook loss setup | |
| if self.vector_quantizer: | |
| self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0) | |
| else: | |
| self.commit_loss_scale = 0.0 | |
| if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss: | |
| raise ValueError('Commit loss is enabled but the quantizer does not support it.') | |
| self.use_scl_loss = cfg.get("use_scl_loss", False) | |
| self.scl_loss_scale = cfg.get("scl_loss_scale", False) | |
| if self.use_scl_loss: | |
| self.speaker_encoder = ResNetSpeakerEncoder() | |
| # load pretrained model | |
| # self.speaker_encoder.load_checkpoint("https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar") | |
| self.speaker_encoder.load_checkpoint( | |
| "https://huggingface.co/Edresson/Speaker_Encoder_H_ASP/resolve/main/pytorch_model.bin", strict=False | |
| ) | |
| # freeze the pretrained speaker encoder | |
| self.speaker_encoder.freeze() | |
| print("Speaker encoder loaded and frozen !!") | |
| # Disabled for now as it is not used in final model | |
| self.use_asr_consitency_loss = False | |
| self.acl_loss_scale = False | |
| # self.use_asr_consitency_loss = cfg.get("use_asr_consitency_loss", False) | |
| # self.acl_loss_scale = cfg.get("acl_loss_scale", False) | |
| # if self.use_asr_consitency_loss: | |
| # self.phoneme_asr_model = PhonemeASR(input_sr=self.sample_rate) | |
| # self.phoneme_asr_model.freeze() | |
| # # self.acl_loss = CrossEntropyLoss() | |
| # print("Phoneme ASR model loaded and frozen !!") | |
| # Log setup | |
| self.log_config = cfg.get("log_config", None) | |
| # Optimizer setup | |
| self.lr_schedule_interval = None | |
| self.automatic_optimization = False | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def num_codebooks(self): | |
| if self.vector_quantizer is None: | |
| raise ValueError("This AudioCodecModel does not have a vector quantizer.") | |
| return self.vector_quantizer.num_codebooks | |
| def codebook_size(self): | |
| if self.vector_quantizer is None: | |
| raise ValueError("This AudioCodecModel does not have a vector quantizer.") | |
| return self.vector_quantizer.codebook_size | |
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |
| if hasattr(self, '_no_state_dict') and self._no_state_dict: | |
| return {} | |
| # Don't save the speaker verification and codec model in the state dict | |
| state_dict = super().state_dict(destination, prefix, keep_vars) | |
| for key in list(state_dict.keys()): | |
| if self.use_scl_loss and "speaker_encoder." in key: | |
| del state_dict[key] | |
| if "discriminator" in key and ".slm_model.ssl_model." in key: | |
| del state_dict[key] | |
| return state_dict | |
| def load_state_dict(self, state_dict, strict=True): | |
| # Override to load all the keys except .speaker_encoder. and WavLM model | |
| for key in list(state_dict.keys()): | |
| if self.use_scl_loss and "speaker_encoder." in key: | |
| del state_dict[key] | |
| if "discriminator" in key and ".slm_model.ssl_model." in key: | |
| del state_dict[key] | |
| super().load_state_dict(state_dict, strict=False) | |
| def get_speaker_embedding(self, audio, requires_grad=False): | |
| if not requires_grad: | |
| with torch.no_grad(): | |
| if HAVE_TORCHAUDIO: | |
| audio_resampled = torchaudio.functional.resample( | |
| audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"] | |
| ) | |
| else: | |
| logging.error('Could not import torchaudio!') | |
| raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!") | |
| g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1) | |
| else: | |
| if HAVE_TORCHAUDIO: | |
| audio_resampled = torchaudio.functional.resample( | |
| audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"] | |
| ) | |
| else: | |
| logging.error('Could not import torchaudio!') | |
| raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!") | |
| g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1) | |
| return g | |
| def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply encoder on the input audio signal. Input will be padded with zeros so | |
| the last frame has full `self.samples_per_frame` samples. | |
| Args: | |
| audio: input time-domain signal | |
| audio_len: valid length for each example in the batch | |
| Returns: | |
| Encoder output `encoded` and its length in number of frames `encoded_len` | |
| """ | |
| audio, audio_len = self.pad_audio(audio, audio_len) | |
| encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) | |
| return encoded, encoded_len | |
| def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation. | |
| Args: | |
| inputs: encoded signal | |
| input_len: valid length for each example in the batch | |
| Returns: | |
| Decoded output `audio` in the time domain and its length in number of samples `audio_len`. | |
| Note that `audio_len` will be a multiple of `self.samples_per_frame`. | |
| """ | |
| audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len) | |
| return audio, audio_len | |
| def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor: | |
| """Quantize the continuous encoded representation into a discrete | |
| representation for each frame. | |
| Args: | |
| encoded: encoded signal representation | |
| encoded_len: valid length of the encoded representation in frames | |
| Returns: | |
| A tensor of tokens for each codebook for each frame. | |
| """ | |
| if not self.vector_quantizer: | |
| raise ValueError("Cannot quantize without quantizer") | |
| # vector quantizer is returning [C, B, T], where C is the number of codebooks | |
| with default_precision(torch.float32): | |
| # vector quantizer is returning [C, B, T], where C is the number of codebooks | |
| tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len) | |
| # use batch first for the output | |
| tokens = rearrange(tokens, 'C B T -> B C T') | |
| return tokens | |
| def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor: | |
| """Convert the discrete tokens into a continuous encoded representation. | |
| Args: | |
| tokens: discrete tokens for each codebook for each time frame | |
| tokens_len: valid length of each example in the batch | |
| Returns: | |
| Continuous encoded representation of the discrete input representation. | |
| """ | |
| if not self.vector_quantizer: | |
| raise ValueError("Cannot dequantize without quantizer") | |
| # vector quantizer is using [C, B, T], where C is the number of codebooks | |
| tokens = rearrange(tokens, 'B C T -> C B T') | |
| with default_precision(torch.float32): | |
| dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len) | |
| dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype | |
| return dequantized | |
| def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Convert input time-domain audio signal into a discrete representation (tokens). | |
| Args: | |
| audio: input time-domain signal, shape `(batch, number of samples)` | |
| audio_len: valid length for each example in the batch, shape `(batch size,)` | |
| Returns: | |
| Tokens for each codebook for each frame, shape `(batch, number of codebooks, number of frames)`, | |
| and the corresponding valid lengths, shape `(batch,)` | |
| """ | |
| # Apply encoder to obtain a continuous vector for each frame | |
| encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) | |
| # Apply quantizer to obtain discrete representation per frame | |
| tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) | |
| return tokens, encoded_len | |
| def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Convert discrete tokens into a continuous time-domain signal. | |
| Args: | |
| tokens: discrete tokens for each codebook for each time frame, shape `(batch, number of codebooks, number of frames)` | |
| tokens_len: valid lengths, shape `(batch,)` | |
| Returns: | |
| Decoded output `audio` in the time domain and its length in number of samples `audio_len`. | |
| Note that `audio_len` will be a multiple of `self.samples_per_frame`. | |
| """ | |
| # Convert a discrete representation to a dequantized vector for each frame | |
| dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len) | |
| dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype | |
| # Apply decoder to obtain time-domain audio for each frame | |
| audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len) | |
| return audio, audio_len | |
| def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply encoder, quantizer, decoder on the input time-domain signal. | |
| Args: | |
| audio: input time-domain signal | |
| audio_len: valid length for each example in the batch | |
| Returns: | |
| Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`. | |
| """ | |
| encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len) | |
| if self.vector_quantizer: | |
| # quantize to discrete tokens | |
| tokens = self.quantize(encoded=encoded, encoded_len=encoded_len) | |
| # decode tokens to audio | |
| output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len) | |
| else: | |
| # no quantization, directly decode to audio | |
| output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len) | |
| return output_audio, output_audio_len | |
| def pad_audio(self, audio, audio_len): | |
| """Zero pad the end of the audio so that we do not have a partial end frame. | |
| The output will be zero-padded to have an integer number of frames of | |
| length `self.samples_per_frame`. | |
| Args: | |
| audio: input time-domain signal | |
| audio_len: valid length for each example in the batch | |
| Returns: | |
| Padded time-domain signal `padded_audio` and its length `padded_len`. | |
| """ | |
| padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int() | |
| max_len = padded_len.max().item() | |
| num_padding = max_len - audio.shape[1] | |
| padded_audio = F.pad(audio, (0, num_padding)) | |
| return padded_audio, padded_len | |
| def _process_batch(self, batch): | |
| # [B, T_audio] | |
| audio = batch.get("audio") | |
| # [B] | |
| audio_len = batch.get("audio_lens") | |
| audio, audio_len = self.pad_audio(audio, audio_len) | |
| # [B, D, T_encoded] | |
| encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) | |
| if self.encoder_noise is not None: | |
| encoded = self.encoder_noise(encoded) | |
| if self.vector_quantizer: | |
| with default_precision(torch.float32): | |
| if self.vector_quantizer_has_commit_loss: | |
| encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len) | |
| else: | |
| encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len) | |
| commit_loss = 0.0 | |
| encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype | |
| else: | |
| commit_loss = 0.0 | |
| # [B, T] | |
| encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype | |
| audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) | |
| return audio, audio_len, audio_gen, commit_loss, encoded | |
| def disc_update_prob(self) -> float: | |
| """Probability of updating the discriminator.""" | |
| return self.disc_updates_per_period / self.disc_update_period | |
| def should_update_disc(self, batch_idx) -> bool: | |
| """Decide whether to update the descriminator based | |
| on the batch index and configured discriminator update period. | |
| """ | |
| disc_update_step = batch_idx % self.disc_update_period | |
| return disc_update_step < self.disc_updates_per_period | |
| def training_step(self, batch, batch_idx): | |
| optim_gen, optim_disc = self.optimizers() | |
| audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch) | |
| metrics = { | |
| "global_step": self.global_step, | |
| "lr": optim_gen.param_groups[0]['lr'], | |
| } | |
| if self.should_update_disc(batch_idx): | |
| # Train discriminator | |
| disc_scores_real, disc_scores_gen, _, _ = self.discriminator( | |
| audio_real=audio, audio_gen=audio_gen.detach() | |
| ) | |
| loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen) | |
| metrics["d_loss"] = loss_disc | |
| optim_disc.zero_grad() | |
| self.manual_backward(loss_disc) | |
| optim_disc.step() | |
| generator_losses = [] | |
| # stft does not support bf16, so make it run in fp32 | |
| loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( | |
| audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len | |
| ) | |
| if self.mel_loss_l1_scale: | |
| metrics["g_loss_mel_l1"] = loss_mel_l1 | |
| generator_losses.append(self.mel_loss_l1_scale * loss_mel_l1) | |
| if self.mel_loss_l2_scale: | |
| metrics["g_loss_mel_l2"] = loss_mel_l2 | |
| generator_losses.append(self.mel_loss_l2_scale * loss_mel_l2) | |
| if self.stft_loss_scale: | |
| loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) | |
| metrics["g_loss_stft"] = loss_stft | |
| generator_losses.append(self.stft_loss_scale * loss_stft) | |
| if self.time_domain_loss_scale: | |
| loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) | |
| metrics["g_loss_time_domain"] = loss_time_domain | |
| generator_losses.append(self.time_domain_loss_scale * loss_time_domain) | |
| if self.si_sdr_loss_scale: | |
| loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) | |
| metrics["g_loss_si_sdr"] = loss_si_sdr | |
| generator_losses.append(self.si_sdr_loss_scale * loss_si_sdr) | |
| _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) | |
| if self.gen_loss_scale: | |
| loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen) | |
| metrics["g_loss_gen"] = loss_gen | |
| generator_losses.append(self.gen_loss_scale * loss_gen) | |
| if self.feature_loss_scale: | |
| loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) | |
| metrics["g_loss_feature"] = loss_feature | |
| generator_losses.append(self.feature_loss_scale * loss_feature) | |
| if self.commit_loss_scale: | |
| metrics["g_loss_commit"] = commit_loss | |
| generator_losses.append(self.commit_loss_scale * commit_loss) | |
| if self.mmd_loss_scale: | |
| loss_mmd = self.mmd_loss_fn(inputs=codes) | |
| metrics["g_loss_mmd"] = loss_mmd | |
| if self.current_epoch >= self.mmd_loss_start_epoch: | |
| generator_losses.append(self.mmd_loss_scale * loss_mmd) | |
| if self.mmd_time_loss_scale: | |
| loss_mmd_time = self.mmd_time_loss_fn(inputs=codes) | |
| metrics["g_loss_mmd_time"] = loss_mmd_time | |
| if self.current_epoch >= self.mmd_loss_start_epoch: | |
| generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time) | |
| # compute embeddings for speaker consistency loss | |
| if self.use_scl_loss: | |
| # concate generated and GT waveforms | |
| audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) | |
| # get speaker embeddings with grads | |
| pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True) | |
| # split generated and GT speaker embeddings | |
| gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) | |
| # speaker consistency loss like YourTTS paper | |
| loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale | |
| metrics["g_loss_scl"] = loss_scl | |
| generator_losses.append(metrics["g_loss_scl"]) | |
| if self.use_asr_consitency_loss: | |
| # concate generated and GT waveforms | |
| audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) | |
| logits, _ = self.phoneme_asr_model(audios_batch) | |
| logits_gt, logits_pred = torch.chunk(logits, 2, dim=0) | |
| # labels_gt, labels_pred = torch.chunk(labels, 2, dim=0) | |
| loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale | |
| metrics["g_loss_acl"] = loss_acl | |
| generator_losses.append(metrics["g_loss_acl"]) | |
| loss_gen_all = sum(generator_losses) | |
| optim_gen.zero_grad() | |
| self.manual_backward(loss_gen_all) | |
| optim_gen.step() | |
| self.update_lr() | |
| self.log_dict(metrics, on_step=True, sync_dist=True) | |
| self.log("t_loss", loss_mel_l1, prog_bar=True, logger=False, sync_dist=True) | |
| def on_train_epoch_end(self): | |
| self.update_lr("epoch") | |
| def validation_step(self, batch, batch_idx): | |
| audio, audio_len, audio_gen, _, _ = self._process_batch(batch) | |
| loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( | |
| audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len | |
| ) | |
| loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len) | |
| loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) | |
| loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len) | |
| # Use only main reconstruction losses for val_loss | |
| val_loss = loss_mel_l1 + loss_stft + loss_time_domain | |
| metrics = { | |
| "val_loss": val_loss, | |
| "val_loss_mel_l1": loss_mel_l1, | |
| "val_loss_mel_l2": loss_mel_l2, | |
| "val_loss_stft": loss_stft, | |
| "val_loss_time_domain": loss_time_domain, | |
| "val_loss_si_sdr": loss_si_sdr, | |
| } | |
| # compute embeddings for speaker consistency loss | |
| if self.use_scl_loss: | |
| # concate generated and GT waveforms | |
| audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) | |
| # get speaker embeddings with grads | |
| pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True) | |
| # split generated and GT speaker embeddings | |
| gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) | |
| # speaker consistency loss like YourTTS paper | |
| loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale | |
| metrics["val_loss_scl"] = loss_scl | |
| metrics["val_loss"] += metrics["val_loss_scl"] | |
| if self.use_asr_consitency_loss: | |
| # concate generated and GT waveforms | |
| audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) | |
| logits, _ = self.phoneme_asr_model(audios_batch) | |
| logits_gt, logits_pred = torch.chunk(logits, 2, dim=0) | |
| loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale | |
| metrics["val_loss_acl"] = loss_acl | |
| metrics["val_loss"] += metrics["val_loss_acl"] | |
| self.log_dict(metrics, on_epoch=True, sync_dist=True) | |
| def get_dataset(self, cfg): | |
| with open_dict(cfg): | |
| is_sharded = cfg.dataset.pop('is_sharded', False) | |
| if is_sharded: | |
| with open_dict(cfg): | |
| cfg.dataset.global_rank = self.global_rank | |
| cfg.dataset.world_size = self.world_size | |
| cfg.dataset._target_ = 'nemo.collections.tts.data.vocoder_dataset.TarredVocoderDataset' | |
| dataset = instantiate(cfg.dataset) | |
| sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) | |
| return dataset, sampler | |
| def _setup_train_dataloader(self, cfg): | |
| dataset, sampler = self.get_dataset(cfg) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params | |
| ) | |
| return data_loader | |
| def _setup_test_dataloader(self, cfg): | |
| dataset = instantiate(cfg.dataset) | |
| data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) | |
| return data_loader | |
| def setup_training_data(self, cfg): | |
| self._train_dl = self._setup_train_dataloader(cfg) | |
| batch_size = cfg['dataloader_params']['batch_size'] | |
| # Need to set this because if using an IterableDataset, the length of the dataloader is the total number | |
| # of samples rather than the number of batches, and this messes up the tqdm progress bar. | |
| # So we set the number of steps manually (to the correct number) to fix this. | |
| if ( | |
| self._train_dl is not None | |
| and hasattr(self._train_dl, 'dataset') | |
| and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset) | |
| ): | |
| # We also need to check if limit_train_batches is already set. | |
| # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, | |
| # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). | |
| if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): | |
| self._trainer.limit_train_batches = int( | |
| self._trainer.limit_train_batches | |
| * ceil((len(self._train_dl.dataset) / self.world_size) / batch_size) | |
| ) | |
| elif self._trainer is None: | |
| logging.warning( | |
| "Model Trainer was not set before constructing the dataset, incorrect number of " | |
| "training batches will be used. Please set the trainer and rebuild the dataset." | |
| ) | |
| def setup_validation_data(self, cfg): | |
| self._validation_dl = self._setup_test_dataloader(cfg) | |
| def setup_test_data(self, cfg): | |
| pass | |
| def max_steps(self): | |
| if "max_steps" in self._cfg: | |
| return self._cfg.get("max_steps") | |
| if "max_epochs" not in self._cfg: | |
| raise ValueError("Must specify 'max_steps' or 'max_epochs'.") | |
| if "steps_per_epoch" in self._cfg: | |
| return self._cfg.max_epochs * self._cfg.steps_per_epoch | |
| return compute_max_steps( | |
| max_epochs=self._cfg.max_epochs, | |
| accumulate_grad_batches=self.trainer.accumulate_grad_batches, | |
| limit_train_batches=self.trainer.limit_train_batches, | |
| num_workers=get_num_workers(self.trainer), | |
| num_samples=len(self._train_dl.dataset), | |
| batch_size=get_batch_size(self._train_dl), | |
| drop_last=self._train_dl.drop_last, | |
| ) | |
| def configure_optimizers(self): | |
| optim_config = self._cfg.optim.copy() | |
| OmegaConf.set_struct(optim_config, False) | |
| sched_config = optim_config.pop("sched", None) | |
| OmegaConf.set_struct(optim_config, True) | |
| asr_ph_params = self.phoneme_asr_model.parameters() if self.use_asr_consitency_loss else [] | |
| se_params = self.speaker_encoder.parameters() if self.use_scl_loss else [] | |
| vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else [] | |
| gen_params = itertools.chain( | |
| self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params, asr_ph_params, se_params | |
| ) | |
| optim_g = instantiate(optim_config, params=gen_params) | |
| disc_params = self.discriminator.parameters() | |
| optim_d = instantiate(optim_config, params=disc_params) | |
| if sched_config is None: | |
| logging.debug('Scheduler is not used') | |
| return [optim_g, optim_d] | |
| logging.debug('Setting up schedulers') | |
| OmegaConf.set_struct(sched_config, False) | |
| sched_config["max_steps"] = self.max_steps | |
| OmegaConf.set_struct(sched_config, True) | |
| scheduler_g = prepare_lr_scheduler( | |
| optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl | |
| ) | |
| scheduler_d = prepare_lr_scheduler( | |
| optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl | |
| ) | |
| self.lr_schedule_interval = scheduler_g["interval"] | |
| return [optim_g, optim_d], [scheduler_g, scheduler_d] | |
| def update_lr(self, interval="step"): | |
| schedulers = self.lr_schedulers() | |
| if schedulers is not None and self.lr_schedule_interval == interval: | |
| sch1, sch2 = schedulers | |
| sch1.step() | |
| sch2.step() | |
| def configure_callbacks(self): | |
| if not self.log_config: | |
| return [] | |
| data_loader = self._setup_test_dataloader(self.log_config) | |
| generators = instantiate(self.log_config.generators) | |
| log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None | |
| log_callback = LoggingCallback( | |
| generators=generators, | |
| data_loader=data_loader, | |
| log_epochs=self.log_config.log_epochs, | |
| epoch_frequency=self.log_config.epoch_frequency, | |
| output_dir=log_dir, | |
| loggers=self.trainer.loggers, | |
| log_tensorboard=self.log_config.log_tensorboard, | |
| log_wandb=self.log_config.log_wandb, | |
| ) | |
| return [log_callback] | |
| def list_available_models(cls) -> List[PretrainedModelInfo]: | |
| models = [] | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="audio_codec_16khz_small", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo", | |
| description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/audio_codec_16khz_small", | |
| ) | |
| models.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="mel_codec_22khz_medium", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo", | |
| description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_medium", | |
| ) | |
| models.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="mel_codec_44khz_medium", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo", | |
| description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_medium", | |
| ) | |
| models.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="mel_codec_22khz_fullband_medium", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo", | |
| description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium", | |
| ) | |
| models.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="mel_codec_44khz_fullband_medium", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo", | |
| description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium", | |
| ) | |
| models.append(model) | |
| return models | |