subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from math import ceil
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.collections.tts.losses.audio_codec_loss import (
FeatureMatchingLoss,
MultiResolutionMelLoss,
MultiResolutionSTFTLoss,
RelativeFeatureMatchingLoss,
SISDRLoss,
TimeDomainLoss,
)
from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder, default_precision
from nemo.collections.tts.modules.common import GaussianDropout
from nemo.collections.tts.parts.utils.callbacks import LoggingCallback
from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers
from nemo.core import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, TokenIndex
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import logging, model_utils
try:
import torchaudio
HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False
class AudioCodecModel(ModelPT):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices
super().__init__(cfg=cfg, trainer=trainer)
# Expected sample rate for the input audio
self.sample_rate = cfg.sample_rate
# Number of samples in each audio frame that is encoded
self.samples_per_frame = cfg.samples_per_frame
# Discriminator updates
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)
# Encoder setup
self.audio_encoder = instantiate(cfg.audio_encoder)
# Optionally, add gaussian noise to encoder output as an information bottleneck
encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0)
if encoder_noise_stdev:
self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev)
else:
self.encoder_noise = None
if "vector_quantizer" in cfg:
self.vector_quantizer = instantiate(cfg.vector_quantizer)
vq_output_types = list(self.vector_quantizer.output_types.keys())
if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss':
self.vector_quantizer_has_commit_loss = True
logging.info('Vector quantizer supports commit loss.')
else:
self.vector_quantizer_has_commit_loss = False
logging.info('Vector quantizer does not support commit loss.')
else:
logging.warning('Vector quantizer will not be used.')
self.vector_quantizer = None
# Decoder setup
self.audio_decoder = instantiate(cfg.audio_decoder)
# Discriminator setup
self.discriminator = instantiate(cfg.discriminator)
# Mel loss setup
loss_resolutions = cfg.loss_resolutions
mel_loss_dims = cfg.get("mel_loss_dims")
mel_loss_log_guard = cfg.get("mel_loss_log_guard", 1.0)
self.mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0)
self.mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0)
self.mel_loss_fn = MultiResolutionMelLoss(
sample_rate=self.sample_rate,
mel_dims=mel_loss_dims,
resolutions=loss_resolutions,
log_guard=mel_loss_log_guard,
)
# STFT loss setup
stft_loss_log_guard = cfg.get("stft_loss_log_guard", 1.0)
self.stft_loss_scale = cfg.get("stft_loss_scale", 0.0)
self.stft_loss_fn = MultiResolutionSTFTLoss(
resolutions=loss_resolutions,
log_guard=stft_loss_log_guard,
)
# Time domain loss setup
self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0)
self.si_sdr_loss_scale = cfg.get("si_sdr_loss_scale", 0.0)
self.time_domain_loss_fn = TimeDomainLoss()
self.si_sdr_loss_fn = SISDRLoss()
# Discriminator loss setup
self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0)
self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0)
self.gen_loss_fn = instantiate(cfg.generator_loss)
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0)
if "mmd_loss" in cfg:
self.mmd_loss_fn = instantiate(cfg.mmd_loss)
self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0)
else:
self.mmd_loss_fn = None
self.mmd_loss_scale = None
if "mmd_time_loss" in cfg:
self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss)
self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0)
else:
self.mmd_time_loss_fn = None
self.mmd_time_loss_scale = None
feature_loss_type = cfg.get("feature_loss_type", "relative")
if feature_loss_type == "relative":
self.feature_loss_fn = RelativeFeatureMatchingLoss()
elif feature_loss_type == "absolute":
self.feature_loss_fn = FeatureMatchingLoss()
else:
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')
# Codebook loss setup
if self.vector_quantizer:
self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0)
else:
self.commit_loss_scale = 0.0
if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss:
raise ValueError('Commit loss is enabled but the quantizer does not support it.')
self.use_scl_loss = cfg.get("use_scl_loss", False)
self.scl_loss_scale = cfg.get("scl_loss_scale", False)
if self.use_scl_loss:
self.speaker_encoder = ResNetSpeakerEncoder()
# load pretrained model
# self.speaker_encoder.load_checkpoint("https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar")
self.speaker_encoder.load_checkpoint(
"https://huggingface.co/Edresson/Speaker_Encoder_H_ASP/resolve/main/pytorch_model.bin", strict=False
)
# freeze the pretrained speaker encoder
self.speaker_encoder.freeze()
print("Speaker encoder loaded and frozen !!")
# Disabled for now as it is not used in final model
self.use_asr_consitency_loss = False
self.acl_loss_scale = False
# self.use_asr_consitency_loss = cfg.get("use_asr_consitency_loss", False)
# self.acl_loss_scale = cfg.get("acl_loss_scale", False)
# if self.use_asr_consitency_loss:
# self.phoneme_asr_model = PhonemeASR(input_sr=self.sample_rate)
# self.phoneme_asr_model.freeze()
# # self.acl_loss = CrossEntropyLoss()
# print("Phoneme ASR model loaded and frozen !!")
# Log setup
self.log_config = cfg.get("log_config", None)
# Optimizer setup
self.lr_schedule_interval = None
self.automatic_optimization = False
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def num_codebooks(self):
if self.vector_quantizer is None:
raise ValueError("This AudioCodecModel does not have a vector quantizer.")
return self.vector_quantizer.num_codebooks
@property
def codebook_size(self):
if self.vector_quantizer is None:
raise ValueError("This AudioCodecModel does not have a vector quantizer.")
return self.vector_quantizer.codebook_size
def state_dict(self, destination=None, prefix='', keep_vars=False):
if hasattr(self, '_no_state_dict') and self._no_state_dict:
return {}
# Don't save the speaker verification and codec model in the state dict
state_dict = super().state_dict(destination, prefix, keep_vars)
for key in list(state_dict.keys()):
if self.use_scl_loss and "speaker_encoder." in key:
del state_dict[key]
if "discriminator" in key and ".slm_model.ssl_model." in key:
del state_dict[key]
return state_dict
def load_state_dict(self, state_dict, strict=True):
# Override to load all the keys except .speaker_encoder. and WavLM model
for key in list(state_dict.keys()):
if self.use_scl_loss and "speaker_encoder." in key:
del state_dict[key]
if "discriminator" in key and ".slm_model.ssl_model." in key:
del state_dict[key]
super().load_state_dict(state_dict, strict=False)
def get_speaker_embedding(self, audio, requires_grad=False):
if not requires_grad:
with torch.no_grad():
if HAVE_TORCHAUDIO:
audio_resampled = torchaudio.functional.resample(
audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"]
)
else:
logging.error('Could not import torchaudio!')
raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!")
g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1)
else:
if HAVE_TORCHAUDIO:
audio_resampled = torchaudio.functional.resample(
audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"]
)
else:
logging.error('Could not import torchaudio!')
raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!")
g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1)
return g
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder on the input audio signal. Input will be padded with zeros so
the last frame has full `self.samples_per_frame` samples.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Encoder output `encoded` and its length in number of frames `encoded_len`
"""
audio, audio_len = self.pad_audio(audio, audio_len)
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
return encoded, encoded_len
@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation.
Args:
inputs: encoded signal
input_len: valid length for each example in the batch
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
return audio, audio_len
@typecheck(
input_types={
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex())},
)
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
"""Quantize the continuous encoded representation into a discrete
representation for each frame.
Args:
encoded: encoded signal representation
encoded_len: valid length of the encoded representation in frames
Returns:
A tensor of tokens for each codebook for each frame.
"""
if not self.vector_quantizer:
raise ValueError("Cannot quantize without quantizer")
# vector quantizer is returning [C, B, T], where C is the number of codebooks
with default_precision(torch.float32):
# vector quantizer is returning [C, B, T], where C is the number of codebooks
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
# use batch first for the output
tokens = rearrange(tokens, 'C B T -> B C T')
return tokens
@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
},
)
def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor:
"""Convert the discrete tokens into a continuous encoded representation.
Args:
tokens: discrete tokens for each codebook for each time frame
tokens_len: valid length of each example in the batch
Returns:
Continuous encoded representation of the discrete input representation.
"""
if not self.vector_quantizer:
raise ValueError("Cannot dequantize without quantizer")
# vector quantizer is using [C, B, T], where C is the number of codebooks
tokens = rearrange(tokens, 'B C T -> C B T')
with default_precision(torch.float32):
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype
return dequantized
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert input time-domain audio signal into a discrete representation (tokens).
Args:
audio: input time-domain signal, shape `(batch, number of samples)`
audio_len: valid length for each example in the batch, shape `(batch size,)`
Returns:
Tokens for each codebook for each frame, shape `(batch, number of codebooks, number of frames)`,
and the corresponding valid lengths, shape `(batch,)`
"""
# Apply encoder to obtain a continuous vector for each frame
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
# Apply quantizer to obtain discrete representation per frame
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
return tokens, encoded_len
@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert discrete tokens into a continuous time-domain signal.
Args:
tokens: discrete tokens for each codebook for each time frame, shape `(batch, number of codebooks, number of frames)`
tokens_len: valid lengths, shape `(batch,)`
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
# Convert a discrete representation to a dequantized vector for each frame
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype
# Apply decoder to obtain time-domain audio for each frame
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)
return audio, audio_len
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()),
"output_audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder, quantizer, decoder on the input time-domain signal.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`.
"""
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
if self.vector_quantizer:
# quantize to discrete tokens
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
# decode tokens to audio
output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len)
else:
# no quantization, directly decode to audio
output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len)
return output_audio, output_audio_len
def pad_audio(self, audio, audio_len):
"""Zero pad the end of the audio so that we do not have a partial end frame.
The output will be zero-padded to have an integer number of frames of
length `self.samples_per_frame`.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Padded time-domain signal `padded_audio` and its length `padded_len`.
"""
padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int()
max_len = padded_len.max().item()
num_padding = max_len - audio.shape[1]
padded_audio = F.pad(audio, (0, num_padding))
return padded_audio, padded_len
def _process_batch(self, batch):
# [B, T_audio]
audio = batch.get("audio")
# [B]
audio_len = batch.get("audio_lens")
audio, audio_len = self.pad_audio(audio, audio_len)
# [B, D, T_encoded]
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
if self.encoder_noise is not None:
encoded = self.encoder_noise(encoded)
if self.vector_quantizer:
with default_precision(torch.float32):
if self.vector_quantizer_has_commit_loss:
encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
else:
encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
commit_loss = 0.0
encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype
else:
commit_loss = 0.0
# [B, T]
encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype
audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len)
return audio, audio_len, audio_gen, commit_loss, encoded
@property
def disc_update_prob(self) -> float:
"""Probability of updating the discriminator."""
return self.disc_updates_per_period / self.disc_update_period
def should_update_disc(self, batch_idx) -> bool:
"""Decide whether to update the descriminator based
on the batch index and configured discriminator update period.
"""
disc_update_step = batch_idx % self.disc_update_period
return disc_update_step < self.disc_updates_per_period
def training_step(self, batch, batch_idx):
optim_gen, optim_disc = self.optimizers()
audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch)
metrics = {
"global_step": self.global_step,
"lr": optim_gen.param_groups[0]['lr'],
}
if self.should_update_disc(batch_idx):
# Train discriminator
disc_scores_real, disc_scores_gen, _, _ = self.discriminator(
audio_real=audio, audio_gen=audio_gen.detach()
)
loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen)
metrics["d_loss"] = loss_disc
optim_disc.zero_grad()
self.manual_backward(loss_disc)
optim_disc.step()
generator_losses = []
# stft does not support bf16, so make it run in fp32
loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(
audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len
)
if self.mel_loss_l1_scale:
metrics["g_loss_mel_l1"] = loss_mel_l1
generator_losses.append(self.mel_loss_l1_scale * loss_mel_l1)
if self.mel_loss_l2_scale:
metrics["g_loss_mel_l2"] = loss_mel_l2
generator_losses.append(self.mel_loss_l2_scale * loss_mel_l2)
if self.stft_loss_scale:
loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len)
metrics["g_loss_stft"] = loss_stft
generator_losses.append(self.stft_loss_scale * loss_stft)
if self.time_domain_loss_scale:
loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
metrics["g_loss_time_domain"] = loss_time_domain
generator_losses.append(self.time_domain_loss_scale * loss_time_domain)
if self.si_sdr_loss_scale:
loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
metrics["g_loss_si_sdr"] = loss_si_sdr
generator_losses.append(self.si_sdr_loss_scale * loss_si_sdr)
_, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen)
if self.gen_loss_scale:
loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen)
metrics["g_loss_gen"] = loss_gen
generator_losses.append(self.gen_loss_scale * loss_gen)
if self.feature_loss_scale:
loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen)
metrics["g_loss_feature"] = loss_feature
generator_losses.append(self.feature_loss_scale * loss_feature)
if self.commit_loss_scale:
metrics["g_loss_commit"] = commit_loss
generator_losses.append(self.commit_loss_scale * commit_loss)
if self.mmd_loss_scale:
loss_mmd = self.mmd_loss_fn(inputs=codes)
metrics["g_loss_mmd"] = loss_mmd
if self.current_epoch >= self.mmd_loss_start_epoch:
generator_losses.append(self.mmd_loss_scale * loss_mmd)
if self.mmd_time_loss_scale:
loss_mmd_time = self.mmd_time_loss_fn(inputs=codes)
metrics["g_loss_mmd_time"] = loss_mmd_time
if self.current_epoch >= self.mmd_loss_start_epoch:
generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time)
# compute embeddings for speaker consistency loss
if self.use_scl_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
# get speaker embeddings with grads
pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
# speaker consistency loss like YourTTS paper
loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale
metrics["g_loss_scl"] = loss_scl
generator_losses.append(metrics["g_loss_scl"])
if self.use_asr_consitency_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
logits, _ = self.phoneme_asr_model(audios_batch)
logits_gt, logits_pred = torch.chunk(logits, 2, dim=0)
# labels_gt, labels_pred = torch.chunk(labels, 2, dim=0)
loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale
metrics["g_loss_acl"] = loss_acl
generator_losses.append(metrics["g_loss_acl"])
loss_gen_all = sum(generator_losses)
optim_gen.zero_grad()
self.manual_backward(loss_gen_all)
optim_gen.step()
self.update_lr()
self.log_dict(metrics, on_step=True, sync_dist=True)
self.log("t_loss", loss_mel_l1, prog_bar=True, logger=False, sync_dist=True)
def on_train_epoch_end(self):
self.update_lr("epoch")
def validation_step(self, batch, batch_idx):
audio, audio_len, audio_gen, _, _ = self._process_batch(batch)
loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(
audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len
)
loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len)
loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
# Use only main reconstruction losses for val_loss
val_loss = loss_mel_l1 + loss_stft + loss_time_domain
metrics = {
"val_loss": val_loss,
"val_loss_mel_l1": loss_mel_l1,
"val_loss_mel_l2": loss_mel_l2,
"val_loss_stft": loss_stft,
"val_loss_time_domain": loss_time_domain,
"val_loss_si_sdr": loss_si_sdr,
}
# compute embeddings for speaker consistency loss
if self.use_scl_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
# get speaker embeddings with grads
pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
# speaker consistency loss like YourTTS paper
loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale
metrics["val_loss_scl"] = loss_scl
metrics["val_loss"] += metrics["val_loss_scl"]
if self.use_asr_consitency_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
logits, _ = self.phoneme_asr_model(audios_batch)
logits_gt, logits_pred = torch.chunk(logits, 2, dim=0)
loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale
metrics["val_loss_acl"] = loss_acl
metrics["val_loss"] += metrics["val_loss_acl"]
self.log_dict(metrics, on_epoch=True, sync_dist=True)
def get_dataset(self, cfg):
with open_dict(cfg):
is_sharded = cfg.dataset.pop('is_sharded', False)
if is_sharded:
with open_dict(cfg):
cfg.dataset.global_rank = self.global_rank
cfg.dataset.world_size = self.world_size
cfg.dataset._target_ = 'nemo.collections.tts.data.vocoder_dataset.TarredVocoderDataset'
dataset = instantiate(cfg.dataset)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
return dataset, sampler
def _setup_train_dataloader(self, cfg):
dataset, sampler = self.get_dataset(cfg)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params
)
return data_loader
def _setup_test_dataloader(self, cfg):
dataset = instantiate(cfg.dataset)
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
return data_loader
def setup_training_data(self, cfg):
self._train_dl = self._setup_train_dataloader(cfg)
batch_size = cfg['dataloader_params']['batch_size']
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if (
self._train_dl is not None
and hasattr(self._train_dl, 'dataset')
and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset)
):
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float):
self._trainer.limit_train_batches = int(
self._trainer.limit_train_batches
* ceil((len(self._train_dl.dataset) / self.world_size) / batch_size)
)
elif self._trainer is None:
logging.warning(
"Model Trainer was not set before constructing the dataset, incorrect number of "
"training batches will be used. Please set the trainer and rebuild the dataset."
)
def setup_validation_data(self, cfg):
self._validation_dl = self._setup_test_dataloader(cfg)
def setup_test_data(self, cfg):
pass
@property
def max_steps(self):
if "max_steps" in self._cfg:
return self._cfg.get("max_steps")
if "max_epochs" not in self._cfg:
raise ValueError("Must specify 'max_steps' or 'max_epochs'.")
if "steps_per_epoch" in self._cfg:
return self._cfg.max_epochs * self._cfg.steps_per_epoch
return compute_max_steps(
max_epochs=self._cfg.max_epochs,
accumulate_grad_batches=self.trainer.accumulate_grad_batches,
limit_train_batches=self.trainer.limit_train_batches,
num_workers=get_num_workers(self.trainer),
num_samples=len(self._train_dl.dataset),
batch_size=get_batch_size(self._train_dl),
drop_last=self._train_dl.drop_last,
)
def configure_optimizers(self):
optim_config = self._cfg.optim.copy()
OmegaConf.set_struct(optim_config, False)
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)
asr_ph_params = self.phoneme_asr_model.parameters() if self.use_asr_consitency_loss else []
se_params = self.speaker_encoder.parameters() if self.use_scl_loss else []
vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else []
gen_params = itertools.chain(
self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params, asr_ph_params, se_params
)
optim_g = instantiate(optim_config, params=gen_params)
disc_params = self.discriminator.parameters()
optim_d = instantiate(optim_config, params=disc_params)
if sched_config is None:
logging.debug('Scheduler is not used')
return [optim_g, optim_d]
logging.debug('Setting up schedulers')
OmegaConf.set_struct(sched_config, False)
sched_config["max_steps"] = self.max_steps
OmegaConf.set_struct(sched_config, True)
scheduler_g = prepare_lr_scheduler(
optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl
)
scheduler_d = prepare_lr_scheduler(
optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl
)
self.lr_schedule_interval = scheduler_g["interval"]
return [optim_g, optim_d], [scheduler_g, scheduler_d]
def update_lr(self, interval="step"):
schedulers = self.lr_schedulers()
if schedulers is not None and self.lr_schedule_interval == interval:
sch1, sch2 = schedulers
sch1.step()
sch2.step()
def configure_callbacks(self):
if not self.log_config:
return []
data_loader = self._setup_test_dataloader(self.log_config)
generators = instantiate(self.log_config.generators)
log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None
log_callback = LoggingCallback(
generators=generators,
data_loader=data_loader,
log_epochs=self.log_config.log_epochs,
epoch_frequency=self.log_config.epoch_frequency,
output_dir=log_dir,
loggers=self.trainer.loggers,
log_tensorboard=self.log_config.log_tensorboard,
log_wandb=self.log_config.log_wandb,
)
return [log_callback]
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
models = []
model = PretrainedModelInfo(
pretrained_model_name="audio_codec_16khz_small",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/audio_codec_16khz_small",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_22khz_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_44khz_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_22khz_fullband_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_44khz_fullband_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium",
)
models.append(model)
return models