Spaces:
Runtime error
Runtime error
File size: 38,824 Bytes
0558aa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 |
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from math import ceil
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.collections.tts.losses.audio_codec_loss import (
FeatureMatchingLoss,
MultiResolutionMelLoss,
MultiResolutionSTFTLoss,
RelativeFeatureMatchingLoss,
SISDRLoss,
TimeDomainLoss,
)
from nemo.collections.tts.modules.audio_codec_modules import ResNetSpeakerEncoder, default_precision
from nemo.collections.tts.modules.common import GaussianDropout
from nemo.collections.tts.parts.utils.callbacks import LoggingCallback
from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers
from nemo.core import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, LengthsType, TokenIndex
from nemo.core.neural_types.neural_type import NeuralType
from nemo.core.optim.lr_scheduler import compute_max_steps, prepare_lr_scheduler
from nemo.utils import logging, model_utils
try:
import torchaudio
HAVE_TORCHAUDIO = True
except ModuleNotFoundError:
HAVE_TORCHAUDIO = False
class AudioCodecModel(ModelPT):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Convert to Hydra 1.0 compatible DictConfig
cfg = model_utils.convert_model_config_to_dict_config(cfg)
cfg = model_utils.maybe_update_config_version(cfg)
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices
super().__init__(cfg=cfg, trainer=trainer)
# Expected sample rate for the input audio
self.sample_rate = cfg.sample_rate
# Number of samples in each audio frame that is encoded
self.samples_per_frame = cfg.samples_per_frame
# Discriminator updates
self.disc_updates_per_period = cfg.get("disc_updates_per_period", 1)
self.disc_update_period = cfg.get("disc_update_period", 1)
if self.disc_updates_per_period > self.disc_update_period:
raise ValueError(
f'Number of discriminator updates ({self.disc_updates_per_period}) per period must be less or equal to the configured period ({self.disc_update_period})'
)
# Encoder setup
self.audio_encoder = instantiate(cfg.audio_encoder)
# Optionally, add gaussian noise to encoder output as an information bottleneck
encoder_noise_stdev = cfg.get("encoder_noise_stdev", 0.0)
if encoder_noise_stdev:
self.encoder_noise = GaussianDropout(stdev=encoder_noise_stdev)
else:
self.encoder_noise = None
if "vector_quantizer" in cfg:
self.vector_quantizer = instantiate(cfg.vector_quantizer)
vq_output_types = list(self.vector_quantizer.output_types.keys())
if len(vq_output_types) == 3 and vq_output_types[-1] == 'commit_loss':
self.vector_quantizer_has_commit_loss = True
logging.info('Vector quantizer supports commit loss.')
else:
self.vector_quantizer_has_commit_loss = False
logging.info('Vector quantizer does not support commit loss.')
else:
logging.warning('Vector quantizer will not be used.')
self.vector_quantizer = None
# Decoder setup
self.audio_decoder = instantiate(cfg.audio_decoder)
# Discriminator setup
self.discriminator = instantiate(cfg.discriminator)
# Mel loss setup
loss_resolutions = cfg.loss_resolutions
mel_loss_dims = cfg.get("mel_loss_dims")
mel_loss_log_guard = cfg.get("mel_loss_log_guard", 1.0)
self.mel_loss_l1_scale = cfg.get("mel_loss_l1_scale", 1.0)
self.mel_loss_l2_scale = cfg.get("mel_loss_l2_scale", 1.0)
self.mel_loss_fn = MultiResolutionMelLoss(
sample_rate=self.sample_rate,
mel_dims=mel_loss_dims,
resolutions=loss_resolutions,
log_guard=mel_loss_log_guard,
)
# STFT loss setup
stft_loss_log_guard = cfg.get("stft_loss_log_guard", 1.0)
self.stft_loss_scale = cfg.get("stft_loss_scale", 0.0)
self.stft_loss_fn = MultiResolutionSTFTLoss(
resolutions=loss_resolutions,
log_guard=stft_loss_log_guard,
)
# Time domain loss setup
self.time_domain_loss_scale = cfg.get("time_domain_loss_scale", 1.0)
self.si_sdr_loss_scale = cfg.get("si_sdr_loss_scale", 0.0)
self.time_domain_loss_fn = TimeDomainLoss()
self.si_sdr_loss_fn = SISDRLoss()
# Discriminator loss setup
self.gen_loss_scale = cfg.get("gen_loss_scale", 1.0)
self.feature_loss_scale = cfg.get("feature_loss_scale", 1.0)
self.gen_loss_fn = instantiate(cfg.generator_loss)
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0)
if "mmd_loss" in cfg:
self.mmd_loss_fn = instantiate(cfg.mmd_loss)
self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0)
else:
self.mmd_loss_fn = None
self.mmd_loss_scale = None
if "mmd_time_loss" in cfg:
self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss)
self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0)
else:
self.mmd_time_loss_fn = None
self.mmd_time_loss_scale = None
feature_loss_type = cfg.get("feature_loss_type", "relative")
if feature_loss_type == "relative":
self.feature_loss_fn = RelativeFeatureMatchingLoss()
elif feature_loss_type == "absolute":
self.feature_loss_fn = FeatureMatchingLoss()
else:
raise ValueError(f'Unknown feature loss type {feature_loss_type}.')
# Codebook loss setup
if self.vector_quantizer:
self.commit_loss_scale = cfg.get("commit_loss_scale", 1.0)
else:
self.commit_loss_scale = 0.0
if self.commit_loss_scale > 0 and not self.vector_quantizer_has_commit_loss:
raise ValueError('Commit loss is enabled but the quantizer does not support it.')
self.use_scl_loss = cfg.get("use_scl_loss", False)
self.scl_loss_scale = cfg.get("scl_loss_scale", False)
if self.use_scl_loss:
self.speaker_encoder = ResNetSpeakerEncoder()
# load pretrained model
# self.speaker_encoder.load_checkpoint("https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar")
self.speaker_encoder.load_checkpoint(
"https://huggingface.co/Edresson/Speaker_Encoder_H_ASP/resolve/main/pytorch_model.bin", strict=False
)
# freeze the pretrained speaker encoder
self.speaker_encoder.freeze()
print("Speaker encoder loaded and frozen !!")
# Disabled for now as it is not used in final model
self.use_asr_consitency_loss = False
self.acl_loss_scale = False
# self.use_asr_consitency_loss = cfg.get("use_asr_consitency_loss", False)
# self.acl_loss_scale = cfg.get("acl_loss_scale", False)
# if self.use_asr_consitency_loss:
# self.phoneme_asr_model = PhonemeASR(input_sr=self.sample_rate)
# self.phoneme_asr_model.freeze()
# # self.acl_loss = CrossEntropyLoss()
# print("Phoneme ASR model loaded and frozen !!")
# Log setup
self.log_config = cfg.get("log_config", None)
# Optimizer setup
self.lr_schedule_interval = None
self.automatic_optimization = False
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def num_codebooks(self):
if self.vector_quantizer is None:
raise ValueError("This AudioCodecModel does not have a vector quantizer.")
return self.vector_quantizer.num_codebooks
@property
def codebook_size(self):
if self.vector_quantizer is None:
raise ValueError("This AudioCodecModel does not have a vector quantizer.")
return self.vector_quantizer.codebook_size
def state_dict(self, destination=None, prefix='', keep_vars=False):
if hasattr(self, '_no_state_dict') and self._no_state_dict:
return {}
# Don't save the speaker verification and codec model in the state dict
state_dict = super().state_dict(destination, prefix, keep_vars)
for key in list(state_dict.keys()):
if self.use_scl_loss and "speaker_encoder." in key:
del state_dict[key]
if "discriminator" in key and ".slm_model.ssl_model." in key:
del state_dict[key]
return state_dict
def load_state_dict(self, state_dict, strict=True):
# Override to load all the keys except .speaker_encoder. and WavLM model
for key in list(state_dict.keys()):
if self.use_scl_loss and "speaker_encoder." in key:
del state_dict[key]
if "discriminator" in key and ".slm_model.ssl_model." in key:
del state_dict[key]
super().load_state_dict(state_dict, strict=False)
def get_speaker_embedding(self, audio, requires_grad=False):
if not requires_grad:
with torch.no_grad():
if HAVE_TORCHAUDIO:
audio_resampled = torchaudio.functional.resample(
audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"]
)
else:
logging.error('Could not import torchaudio!')
raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!")
g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1)
else:
if HAVE_TORCHAUDIO:
audio_resampled = torchaudio.functional.resample(
audio, self.sample_rate, self.speaker_encoder.audio_config["sample_rate"]
)
else:
logging.error('Could not import torchaudio!')
raise ModuleNotFoundError("torchaudio is not installed but is necessary to audio resample !!")
g = self.speaker_encoder(audio_resampled, l2_norm=True).unsqueeze(-1)
return g
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode_audio(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder on the input audio signal. Input will be padded with zeros so
the last frame has full `self.samples_per_frame` samples.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Encoder output `encoded` and its length in number of frames `encoded_len`
"""
audio, audio_len = self.pad_audio(audio, audio_len)
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
return encoded, encoded_len
@typecheck(
input_types={
"inputs": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"input_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode_audio(self, inputs: torch.Tensor, input_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation.
Args:
inputs: encoded signal
input_len: valid length for each example in the batch
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
return audio, audio_len
@typecheck(
input_types={
"encoded": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
"encoded_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex())},
)
def quantize(self, encoded: torch.Tensor, encoded_len: torch.Tensor) -> torch.Tensor:
"""Quantize the continuous encoded representation into a discrete
representation for each frame.
Args:
encoded: encoded signal representation
encoded_len: valid length of the encoded representation in frames
Returns:
A tensor of tokens for each codebook for each frame.
"""
if not self.vector_quantizer:
raise ValueError("Cannot quantize without quantizer")
# vector quantizer is returning [C, B, T], where C is the number of codebooks
with default_precision(torch.float32):
# vector quantizer is returning [C, B, T], where C is the number of codebooks
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
# use batch first for the output
tokens = rearrange(tokens, 'C B T -> B C T')
return tokens
@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"dequantized": NeuralType(('B', 'D', 'T_encoded'), EncodedRepresentation()),
},
)
def dequantize(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> torch.Tensor:
"""Convert the discrete tokens into a continuous encoded representation.
Args:
tokens: discrete tokens for each codebook for each time frame
tokens_len: valid length of each example in the batch
Returns:
Continuous encoded representation of the discrete input representation.
"""
if not self.vector_quantizer:
raise ValueError("Cannot dequantize without quantizer")
# vector quantizer is using [C, B, T], where C is the number of codebooks
tokens = rearrange(tokens, 'B C T -> C B T')
with default_precision(torch.float32):
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
dequantized = dequantized.to(self.dtype) # make sure dequantized is in the right dtype
return dequantized
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
)
def encode(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert input time-domain audio signal into a discrete representation (tokens).
Args:
audio: input time-domain signal, shape `(batch, number of samples)`
audio_len: valid length for each example in the batch, shape `(batch size,)`
Returns:
Tokens for each codebook for each frame, shape `(batch, number of codebooks, number of frames)`,
and the corresponding valid lengths, shape `(batch,)`
"""
# Apply encoder to obtain a continuous vector for each frame
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
# Apply quantizer to obtain discrete representation per frame
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
return tokens, encoded_len
@typecheck(
input_types={
"tokens": NeuralType(('B', 'C', 'T_encoded'), TokenIndex()),
"tokens_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def decode(self, tokens: torch.Tensor, tokens_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert discrete tokens into a continuous time-domain signal.
Args:
tokens: discrete tokens for each codebook for each time frame, shape `(batch, number of codebooks, number of frames)`
tokens_len: valid lengths, shape `(batch,)`
Returns:
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
# Convert a discrete representation to a dequantized vector for each frame
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
dequantized = dequantized.to(self.dtype) # make sure that the dequantized is in the model dtype
# Apply decoder to obtain time-domain audio for each frame
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)
return audio, audio_len
@typecheck(
input_types={
"audio": NeuralType(('B', 'T_audio'), AudioSignal()),
"audio_len": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"output_audio": NeuralType(('B', 'T_audio'), EncodedRepresentation()),
"output_audio_len": NeuralType(tuple('B'), LengthsType()),
},
)
def forward(self, audio: torch.Tensor, audio_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply encoder, quantizer, decoder on the input time-domain signal.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Reconstructed time-domain signal `output_audio` and its length in number of samples `output_audio_len`.
"""
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
if self.vector_quantizer:
# quantize to discrete tokens
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
# decode tokens to audio
output_audio, output_audio_len = self.decode(tokens=tokens, tokens_len=encoded_len)
else:
# no quantization, directly decode to audio
output_audio, output_audio_len = self.decode_audio(inputs=encoded, input_len=encoded_len)
return output_audio, output_audio_len
def pad_audio(self, audio, audio_len):
"""Zero pad the end of the audio so that we do not have a partial end frame.
The output will be zero-padded to have an integer number of frames of
length `self.samples_per_frame`.
Args:
audio: input time-domain signal
audio_len: valid length for each example in the batch
Returns:
Padded time-domain signal `padded_audio` and its length `padded_len`.
"""
padded_len = self.samples_per_frame * torch.ceil(audio_len / self.samples_per_frame).int()
max_len = padded_len.max().item()
num_padding = max_len - audio.shape[1]
padded_audio = F.pad(audio, (0, num_padding))
return padded_audio, padded_len
def _process_batch(self, batch):
# [B, T_audio]
audio = batch.get("audio")
# [B]
audio_len = batch.get("audio_lens")
audio, audio_len = self.pad_audio(audio, audio_len)
# [B, D, T_encoded]
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
if self.encoder_noise is not None:
encoded = self.encoder_noise(encoded)
if self.vector_quantizer:
with default_precision(torch.float32):
if self.vector_quantizer_has_commit_loss:
encoded, _, commit_loss = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
else:
encoded, _ = self.vector_quantizer(inputs=encoded, input_len=encoded_len)
commit_loss = 0.0
encoded = encoded.to(encoded.dtype) # make sure encoded is converted to the right dtype
else:
commit_loss = 0.0
# [B, T]
encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype
audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len)
return audio, audio_len, audio_gen, commit_loss, encoded
@property
def disc_update_prob(self) -> float:
"""Probability of updating the discriminator."""
return self.disc_updates_per_period / self.disc_update_period
def should_update_disc(self, batch_idx) -> bool:
"""Decide whether to update the descriminator based
on the batch index and configured discriminator update period.
"""
disc_update_step = batch_idx % self.disc_update_period
return disc_update_step < self.disc_updates_per_period
def training_step(self, batch, batch_idx):
optim_gen, optim_disc = self.optimizers()
audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch)
metrics = {
"global_step": self.global_step,
"lr": optim_gen.param_groups[0]['lr'],
}
if self.should_update_disc(batch_idx):
# Train discriminator
disc_scores_real, disc_scores_gen, _, _ = self.discriminator(
audio_real=audio, audio_gen=audio_gen.detach()
)
loss_disc = self.disc_loss_fn(disc_scores_real=disc_scores_real, disc_scores_gen=disc_scores_gen)
metrics["d_loss"] = loss_disc
optim_disc.zero_grad()
self.manual_backward(loss_disc)
optim_disc.step()
generator_losses = []
# stft does not support bf16, so make it run in fp32
loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(
audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len
)
if self.mel_loss_l1_scale:
metrics["g_loss_mel_l1"] = loss_mel_l1
generator_losses.append(self.mel_loss_l1_scale * loss_mel_l1)
if self.mel_loss_l2_scale:
metrics["g_loss_mel_l2"] = loss_mel_l2
generator_losses.append(self.mel_loss_l2_scale * loss_mel_l2)
if self.stft_loss_scale:
loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len)
metrics["g_loss_stft"] = loss_stft
generator_losses.append(self.stft_loss_scale * loss_stft)
if self.time_domain_loss_scale:
loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
metrics["g_loss_time_domain"] = loss_time_domain
generator_losses.append(self.time_domain_loss_scale * loss_time_domain)
if self.si_sdr_loss_scale:
loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
metrics["g_loss_si_sdr"] = loss_si_sdr
generator_losses.append(self.si_sdr_loss_scale * loss_si_sdr)
_, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen)
if self.gen_loss_scale:
loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen)
metrics["g_loss_gen"] = loss_gen
generator_losses.append(self.gen_loss_scale * loss_gen)
if self.feature_loss_scale:
loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen)
metrics["g_loss_feature"] = loss_feature
generator_losses.append(self.feature_loss_scale * loss_feature)
if self.commit_loss_scale:
metrics["g_loss_commit"] = commit_loss
generator_losses.append(self.commit_loss_scale * commit_loss)
if self.mmd_loss_scale:
loss_mmd = self.mmd_loss_fn(inputs=codes)
metrics["g_loss_mmd"] = loss_mmd
if self.current_epoch >= self.mmd_loss_start_epoch:
generator_losses.append(self.mmd_loss_scale * loss_mmd)
if self.mmd_time_loss_scale:
loss_mmd_time = self.mmd_time_loss_fn(inputs=codes)
metrics["g_loss_mmd_time"] = loss_mmd_time
if self.current_epoch >= self.mmd_loss_start_epoch:
generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time)
# compute embeddings for speaker consistency loss
if self.use_scl_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
# get speaker embeddings with grads
pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
# speaker consistency loss like YourTTS paper
loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale
metrics["g_loss_scl"] = loss_scl
generator_losses.append(metrics["g_loss_scl"])
if self.use_asr_consitency_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
logits, _ = self.phoneme_asr_model(audios_batch)
logits_gt, logits_pred = torch.chunk(logits, 2, dim=0)
# labels_gt, labels_pred = torch.chunk(labels, 2, dim=0)
loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale
metrics["g_loss_acl"] = loss_acl
generator_losses.append(metrics["g_loss_acl"])
loss_gen_all = sum(generator_losses)
optim_gen.zero_grad()
self.manual_backward(loss_gen_all)
optim_gen.step()
self.update_lr()
self.log_dict(metrics, on_step=True, sync_dist=True)
self.log("t_loss", loss_mel_l1, prog_bar=True, logger=False, sync_dist=True)
def on_train_epoch_end(self):
self.update_lr("epoch")
def validation_step(self, batch, batch_idx):
audio, audio_len, audio_gen, _, _ = self._process_batch(batch)
loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(
audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len
)
loss_stft = self.stft_loss_fn(audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len)
loss_time_domain = self.time_domain_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
loss_si_sdr = self.si_sdr_loss_fn(audio_real=audio, audio_gen=audio_gen, audio_len=audio_len)
# Use only main reconstruction losses for val_loss
val_loss = loss_mel_l1 + loss_stft + loss_time_domain
metrics = {
"val_loss": val_loss,
"val_loss_mel_l1": loss_mel_l1,
"val_loss_mel_l2": loss_mel_l2,
"val_loss_stft": loss_stft,
"val_loss_time_domain": loss_time_domain,
"val_loss_si_sdr": loss_si_sdr,
}
# compute embeddings for speaker consistency loss
if self.use_scl_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
# get speaker embeddings with grads
pred_embs = self.get_speaker_embedding(audios_batch, requires_grad=True)
# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
# speaker consistency loss like YourTTS paper
loss_scl = -1 * torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.scl_loss_scale
metrics["val_loss_scl"] = loss_scl
metrics["val_loss"] += metrics["val_loss_scl"]
if self.use_asr_consitency_loss:
# concate generated and GT waveforms
audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0)
logits, _ = self.phoneme_asr_model(audios_batch)
logits_gt, logits_pred = torch.chunk(logits, 2, dim=0)
loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale
metrics["val_loss_acl"] = loss_acl
metrics["val_loss"] += metrics["val_loss_acl"]
self.log_dict(metrics, on_epoch=True, sync_dist=True)
def get_dataset(self, cfg):
with open_dict(cfg):
is_sharded = cfg.dataset.pop('is_sharded', False)
if is_sharded:
with open_dict(cfg):
cfg.dataset.global_rank = self.global_rank
cfg.dataset.world_size = self.world_size
cfg.dataset._target_ = 'nemo.collections.tts.data.vocoder_dataset.TarredVocoderDataset'
dataset = instantiate(cfg.dataset)
sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
return dataset, sampler
def _setup_train_dataloader(self, cfg):
dataset, sampler = self.get_dataset(cfg)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params
)
return data_loader
def _setup_test_dataloader(self, cfg):
dataset = instantiate(cfg.dataset)
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
return data_loader
def setup_training_data(self, cfg):
self._train_dl = self._setup_train_dataloader(cfg)
batch_size = cfg['dataloader_params']['batch_size']
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if (
self._train_dl is not None
and hasattr(self._train_dl, 'dataset')
and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset)
):
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float):
self._trainer.limit_train_batches = int(
self._trainer.limit_train_batches
* ceil((len(self._train_dl.dataset) / self.world_size) / batch_size)
)
elif self._trainer is None:
logging.warning(
"Model Trainer was not set before constructing the dataset, incorrect number of "
"training batches will be used. Please set the trainer and rebuild the dataset."
)
def setup_validation_data(self, cfg):
self._validation_dl = self._setup_test_dataloader(cfg)
def setup_test_data(self, cfg):
pass
@property
def max_steps(self):
if "max_steps" in self._cfg:
return self._cfg.get("max_steps")
if "max_epochs" not in self._cfg:
raise ValueError("Must specify 'max_steps' or 'max_epochs'.")
if "steps_per_epoch" in self._cfg:
return self._cfg.max_epochs * self._cfg.steps_per_epoch
return compute_max_steps(
max_epochs=self._cfg.max_epochs,
accumulate_grad_batches=self.trainer.accumulate_grad_batches,
limit_train_batches=self.trainer.limit_train_batches,
num_workers=get_num_workers(self.trainer),
num_samples=len(self._train_dl.dataset),
batch_size=get_batch_size(self._train_dl),
drop_last=self._train_dl.drop_last,
)
def configure_optimizers(self):
optim_config = self._cfg.optim.copy()
OmegaConf.set_struct(optim_config, False)
sched_config = optim_config.pop("sched", None)
OmegaConf.set_struct(optim_config, True)
asr_ph_params = self.phoneme_asr_model.parameters() if self.use_asr_consitency_loss else []
se_params = self.speaker_encoder.parameters() if self.use_scl_loss else []
vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else []
gen_params = itertools.chain(
self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params, asr_ph_params, se_params
)
optim_g = instantiate(optim_config, params=gen_params)
disc_params = self.discriminator.parameters()
optim_d = instantiate(optim_config, params=disc_params)
if sched_config is None:
logging.debug('Scheduler is not used')
return [optim_g, optim_d]
logging.debug('Setting up schedulers')
OmegaConf.set_struct(sched_config, False)
sched_config["max_steps"] = self.max_steps
OmegaConf.set_struct(sched_config, True)
scheduler_g = prepare_lr_scheduler(
optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl
)
scheduler_d = prepare_lr_scheduler(
optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl
)
self.lr_schedule_interval = scheduler_g["interval"]
return [optim_g, optim_d], [scheduler_g, scheduler_d]
def update_lr(self, interval="step"):
schedulers = self.lr_schedulers()
if schedulers is not None and self.lr_schedule_interval == interval:
sch1, sch2 = schedulers
sch1.step()
sch2.step()
def configure_callbacks(self):
if not self.log_config:
return []
data_loader = self._setup_test_dataloader(self.log_config)
generators = instantiate(self.log_config.generators)
log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None
log_callback = LoggingCallback(
generators=generators,
data_loader=data_loader,
log_epochs=self.log_config.log_epochs,
epoch_frequency=self.log_config.epoch_frequency,
output_dir=log_dir,
loggers=self.trainer.loggers,
log_tensorboard=self.log_config.log_tensorboard,
log_wandb=self.log_config.log_wandb,
)
return [log_callback]
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
models = []
model = PretrainedModelInfo(
pretrained_model_name="audio_codec_16khz_small",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/audio_codec_16khz_small",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_22khz_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_44khz_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_22khz_fullband_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium",
)
models.append(model)
model = PretrainedModelInfo(
pretrained_model_name="mel_codec_44khz_fullband_medium",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo",
description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium",
)
models.append(model)
return models
|