subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
import time
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Union
import numpy as np
import soundfile as sf
import torch
import wandb
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torch import nn
from torch.utils.data import get_worker_info
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers
from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss
from nemo.collections.tts.models import AudioCodecModel
from nemo.collections.tts.modules import transformer_2501
from nemo.collections.tts.modules.aligner import AlignmentEncoder
from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter
from nemo.collections.tts.modules.magpietts_modules import (
CharAwareSubwordEncoder,
EOSDetectionMethod,
LocalTransformerType,
SpecialAudioToken,
cosine_schedule,
)
from nemo.collections.tts.parts.utils.helpers import (
binarize_attention_parallel,
get_mask_from_lengths,
plot_alignment_to_numpy,
)
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
@dataclass
class InferBatchOutput:
"""Output dataclass for MagpieTTS infer_batch method.
This provides a consistent return type regardless of which optional outputs
are requested.
Attributes:
predicted_audio: Generated audio waveforms. Shape: (B, T_audio).
predicted_audio_lens: Length of each audio in samples. Shape: (B,).
predicted_codes: Generated audio codec tokens. Shape: (B, num_codebooks, T_frames).
predicted_codes_lens: Length of each code sequence in frames. Shape: (B,).
rtf_metrics: Dictionary containing real-time factor and timing metrics.
cross_attention_maps: Optional cross-attention visualization maps.
List of numpy arrays, one per batch item. Only populated if
return_cross_attn_probs=True.
headwise_cross_attention_maps: Optional per-head cross-attention maps.
Only populated if return_cross_attn_probs=True and
compute_all_heads_attn_maps=True.
"""
predicted_audio: torch.Tensor
predicted_audio_lens: torch.Tensor
predicted_codes: torch.Tensor
predicted_codes_lens: torch.Tensor
rtf_metrics: Dict[str, Any]
cross_attention_maps: Optional[List[Any]] = None
headwise_cross_attention_maps: Optional[List[Any]] = None
def worker_init_fn(worker_id):
# For mp.set_start_method("spawn", force=True)
# The dataset class should be picklable, so we initialize non-picklable objects here
logging.info(f"Worker {worker_id} initializing...")
worker_info = get_worker_info()
dataset = worker_info.dataset # Get the dataset instance in this worker
tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type)
dataset.text_tokenizer = tokenizer
class MagpieTTSModel(ModelPT):
"""
Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context
audio/text
Supports multiple model types:
- multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to
layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by
context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio.
Only one of context audio or contex text is supported.
- decoder_context_tts: Text goes into the encoder; context & target audio go to the decoder. Also supports text
context. Supports fixed sized context so we set context_duration_min and context_duration_max to the same
value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is
padded to the max context duration in this model.
- decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and
the decoder input.
"""
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices
# load codec, disable loading of loss modules not needed during inference
codec_model_path = cfg.get('codecmodel_path')
if codec_model_path.startswith('nvidia/'):
codec_model = AudioCodecModel.from_pretrained(codec_model_path)
else:
codec_model_cfg = AudioCodecModel.restore_from(codec_model_path, return_config=True)
if "use_scl_loss" in codec_model_cfg:
codec_model_cfg.use_scl_loss = False
codec_model = AudioCodecModel.restore_from(
codec_model_path, strict=False, override_config_path=codec_model_cfg
)
self.sample_rate = codec_model.sample_rate
self.codec_model_samples_per_frame = codec_model.samples_per_frame
# del codec discriminator to free memory
del codec_model.discriminator
# When using FSQ tokens, the codebook structure can be changed at any time.
# An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure
# that is different than in the audio codec checkpoint.
vector_quantizer = cfg.get('vector_quantizer')
if vector_quantizer is not None:
vector_quantizer = instantiate(vector_quantizer)
num_audio_codebooks = vector_quantizer.num_codebooks
codebook_size = vector_quantizer.codebook_size
codec_converter = VectorQuantizerIndexConverter(
vector_quantizer_original=codec_model.vector_quantizer,
vector_quantizer_new=vector_quantizer,
)
data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks
else:
num_audio_codebooks = codec_model.num_codebooks
data_num_audio_codebooks = num_audio_codebooks
codebook_size = codec_model.codebook_size
codec_converter = None
# The dataloader needs to know the number of codebooks that the context codes were stored in
# In the case where there are no context codes saved, and there is no context audio (in the text context path),
# We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for
# data_num_audio_codebooks
self.data_num_audio_codebooks = data_num_audio_codebooks
self.num_audio_codebooks = num_audio_codebooks
self.codebook_size = codebook_size
# Our codebooks start with actual audio codec tokens, followed by special tokens.
# The `forced_*` options are for backward compatibility for models trained with older code.
get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size)
self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS))
self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS))
self.context_audio_bos_id = cfg.get(
'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS)
)
self.context_audio_eos_id = cfg.get(
'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS)
)
self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN))
self.num_all_tokens_per_codebook = cfg.get(
'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken)
)
self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False)
# The frame stacking factor controls how many consecutive frames are processed together by the base decoder
# (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no
# frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking
# factor of 3, the entries of codebook 0 appear 3 times in the embedding table.
self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1)
assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead'
# Setup tokenizer
if hasattr(cfg, 'text_tokenizer'):
# For backward compatibility for English-only models
with open_dict(cfg):
cfg.text_tokenizers = {"english_phoneme": cfg.text_tokenizer}
del cfg['text_tokenizer']
self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False)
# Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility.
self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None)
self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False)
if self.legacy_text_conditioning:
if self.text_conditioning_tokenizer_name is None:
self.text_conditioning_tokenizer_name = "google-t5/t5-small"
tokenizer_target = "AutoTokenizer"
if self.text_conditioning_tokenizer_name == "google-t5/t5-small":
tokenizer_target = "T5Tokenizer"
with open_dict(cfg):
cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = {
'_target_': tokenizer_target,
'pretrained_model': self.text_conditioning_tokenizer_name,
}
elif self.text_conditioning_tokenizer_name is None:
# If no text_conditioning_tokenizer_name is specified, use the first one as default
# For text context tokenization
self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0]
# TODO @xueyang: both tokenizers are only used to get some token ids. We
# should kill them to save a small amount of mem resources since dataloader will initialize them
# again after the worker processes are spawned.
self.tokenizer = setup_tokenizers(
all_tokenizers_config=cfg.text_tokenizers,
mode='train',
)
num_tokens_tokenizer = len(self.tokenizer.tokens)
if self.legacy_text_conditioning:
# Text context tokens are not a part of the the regular transcript embedding table in legacy models
num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name]
num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS
self.bos_id = num_tokens - 2
self.eos_id = num_tokens - 1
self.model_type = cfg.get('model_type', None)
self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce']
self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False)
# Below args (text_context_remapping_json, text_context_remapping_prob) are
# for combining multiple context_texts into a single one during training.
# Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker,
# we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'}
# This dict is saved in a json file given by cfg.model.text_context_remapping_json
# If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma)
# we can do this mapping with a probability during training, as specified by text_context_remapping_prob
self.text_context_remapping = None
text_context_remapping_json = cfg.get('text_context_remapping_json', None)
self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0)
if text_context_remapping_json is not None:
with open(text_context_remapping_json, 'r') as f:
self.text_context_remapping = json.load(f)
super().__init__(cfg=cfg, trainer=trainer)
if self.legacy_text_conditioning:
tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name]
self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim)
# This needs to happen after super().__init__()
self._codec_model = codec_model
self._codec_model.freeze() # Lightning does requires_grad = False and self.eval()
self._codec_converter = codec_converter
audio_embeddings = []
for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim))
self.audio_embeddings = nn.ModuleList(audio_embeddings)
if self.use_bpe_char_tokenizer:
# BPE char tokenizer
assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer"
tokenizer_name = self.tokenizer.tokenizer_names[0]
tokenizer = self.tokenizer.tokenizers[tokenizer_name]
subword_vocab = tokenizer.get_vocab()
# special tokens will be stored as it is in the char_vocab
# Each special token will only be mapped to one char id
special_vocab = {
'<BOS>': self.bos_id,
'<EOS>': self.eos_id,
}
self.cas_encoder = CharAwareSubwordEncoder(
d_embed=cfg.embedding_dim,
llm_tokenizer_vocab=subword_vocab,
subword_padding_idx=self.tokenizer.pad,
special_vocab=special_vocab,
)
else:
# Regular text embedding
self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim)
self.encoder = transformer_2501.Transformer(**dict(cfg.encoder))
self.decoder = transformer_2501.Transformer(**dict(cfg.decoder))
self.final_proj = nn.Linear(
cfg.decoder.d_model,
self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor,
)
self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower())
logging.info(f"Local transformer type: {self.local_transformer_type}")
if self.local_transformer_type != LocalTransformerType.NO_LT:
local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256)
if local_transformer_hidden_dim != cfg.decoder.d_model:
self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim)
else:
self.local_transformer_in_projection = nn.Identity()
self.local_transformer = transformer_2501.Transformer(
n_layers=self.cfg.get('local_transformer_n_layers', 2),
d_model=local_transformer_hidden_dim,
d_ffn=local_transformer_hidden_dim * 4,
sa_n_heads=self.cfg.get('local_transformer_n_heads', 1),
kernel_size=1,
is_causal=self.local_transformer_type == LocalTransformerType.AR,
max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2,
use_learnable_pos_emb=True,
)
local_transformer_out_projections = []
for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
# Have a separate projection layer for each codebook, to distinguish between them
local_transformer_out_projections.append(
nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)
)
self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections)
if cfg.get('use_alignment_encoder', False):
self.alignment_encoder = AlignmentEncoder(
n_mel_channels=cfg.embedding_dim,
n_text_channels=cfg.embedding_dim,
dist_type="cosine",
temperature=15.0,
)
if self.model_type == 'multi_encoder_context_tts':
logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.")
# Transcript and context audio/text go to different encoders.
# Output of the encoders goes to the decoder through the cross-attention layers
self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8])
self.context_decoder_layers = cfg.get(
'context_decoder_layers', [0, 1, 2, 9, 10, 11]
) # For backward compatibility
multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)]
for layer in self.transcript_decoder_layers:
multi_encoder_mapping[layer] = 0 # 0 means text goes to this layer, 1 means context goes to this layer
for layer in self.context_decoder_layers:
multi_encoder_mapping[layer] = 1
self.multi_encoder_mapping = multi_encoder_mapping
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
elif self.model_type == 'decoder_context_tts':
# Context audio/text goes directly to the decoder (before the target audio codes)
self.transcript_decoder_layers = [
idx for idx in range(self.decoder.n_layers)
] # All layers are used for text
elif self.model_type == 'decoder_ce':
# Similar to decoder_context_tts, but we use context encoder
# Decoder gets output from context encoder instead of raw context tokens embeddings
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
self.transcript_decoder_layers = [
idx for idx in range(cfg.decoder.n_layers)
] # All layers are used for text
# Register buffers for baked context embedding (initially None/empty)
# These will be populated when loading a checkpoint with baked embedding
self.register_buffer('baked_context_embedding', None)
self.register_buffer('baked_context_embedding_len', None)
else:
raise ValueError(f"Unsupported model type {self.model_type}")
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0)
self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0)
if self.alignment_loss_scale > 0.0:
self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale)
if self.alignment_encoder_loss_scale > 0.0:
self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale)
# Define cfg parameters into self parameters
self.prior_end_step = self.cfg.prior_end_step
self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step
self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0)
self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers)
self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0)
self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0)
self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax')
self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2)
self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0)
self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0)
self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0)
self.prior_future_context = self.cfg.get('prior_future_context', 1)
self.prior_past_context = self.cfg.get('prior_past_context', 1)
self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0)
self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0)
self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0)
self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False)
self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False)
self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf'))
self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook)
# Configuration validity checks
self.check_frame_stacking_config_validity()
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""
Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model
from the checkpoint. The codec model is saved in a separate checkpoint.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
model_type that is no longer supported and can likely be removed in a future version.
If the model has a baked context embedding, the context_encoder weights are also excluded
since they are no longer needed for inference.
"""
if hasattr(self, '_no_state_dict') and self._no_state_dict:
return {}
# Don't save the speaker verification and codec model in the state dict
state_dict = super().state_dict(destination, prefix, keep_vars)
keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model']
# If we have a baked context embedding, exclude context_encoder weights
if self.has_baked_context_embedding:
keys_substrings_to_exclude.append('context_encoder')
for key in list(state_dict.keys()):
if any([substring in key for substring in keys_substrings_to_exclude]):
del state_dict[key]
return state_dict
def check_frame_stacking_config_validity(self):
"""
Check if the configuration is compatible with frame stacking.
"""
if self.frame_stacking_factor > 1:
# The settings below are not supported with frame stacking.
# Some of them may work - but they have not been tested.
# disallow alignment encoder
if self.use_alignment_encoder:
raise ValueError("Alignment encoder is not supported for frame stacking")
# disallow alignment loss
if self.alignment_loss_scale > 0.0:
raise ValueError("Alignment loss is not supported for frame stacking")
# disallow training prior
if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0:
raise ValueError("Training-time attention prior is not supported for frame stacking")
# disallow text conditioning
if self.use_text_conditioning_encoder:
raise ValueError("Text conditioning is not supported for frame stacking")
@property
def has_baked_context_embedding(self) -> bool:
"""Check if the model has a baked context embedding.
Returns:
True if baked_context_embedding buffer is set, not None, and has elements.
"""
return (
self.model_type == 'decoder_ce'
and hasattr(self, 'baked_context_embedding')
and self.baked_context_embedding is not None
and self.baked_context_embedding.numel() > 0
)
def update_ckpt(self, state_dict):
"""
Backward compatibility for checkpoints saved with old model names.
"""
new_state_dict = {}
for key in state_dict.keys():
if 't5_encoder' in key:
new_key = key.replace('t5_encoder', 'encoder')
new_state_dict[new_key] = state_dict[key]
elif 't5_decoder' in key:
new_key = key.replace('t5_decoder', 'decoder')
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
def load_state_dict(self, state_dict, strict=True):
"""
Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when
strict is True.
When strict is False, we can call pytorch's load_state_dict.
When strict is True, we loop through all parameters and rename them to enable loading.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
model_type that is no longer supported and can likely be removed in a future version.
Also handles loading baked context embeddings. If the checkpoint contains baked_context_embedding,
context_encoder weights are not expected to be present.
"""
state_dict = self.update_ckpt(state_dict)
# Check if checkpoint has baked context embedding
has_baked_embedding_in_ckpt = (
'baked_context_embedding' in state_dict and state_dict['baked_context_embedding'] is not None
)
# Load baked embedding buffers if present
if has_baked_embedding_in_ckpt:
self.baked_context_embedding = state_dict['baked_context_embedding']
self.baked_context_embedding_len = state_dict['baked_context_embedding_len']
logging.info(
f"Loaded baked context embedding with shape {self.baked_context_embedding.shape}, "
f"length {self.baked_context_embedding_len.item()}"
)
if not strict:
super().load_state_dict(state_dict, strict=False)
# Build list of modules to skip
modules_to_skip = [
'_speaker_verification_model',
'_codec_model',
'_reference_model',
'eval_asr_model',
'eval_speaker_verification_model',
'whisper_model',
'squim_objective_model',
]
# Skip context_encoder if checkpoint has baked embedding (weights won't be in checkpoint)
if has_baked_embedding_in_ckpt:
modules_to_skip.append('context_encoder')
for name, child in self.named_children():
if name in modules_to_skip:
continue
if any(param.numel() > 0 for param in child.parameters()):
# If the module has parameters, we want to change the default mapping so that the state_dict gets
# loaded.
# Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight]
new_state_dict = {}
for key in state_dict.keys():
name_with_dot = f"{name}."
if key.startswith(name_with_dot):
new_state_dict[key[len(name_with_dot) :]] = state_dict[key]
child.load_state_dict(new_state_dict)
def audio_to_codes(self, audio, audio_len, audio_type='target'):
# audio: (B, T)
# audio_len: (B,)
if audio_type == 'target':
audio_eos_id = self.audio_eos_id
audio_bos_id = self.audio_bos_id
elif audio_type == 'context':
audio_eos_id = self.context_audio_eos_id
audio_bos_id = self.context_audio_bos_id
else:
raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`")
self._codec_model.eval()
with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32):
codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len)
if self._codec_converter is not None:
codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len)
# Add a timestep to begining and end of codes tensor
bos_tensor = torch.full(
(codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device
)
# pad at the end to make room for the EOS token; the EOS token's actual position
# varies per batch element depending on each element's length.
pad_tensor = torch.full(
(codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device
) # 0 is the padding token in the audio codebook
codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1)
# codes: (B, C, T')
# codes_len: (B,)
for idx in range(codes.size(0)):
codes[idx, :, codes_len[idx] + 1] = audio_eos_id
codes_len = codes_len + 2 # +1 for bos and +1 for eos
return codes.long(), codes_len.long()
def codes_to_audio(self, codes, codes_len):
# codes: (B, C, T')
# codes_len: (B,)
self._codec_model.eval()
with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32):
# Make a copy to avoid modifying the original tensor if it's used elsewhere
codes_copy = codes.clone()
# Replace eos and bos tokens with padding in the copied tensor
codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token
codes_copy[codes == self.audio_eos_id] = 0
# Pass the modified integer token IDs
if self._codec_converter is not None:
codes_copy = self._codec_converter.convert_new_to_original(
audio_tokens=codes_copy, audio_lens=codes_len
)
audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len)
# audio: (B, T)
# audio_len: (B,)
return audio, audio_len
def embed_audio_tokens(self, audio_tokens):
B, C, T = audio_tokens.shape
audio_embedding = None
for i in range(self.frame_stacking_factor):
for c in range(C):
tokens = audio_tokens[:, c, i :: self.frame_stacking_factor]
embedding = self.audio_embeddings[c + i * C](tokens)
if audio_embedding is None:
audio_embedding = embedding
else:
audio_embedding += embedding
audio_embedding = audio_embedding / (C * self.frame_stacking_factor)
return audio_embedding
def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False):
"""
Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes.
This function is used in training and validation, not inference/sampling.
The sequence layout is slightly different between AR and MG modes, as shown in the diagram below,
(using an 8-codebook setup as an example):
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none |
| codebook | | | | | | | | | |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
| codebook | | | | | | | | | |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
| codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
dec_out: (B, T', E)
audio_codes_target: (B, C, T')
targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive)
if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit)
"""
C = self.num_audio_codebooks
dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E)
local_transformer_input = [dec_out_all]
# Build the teacher-forced input to the LT.
for fs_index in range(self.frame_stacking_factor):
for codebook_num in range(C):
# Collect ground truth codes for the current codebook and frame stack index combintation.
codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor] # (B, T')
# Individual timesteps are independently handled by the LT fold time into the batch dimension.
codes = codes.reshape(-1) # (B*T',)
# Embed the codes
codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E)
local_transformer_input.append(codebook_embedding)
# Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively.
local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E)
local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128)
_mask = torch.ones(
local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E)
if not targets_offset_by_one:
# for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc.
local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E)
else:
# for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc.
local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E)
all_code_logits = []
for fs_index in range(self.frame_stacking_factor):
for codebook_num in range(audio_codes_target.size(1)):
# Using a separate projection layer for each codebook (to distinguish between them)
# Checked the time - this loop is not taking much time (compared to the local transformer forward pass)
codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C](
local_transformer_output[:, codebook_num + fs_index * C, :]
) # (B*T', num_all_tokens_per_codebook)
all_code_logits.append(codebook_logits)
all_code_logits = torch.cat(
all_code_logits, dim=1
) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor)
all_code_logits = all_code_logits.view(
audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1
) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor)
return all_code_logits
def maskgit_create_random_mask(self, codes):
"""
Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN.
"""
# Codes: (B, C, T)
B, C, T = codes.shape
# get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right?
rand_values = torch.rand(B, T, device=codes.device)
# apply the cosine schedule
frac_masked = cosine_schedule(rand_values)
# how many positions to mask
n_masked = torch.ceil(frac_masked * C).long() # B,T
# The code further below is the vectorized version of this:
# for b in range(B):
# for t in range(T):
# if n_masked[b,t] > 0:
# # get a random permutation of the codebook indices
# perm = torch.randperm(C)
# # mask the top n_masked positions
# mask[b, perm[:n_masked[b,t]], t] = True
#
# Create random permutations
random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T)
# Create a mask tensor where each position indicates if it should be masked
mask_indices = torch.arange(C, device=codes.device).view(1, C, 1)
mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T)
# Apply the random permutations to the mask
mask = torch.gather(mask, 1, random_permutations)
return mask # (B, C, T)
def maskgit_apply_random_mask(self, codes):
# Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule.
# Codes: (B, C, T)
mask = self.maskgit_create_random_mask(codes)
# replace some tokens with MASK_TOKEN
codes_with_mask = torch.where(mask, self.mask_token_id, codes)
return codes_with_mask, mask
def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1):
"""
Computes the audio codebook loss. Used by
(1) The main Magpie-TTS transformer
(2) The local transformer, for both autoregressive and MaskGit methods
logits: (B, T', num_codebooks * num_tokens_per_codebook)
audio_codes: (B, C, T')
audio_codes_lens: (B,)
mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should
therefore be the only ones included in the loss computation (for MaskGit).
frame_stacking_factor: int, the stacking factor used in the model
"""
loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor)
if mask_tokens_mask is not None:
# For MaskGit we only compute loss for the masked tokens.
# *Both* conditions must be true:
# 1. the token is masked
# 2. the token is not padding
loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask
if not loss_mask.any():
# Without this we were very rarely getting NaNs in the loss
logging.warning("No tokens valid were found in compute_loss()!")
return torch.tensor(0.0, device=loss_mask.device), loss_mask
else:
# repeat loss mask for each codebook to simplify code below
loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1)
total_codebook_loss = None
for fs_index in range(frame_stacking_factor):
for codebook in range(audio_codes.size(1)):
si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook)
codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor] # (B, T')
codebook_loss = self.cross_entropy_loss(
codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T')
) # (B, T')
codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor]
codebook_loss = codebook_loss * codebook_loss_mask
if codebook_loss_mask.sum() == 0:
logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}")
continue
codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum()
if total_codebook_loss is None:
total_codebook_loss = codebook_loss
else:
total_codebook_loss = total_codebook_loss + codebook_loss
total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor)
return total_codebook_loss, loss_mask
def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping):
decoder_out = self.decoder(
dec_input_embedded,
dec_input_mask,
cond=cond,
cond_mask=cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=multi_encoder_mapping,
)
attn_probabilities = decoder_out['attn_probabilities']
all_code_logits = self.final_proj(decoder_out['output']) # (B, T', num_codebooks * num_tokens_per_codebook)
return all_code_logits, attn_probabilities, decoder_out['output']
def logits_to_audio_codes(self, all_code_logits, audio_codes_lens):
# all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook)
# audio_codes_lens: (B,)
all_preds = [[] for _ in range(self.frame_stacking_factor)]
for fs_index in range(self.frame_stacking_factor):
for idx in range(self.num_audio_codebooks):
si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = all_code_logits[:, :, si:ei]
codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook)
# argmax to get the tokens
codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T')
all_preds[fs_index].append(codebook_preds)
all_preds = [
torch.stack(p, dim=1) for p in all_preds
] # list of `frame_stacking_factor`` elements of shape (B,C,T) each
all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor
# undo the frame stacking
all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor
pred_max_len = all_preds.size(2)
real_max_len = audio_codes_lens.max()
assert (pred_max_len - real_max_len) < self.frame_stacking_factor
# trim padding introduced for frame stacking
all_preds = all_preds[:, :, :real_max_len]
audio_mask = get_mask_from_lengths(audio_codes_lens)
all_preds = all_preds * audio_mask.unsqueeze(1)
return all_preds
def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2):
"""
Visualize codes for analysis purposes
codes: (B, C)
"""
def code_to_str(code):
if code == mask_id:
return "M "
else:
return f"{code:04d} "
B, C = codes.shape
if B > 1:
logging.debug("Warning: visualizing only first batch element")
codes = codes.clone().detach().cpu().numpy()[0]
codes = [code_to_str(c) for c in codes]
output_str = ""
for i, c in enumerate(codes):
if (i) % (C / frame_stacking_rate) == 0:
output_str += "|timestep| "
output_str += c
logging.debug(output_str)
def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor:
"""
Sets logits of forbidden tokens to `-inf` so they will never be sampled.
Specifically, we forbid sampling of all special tokens except AUDIO_EOS
which is allowed by default.
Args:
logits: (B, C, num_audio_tokens_per_codebook)
forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens
from being sampled. Default: False.
"""
logits[
:,
:,
SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos),
] = float('-inf')
return logits
def local_transformer_sample_maskgit(
self,
dec_output: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
use_cfg: bool = False,
cfg_scale: float = 1.0,
n_steps: int = 3,
noise_scale: float = 0.0,
fixed_schedule: Optional[List[int]] = None,
dynamic_cfg_scale: bool = False,
sampling_type: Optional[str] = None,
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample audio codes for the current timestep using MaskGit-like iterative
prediction with the local transformer. If frame-stacking is enabled, the
codes for all frames in the stack are sampled, treated as one long sequence.
The MaskGit process starts with all positions masked and iteratively unmasks the
most confident positions over multiple steps. By "masked" we mean that a
dedicated MASK token is used (as opposed to attention masking). The LT in this
case is a non-causal transformer decoder. At each step the model predicts all
positions at once. Of those predictions, a subset of the most confident
previously-masked positions is kept and unmasked in the next step. The number of
positions that are unmasked at each step is determined by the unmasking
schedule. We support a cosine schedule and a fixed schedule provided by the
user.
Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS for all items in the batch.
This is useful early in the generation process.
* supports different unmasking methods, see `sampling_type` argument for details.
Args:
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
and E is primary decoder's embedding dimension.
temperature (float, optional): Sampling temperature
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling.
noise_scale (float, optional): Scale factor for noise to add to confidence scores
during sampling (experimental).
fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step.
If None, uses cosine schedule.
dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during
sampling (experimental).
sampling_type (str, optional): Type of sampling strategy. Options are:
["default", "causal", "purity_causal", "purity_default"].
* Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity"
is not specified, confidence sampling is used as in the original MaskGit paper.
* "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled.
If "causal" is specified, frames are unmasked in causal order. "default"
doesn't impose any constraints on the unmasking order.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
"""
# dec_output: (B, E)
device = dec_output.device
# disable KV cache since our transformer is not causal
self.local_transformer.reset_cache(use_cache=False)
dec_output = dec_output.unsqueeze(1) # (B, 1, E)
local_transformer_input_init = self.local_transformer_in_projection(
dec_output
) # (B, 1, D) where D is the dimension of the local transformer
codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor
B = dec_output.size(0)
min_confidence = 0
# this needs to be large enough that unmasked items will always remain unmasked (even after noise addition)
# Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that
max_confidence = 5
confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device)
# initialize to all masked
codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long)
sampled_codes = codes.clone()
topk_indices = None
if fixed_schedule is not None:
n_steps = len(fixed_schedule)
for step in range(n_steps):
# how far along we are in the unmasking process
progress = step / n_steps
# get mask fraction
frac_masked = cosine_schedule(torch.tensor(progress))
if sampling_type == "causal" or sampling_type == "purity_causal":
frac_masked = torch.ones_like(frac_masked) * (1.0 - progress)
# how many codebooks to mask
if fixed_schedule is None:
n_masked = torch.ceil(codebook_seq_len * frac_masked).long()
else:
n_masked = codebook_seq_len - fixed_schedule[step]
n_unmasked = codebook_seq_len - n_masked
if (
sampling_type == "causal" or sampling_type == "purity_causal"
): # and n_unmasked <= self.num_audio_codebooks:
# force second frame not to be unmasked
n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1))
confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = (
min_confidence - 1
) # only tested for frame_stacking_factor=2
# pick top-confidence codebooks up to n_unmasked
_, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1)
if use_cfg:
actual_batch_size = topk_indices.size(0) // 2
assert (
topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size]
).all(), "Topk indices are not the same for conditional and unconditional codes"
# replace masks of the top-k confident codebooks with the codes that were sampled for them
unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices)
codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)
# build transformer input
local_transformer_input = local_transformer_input_init
for codebook_num in range(codebook_seq_len):
next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(
1
) # (B, 1, 768)
next_local_transformer_input = self.local_transformer_in_projection(
next_local_transformer_input
) # (B, 1, d_local)
local_transformer_input = torch.cat(
[local_transformer_input, next_local_transformer_input], dim=1
) # (B, codebook_num+1, d_local)
# run transformer
_mask = torch.ones(B, codebook_seq_len + 1, device=device)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)[
'output'
] # (B, C+1, d_local)
# get logits
logits = []
for codebook_num in range(codebook_seq_len):
# The `codebook_num+1` is to drop first position which corresponds to the magpie latent
codebook_logits = self.local_transformer_out_projections[codebook_num](
local_transformer_output[:, codebook_num + 1, :]
) # (B, num_audio_tokens_per_codebook)
logits.append(codebook_logits)
logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook)
# apply CFG
if use_cfg:
actual_batch_size = logits.size(0) // 2
conditional_logits = logits[:actual_batch_size]
unconditional_logits = logits[actual_batch_size:]
if not dynamic_cfg_scale:
current_cfg_scale = cfg_scale
else:
# gradually increase the scale until mid point through sampling, then reduce it again
progress = step / (n_steps - 1)
# interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero
# interp = 1.0 - progress # decrease from 1 to 0
interp = progress # gradually increase from 0 to 1
current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0
cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits
logits[:actual_batch_size] = cfg_logits
# Disallow generation of special tokens
logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos)
# handle unfinished and finished items
for item_idx in unfinished_items:
logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
logits[item_idx, :, :] = float('-inf')
logits[item_idx, :, self.audio_eos_id] = 0.0
# sample with top-k
logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk)
indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook)
logits_rescored = logits.clone()
logits_rescored[indices_to_remove] = float('-inf')
probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook)
sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len)
if use_cfg:
sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size]
probs[actual_batch_size:] = probs[:actual_batch_size]
if sampling_type != "purity_causal" and sampling_type != "purity_default":
confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1)
else:
# use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling"
confidences = probs.max(dim=2)[0]
# replace entries in sampled_codes with previously unmasked codebooks
sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)
# add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439)
if noise_scale > 0.0:
# get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`,
# and anneal it to 0 as we approach the end of the unmasking process
noise = (
(torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps)
) # the +2 makes sure that by the last iteration the noise is exactly 0
confidences += noise
# the conditional and unconditional get different noise and must be fixed to be the same again
confidences[actual_batch_size:] = confidences[:actual_batch_size]
confidence_eps = 0.1
assert (
confidences.max() + confidence_eps < max_confidence
), f"Predicted confidence is approaching max_confidence: {confidences.max()}"
# for unmasked codebooks, set confidence to max so that they will remain unmasked
confidences.scatter_(
index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float)
)
codes = sampled_codes
assert not (
codes == self.mask_token_id
).any(), "Codes contain mask tokens after completion of MaskGit sampling"
# break stacked groups of frames into individual frames
codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute(
0, 2, 1
) # B, C, frame_stacking_factor
if use_cfg:
# drop unconditional codes
codes = codes[:actual_batch_size]
return codes
def local_transformer_sample_autoregressive(
self,
dec_output: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
use_cfg: bool = False,
cfg_scale: float = 1.0,
use_kv_cache: bool = True,
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample audio codes autoregressively across codebooks using the local
transformer. Uses multinomial sampling with temperature, top-k, and
classifier-free guidance (CFG).
The sequence is initialized with the primary decoder's hidden output as the only
input and is gradually extended a code for one codebook at a time, appending the
sampled code as input sequence for the next step. At the last step the sequence
is `num_codebooks` long. If frame stacking is enabled, codes for all frames in
the stack are sampled as one long sequence and the final sequence length is
`num_codebooks * frame_stacking_factor` codes long.
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS (useful early in the generation process)
Args:
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
and E is primary decoder's embedding dimension.
temperature (float, optional): Sampling temperature.
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
use_kv_cache (bool, optional): Whether to use key-value caching in the transformer.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
where B is batch size (or actual_batch_size if use_cfg=True).
"""
self.local_transformer.reset_cache(use_cache=use_kv_cache)
dec_output = dec_output.unsqueeze(1) # (B, 1, E)
local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128)
all_preds = []
for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor):
_mask = torch.ones(
local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128)
codebook_logits = self.local_transformer_out_projections[codebook_num](
local_transformer_output[:, -1, :]
) # (B, num_all_tokens_per_codebook)
if use_cfg:
actual_batch_size = codebook_logits.size(0) // 2
conditional_logits = codebook_logits[:actual_batch_size]
unconditional_logits = codebook_logits[actual_batch_size:]
cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits
codebook_logits[:actual_batch_size] = cfg_logits
for item_idx in unfinished_items:
codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0
# Disallow generation of special tokens
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
) # (B, num_tokens_per_codebook)
codebook_logits_rescored = codebook_logits.clone()
codebook_logits_rescored[indices_to_remove] = float('-inf')
codebook_probs = torch.softmax(
codebook_logits_rescored / temperature, dim=-1
) # (B, num_tokens_per_codebook)
codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1)
if use_cfg:
codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size]
all_preds.append(codebook_preds)
next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(
1
) # (B, 1, 128)
next_local_transformer_input = self.local_transformer_in_projection(
next_local_transformer_input
) # (B, 1, 128)
local_transformer_input = torch.cat(
[local_transformer_input, next_local_transformer_input], dim=1
) # (B, T+1, 128)
all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor)
all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute(
0, 2, 1
) # (B, num_codebooks, frame_stacking_factor)
if use_cfg:
all_preds = all_preds[:actual_batch_size]
return all_preds
def sample_codes_from_logits(
self,
all_code_logits_t: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample codes for all codebooks at a given timestep. Uses multinomial sampling
with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor
> 1`), this function will sample across the entire frame stack.
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS (useful early in the generation process)
Args:
all_code_logits_t (torch.Tensor): Logits at a given timestep with shape
(B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
temperature (float, optional): Sampling temperature
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
"""
all_preds = [[] for _ in range(self.frame_stacking_factor)]
for fs_index in range(self.frame_stacking_factor):
for idx in range(self.num_audio_codebooks):
si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook)
for item_idx in unfinished_items:
codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0
# Disallow generation of special tokens
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
) # (B, num_tokens_per_codebook)
codebook_logits_rescored = codebook_logits.clone()
codebook_logits_rescored[indices_to_remove] = float('-inf')
codebook_probs = torch.softmax(
codebook_logits_rescored / temperature, dim=-1
) # (B, num_tokens_per_codebook)
codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1)
all_preds[fs_index].append(codebook_preds)
all_preds = [
torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds
] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks)
all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor)
return all_preds
def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0):
# attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps)
wandb_images_log = {}
with torch.no_grad():
attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps)
attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps)
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
wandb_images_log[f"Image/{prefix}/attention_matrix"] = list()
for idx in range(min(3, attention_prob_matrix_mean.size(0))):
item_attn_matrix = attention_prob_matrix_mean[idx][
dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx]
]
item_attn_matrix = item_attn_matrix.detach().cpu().numpy()
img_np = plot_alignment_to_numpy(item_attn_matrix.T)
if is_wandb:
wandb_images_log[f"Image/{prefix}/attention_matrix"].append(
wandb.Image(img_np, caption=f"Example_{idx}")
)
if is_tb:
logger.experiment.add_image(
f'{prefix}/attention_matrix/Example_{idx}',
img_np,
global_step=self.global_step,
dataformats="HWC",
)
return wandb_images_log
def log_val_audio_example(
self,
logits,
target_audio_codes,
audio_codes_lens_target,
context_audio_codes=None,
context_audio_codes_lens=None,
):
wandb_audio_log = {}
pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target)
pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target)
target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target)
context_audio, context_audio_lens = None, None
if context_audio_codes is not None and context_audio_codes.shape[2] > 3:
# > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context)
context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens)
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
for idx in range(min(3, pred_audio.size(0))):
pred_audio_np = pred_audio[idx].float().detach().cpu().numpy()
target_audio_np = target_audio[idx].float().detach().cpu().numpy()
pred_audio_np = pred_audio_np[: pred_audio_lens[idx]]
target_audio_np = target_audio_np[: target_audio_lens[idx]]
context_audio_np = None
if context_audio is not None:
context_audio_np = context_audio[idx].float().detach().cpu().numpy()
context_audio_np = context_audio_np[: context_audio_lens[idx]]
if is_wandb:
wandb_audio_log[f"Audio/Example_{idx}"] = list()
if context_audio_np is not None:
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context")
)
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction")
)
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target")
)
if is_tb:
if context_audio_np is not None:
logger.experiment.add_audio(
f'Example_{idx}/context',
context_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
logger.experiment.add_audio(
f'Example_{idx}/prediction',
pred_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
logger.experiment.add_audio(
f'Example_{idx}/target',
target_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
return wandb_audio_log
def scale_prior(self, prior, global_step):
if prior is None:
return None
if global_step < self.prior_scaledown_start_step:
return prior
elif global_step >= self.prior_end_step:
if random.random() < self.indefinite_prior_prob:
print("Using Prior")
return prior
else:
print("Not using Prior")
return None
else:
with torch.no_grad():
# Interpolate between all ones and the prior
residual = 1.0 - prior
new_prior = prior + (
residual
* (global_step - self.prior_scaledown_start_step)
/ (self.prior_end_step - self.prior_scaledown_start_step)
)
return new_prior
def embed_text(self, text, text_mask):
if self.use_bpe_char_tokenizer:
text_embedded = self.cas_encoder(text, subword_mask=text_mask)
else:
text_embedded = self.text_embedding(text)
return text_embedded
def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0):
# attention scores: List of (B, C, audio_timesteps, text_timesteps)
attention_scores_combined = torch.cat(attention_scores, dim=1) # (B, C, audio_timesteps, text_timesteps)
attention_scores_mean = attention_scores_combined.mean(
dim=1, keepdim=True
) # (B, 1, audio_timesteps, text_timesteps)
attention_scores_mean = attention_scores_mean[
:, :, dec_context_size:, :
] # Remove the context audio embeddings from the attention scores
alignment_loss = self.alignment_loss(
attn_logprob=attention_scores_mean, in_lens=text_lens, out_lens=audio_lens
)
return alignment_loss
def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0):
"""
Pads the time dimension of the audio codes to a multiple of the frame stacking factor.
Args:
audio_codes (torch.Tensor): B, C, T
frame_stacking_factor (int): The factor that frames will be stacked by.
pad_token (int): The token ID to pad with.
Returns:
B, C, T_padded
"""
T = audio_codes.size(2)
T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor)
if T_padded > T:
padding = pad_token * torch.ones(
audio_codes.size(0),
audio_codes.size(1),
T_padded - T,
device=audio_codes.device,
dtype=audio_codes.dtype,
)
audio_codes = torch.cat([audio_codes, padding], dim=2)
return audio_codes
def embed_context_text(self, context_text_tokens):
if self.legacy_text_conditioning:
context_text_tokens = (
context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name]
)
context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E)
else:
context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E)
return context_text_embedded
def prepare_context_tensors(self, batch):
dec_context_size = 0
additional_decoder_input = None
additional_decoder_mask = None
context_audio_codes = None
context_audio_codes_lens = None
_attn_prior = None
attn_prior = None
cond = None
cond_mask = None
multi_encoder_mapping = None
text = None
text_lens = None
# self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce]
text = batch['text']
text_lens = batch['text_lens']
text_mask = get_mask_from_lengths(text_lens) # (B, T)
text_embedded = self.embed_text(text, text_mask) # (B, T, E)
text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E)
_attn_prior = batch.get('align_prior_matrix', None)
_attn_prior = self.scale_prior(_attn_prior, self.global_step)
if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']:
if 'context_audio_codes' in batch:
context_audio_codes = batch['context_audio_codes']
context_audio_codes_lens = batch['context_audio_codes_lens']
if self._codec_converter is not None:
context_audio_codes = self._codec_converter.convert_original_to_new(
audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens
).long()
else:
context_audio_codes, context_audio_codes_lens = self.audio_to_codes(
batch['context_audio'], batch['context_audio_lens'], audio_type='context'
)
context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0)
context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T/frame_stacking_factor, E)
if self.use_text_conditioning_encoder:
context_text_tokens = batch['context_text_tokens']
context_text_lens = batch['context_text_tokens_lens']
context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E)
# Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps
if context_audio_embedded.size(1) < context_text_embedded.size(1):
padding = torch.zeros(
context_audio_embedded.size(0),
context_text_embedded.size(1) - context_audio_embedded.size(1),
context_audio_embedded.size(2),
device=context_audio_embedded.device,
)
context_audio_embedded = torch.cat([context_audio_embedded, padding], dim=1)
elif context_audio_embedded.size(1) > context_text_embedded.size(1):
padding = torch.zeros(
context_text_embedded.size(0),
context_audio_embedded.size(1) - context_text_embedded.size(1),
context_text_embedded.size(2),
device=context_text_embedded.device,
)
context_text_embedded = torch.cat([context_text_embedded, padding], dim=1) # (B, T, E)
has_text_context = batch['has_text_context'].unsqueeze(-1).unsqueeze(-1).float() # (B, 1, 1)
context_input_embedded = (
has_text_context * context_text_embedded + (1 - has_text_context) * context_audio_embedded
)
context_input_lens = (
batch['has_text_context'].float() * context_text_lens
+ (1 - batch['has_text_context'].float()) * context_audio_codes_lens
) # (B,)
else:
context_input_embedded = context_audio_embedded
context_input_lens = context_audio_codes_lens
context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to(
context_input_lens.dtype
)
context_mask = get_mask_from_lengths(context_input_lens)
if self.model_type == 'multi_encoder_context_tts':
context_embeddings = self.context_encoder(
context_input_embedded, context_mask, cond=None, cond_mask=None
)['output']
cond = [text_encoder_out, context_embeddings]
cond_mask = [text_mask, context_mask]
multi_encoder_mapping = self.multi_encoder_mapping
attn_prior = [_attn_prior, None]
elif self.model_type in ['decoder_context_tts', 'decoder_ce']:
context_embeddings = None # Address CodeQL
if self.model_type == 'decoder_context_tts':
context_embeddings = context_input_embedded
elif self.model_type == 'decoder_ce':
# Check for baked context embedding first
if self.has_baked_context_embedding:
# self.baked_context_embedding is a fixed context embedding that is baked into the model.
# This is used when we do not want users to generate speech with context audio or context text.
# This is done to disable zero-shot inference. Users can only generate speech in 1 voice chosen
# by the model development team.
batch_size = text.size(0)
# Expand baked embedding to batch size: (T, E) -> (B, T, E)
context_embeddings = self.baked_context_embedding.unsqueeze(0).expand(batch_size, -1, -1)
# Create context mask from baked length
context_input_lens = (
self.baked_context_embedding_len.unsqueeze(0).expand(batch_size).to(text.device)
)
context_mask = get_mask_from_lengths(context_input_lens)
else:
context_embeddings = self.context_encoder(
context_input_embedded, context_mask, cond=None, cond_mask=None
)['output']
dec_context_size = context_mask.size(1)
attn_prior = _attn_prior
if attn_prior is not None:
# B, audio_timesteps, text_timesteps
padding_zeros = torch.zeros(
attn_prior.size(0), dec_context_size, attn_prior.size(2), device=attn_prior.device
)
attn_prior = torch.cat([padding_zeros, attn_prior], dim=1)
cond = text_encoder_out
cond_mask = text_mask
multi_encoder_mapping = None
additional_decoder_input = context_embeddings
additional_decoder_mask = context_mask
else:
raise ValueError(f"Unsupported model type {self.model_type}")
if attn_prior is not None and self.ctc_prior_layer_ids is not None:
# Convert prior to a list of tensors, one for each layer
# Set None for layers not in ctc_prior_layer_ids
if self.model_type == 'multi_encoder_context_tts':
text_attn_prior = [
attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None
for layer_idx in range(self.decoder.n_layers)
]
attn_prior = [text_attn_prior, attn_prior[1]]
else:
attn_prior = [
attn_prior if layer_idx in self.ctc_prior_layer_ids else None
for layer_idx in range(self.decoder.n_layers)
]
return {
'beta_binomial_attn_prior': batch.get('align_prior_matrix', None),
'text_encoder_out': text_encoder_out,
'cond': cond,
'cond_mask': cond_mask,
'attn_prior': attn_prior,
'prior_used': _attn_prior is not None,
'multi_encoder_mapping': multi_encoder_mapping,
'additional_decoder_input': additional_decoder_input,
'additional_decoder_mask': additional_decoder_mask,
'dec_context_size': dec_context_size,
'text': text,
'text_embedded': text_embedded,
'text_mask': text_mask,
'text_lens': text_lens,
'context_audio_codes': context_audio_codes,
'context_audio_codes_lens': context_audio_codes_lens,
}
def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard):
# aligner_attn_hard B, audio_timesteps, text_timesteps
if self.model_type == 'multi_encoder_context_tts':
text_attn_prior = attn_prior[0]
else:
text_attn_prior = attn_prior
assert text_attn_prior is not None, "Prior is None"
if isinstance(text_attn_prior, list):
# Layer wise prior
prior_updated = False
for idx, prior in enumerate(text_attn_prior):
if prior is not None:
text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard
prior_updated = True
assert prior_updated, "Did not find any prior to update"
else:
# Same prior for all layers
text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard
if self.model_type == 'multi_encoder_context_tts':
attn_prior[0] = text_attn_prior
else:
attn_prior = text_attn_prior
return attn_prior
def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens):
# aligner_attn_soft B, 1, audio_timesteps, text_timesteps
if self.binarize_attn_method == 'nemo_binarize':
logging.debug("Binarizing attention using nemo_binarize")
binarize_repeat_audio_factor = self.binarize_repeat_audio_factor
aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(
binarize_repeat_audio_factor, dim=2
) # B, 1, 2*audio_timesteps, text_timesteps
aligner_attn_hard = binarize_attention_parallel(
aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor
).squeeze(
1
) # B, 2*audio_timesteps, text_timesteps
aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps
elif self.binarize_attn_method == 'argmax':
logging.debug("Binarizing attention using argmax")
aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1)
aligner_attn_hard = torch.nn.functional.one_hot(
aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)
).float()
else:
raise ValueError(
f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'."
)
aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon
for future_timestep in range(self.prior_future_context):
decay_factor = self.prior_future_decay ** (future_timestep + 1)
aligner_attn_hard_wider[:, :, future_timestep + 1 :] += (
decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)]
)
for past_timestep in range(self.prior_past_context):
decay_factor = self.prior_past_decay ** (past_timestep + 1)
aligner_attn_hard_wider[:, :, : -past_timestep - 1] += (
decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :]
)
aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0)
return aligner_attn_hard_wider
def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask):
dummy_additional_decoder_input = None
dummy_additional_dec_mask = None
if additional_decoder_input is not None:
dummy_additional_decoder_input = torch.zeros_like(additional_decoder_input)
# all ones mask means dont ignore any timesteps (so that it is consistent with usual decoder mask)
dummy_additional_dec_mask = torch.ones_like(additional_dec_mask)
if isinstance(cond, list):
# multi encoder conditioning
dummy_cond = [torch.zeros_like(cond_item) for cond_item in cond]
attn_prior = [None for _ in cond]
dummy_mask = []
for mask_item in cond_mask:
# ignore all timesteps except the first one
mask = torch.zeros_like(mask_item)
mask[:, 0] = 1 # Make first timestep all zeros
dummy_mask.append(mask)
elif isinstance(cond, torch.Tensor):
# single encoder conditioning
dummy_cond = torch.zeros_like(cond)
dummy_mask = torch.zeros_like(cond_mask)
dummy_mask[:, 0] = 1 # ignore all timesteps except the first one
attn_prior = None
else:
raise ValueError(f"Unsupported type for cond {type(cond)}")
return dummy_cond, dummy_mask, dummy_additional_decoder_input, dummy_additional_dec_mask, attn_prior
def process_batch(self, batch, mode="train"):
context_tensors = self.prepare_context_tensors(batch)
disable_alignment_loss = False
if 'audio_codes' not in batch:
audio_codes, audio_codes_lens = self.audio_to_codes(batch['audio'], batch['audio_lens'])
else:
audio_codes = batch['audio_codes']
audio_codes_lens = batch['audio_codes_lens']
if self._codec_converter:
audio_codes = self._codec_converter.convert_original_to_new(
audio_tokens=audio_codes, audio_lens=audio_codes_lens
).long()
if self.frame_stacking_factor > 1:
# repeat the BOS token to frame_stacking_factor times. This is necessary since at inference
# we need to start autoregressive generation from a full stack indicating BOS.
# TODO: @rfejgin: this assert might be slow due to GPU/CPU sync
assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token"
audio_codes = torch.cat(
[
torch.full(
(audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1),
self.audio_bos_id,
device=audio_codes.device,
dtype=audio_codes.dtype,
),
audio_codes,
],
dim=2,
)
audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat
audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0)
# Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain
# drop last (stacked) frame since it is not part of *input*
audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor] # B, C, T'
# drop first (stacked) frame which contains BOS token(s) which are not part of *target*
audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :]
audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input
audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target
audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long()
audio_codes_embedded_all = self.embed_audio_tokens(
audio_codes
) # (B, T, E) # Computing this to be use in the alignment encoder
audio_codes_embedded = audio_codes_embedded_all[
:, :-1, :
] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`)
audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input)
use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None)
if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob:
cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = (
self.prepare_dummy_cond_for_cfg(
context_tensors['cond'],
context_tensors['cond_mask'],
context_tensors['additional_decoder_input'],
context_tensors['additional_decoder_mask'],
)
)
disable_alignment_loss = True
else:
cond = context_tensors['cond']
cond_mask = context_tensors['cond_mask']
additional_decoder_input = context_tensors['additional_decoder_input']
additional_decoder_mask = context_tensors['additional_decoder_mask']
attn_prior = context_tensors['attn_prior']
if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5:
# For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens
max_codebook_val = self.dec_random_input_max
# @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens
# which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on
# audio_codes_input so should not matter if we don't supply dec_random_input_max.
random_audio_tokens = torch.randint(
0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device
)
random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1)
dec_dropout_mask = (
torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device)
> self.decoder_input_dropout_prob
)
# timestep_mask is True for timesteps to be kept
audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * (
~dec_dropout_mask
)
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E)
if context_tensors['additional_decoder_input'] is not None:
dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1)
dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1)
else:
dec_input_embedded = audio_codes_embedded
dec_input_mask = audio_codes_mask
aligner_encoder_loss = None
aligner_attn_soft = None
aligner_attn_hard = None
if self.use_alignment_encoder and not disable_alignment_loss:
aligner_prior = None
if self.use_prior_for_aligner:
aligner_prior = context_tensors['beta_binomial_attn_prior']
# Passing target audio embeddings to the alignment encoder
if self.global_step < self.aligner_encoder_train_steps:
aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T'
keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T
mask=~context_tensors['text_mask'].unsqueeze(-1),
attn_prior=aligner_prior,
)
aligner_encoder_loss = self.alignment_encoder_loss(
attn_logprob=aligner_attn_logprobs,
in_lens=context_tensors['text_lens'],
out_lens=audio_codes_lens_input,
)
else:
with torch.no_grad():
# Just get the attention matrix without computing the loss or gradients
aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T'
keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T
mask=~context_tensors['text_mask'].unsqueeze(-1),
attn_prior=aligner_prior,
)
with torch.no_grad():
aligner_attn_hard = self.get_binarized_prior_matrix(
aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens']
)
if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']:
attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard)
logits, attn_info, dec_out = self.forward(
dec_input_embedded=dec_input_embedded,
dec_input_mask=dec_input_mask,
cond=cond,
cond_mask=cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
# logits: (B, T', num_codebooks * num_tokens_per_codebook)
# dec_out: (B, T', E)
dec_context_size = context_tensors['dec_context_size']
logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits
# Codebook loss (parallel)
codebook_loss, loss_mask = self.compute_loss(
logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
frame_stacking_factor=self.frame_stacking_factor,
)
# Alignment loss
alignment_loss = None
if self.alignment_loss_scale > 0.0 and not disable_alignment_loss:
text_lens = context_tensors['text_lens']
cross_attention_scores = [
attn['cross_attn_probabilities'][1]
for layer_idx, attn in enumerate(attn_info)
if layer_idx in self.ctc_prior_layer_ids
]
alignment_loss = self.compute_alignment_loss(
cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size
)
loss = self.codebook_loss_scale * codebook_loss + alignment_loss
else:
loss = self.codebook_loss_scale * codebook_loss
# Local Transformer loss
local_transformer_loss = None
local_transformer_logits = None
if self.local_transformer_type != LocalTransformerType.NO_LT:
if self.local_transformer_type == LocalTransformerType.MASKGIT:
# Maskgit
# randomly replace some positions with MASK_TOKEN
audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked)
# TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of
# of a pair where the first position is valid. Is this an issue?
local_transformer_logits = self.compute_local_transformer_logits(
dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True
)
local_transformer_loss, _ = self.compute_loss(
local_transformer_logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
mask_tokens_mask,
frame_stacking_factor=self.frame_stacking_factor,
)
else:
# Autoregressive
assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type"
local_transformer_logits = self.compute_local_transformer_logits(
dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False
)
local_transformer_loss, _ = self.compute_loss(
local_transformer_logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
None,
frame_stacking_factor=self.frame_stacking_factor,
)
loss = loss + self.local_transformer_loss_scale * local_transformer_loss
if aligner_encoder_loss is not None:
loss = loss + aligner_encoder_loss
return {
'logits': logits,
'attn_info': attn_info,
'loss': loss,
'codebook_loss': codebook_loss,
'local_transformer_loss': local_transformer_loss,
'local_transformer_logits': local_transformer_logits,
'loss_mask': loss_mask,
'alignment_loss': alignment_loss,
'aligner_encoder_loss': aligner_encoder_loss,
'audio_codes_target': audio_codes_target_unstacked,
'audio_codes_lens_target': audio_codes_lens_target_unstacked,
'text': context_tensors['text'],
'text_lens': context_tensors['text_lens'],
'context_audio_codes': context_tensors['context_audio_codes'],
'context_audio_codes_lens': context_tensors['context_audio_codes_lens'],
'dec_context_size': dec_context_size,
'aligner_attn_soft': aligner_attn_soft,
'aligner_attn_hard': aligner_attn_hard,
}
def training_step(self, batch, batch_idx):
batch_output = self.process_batch(batch)
loss = batch_output['loss']
codebook_loss = batch_output['codebook_loss']
self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True)
if self.cfg_unconditional_prob == 0.0:
# Only log alignment loss when not using cfg to avoid sync issues when
# alignment loss is None on some ranks
alignment_loss = batch_output['alignment_loss']
if alignment_loss is not None:
self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True)
self.log('train/loss', loss, prog_bar=True, sync_dist=True)
local_transformer_loss = batch_output['local_transformer_loss']
if local_transformer_loss is not None:
self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True)
# Log batch info
batch_size, text_token_max_len = batch["text"].shape
text_token_total_num = batch["text_lens"].sum()
batch_info_dict = {
"train/batch_size": batch_size,
"train/text_token_max_len": text_token_max_len,
"train/text_token_total_num_in_batch": text_token_total_num.item(),
"train/text_token_pad_ratio_percent_in_batch": 100
* (1 - text_token_total_num / (batch_size * text_token_max_len)),
}
if "audio_codes" in batch:
audio_codes_max_len = batch["audio_codes"].shape[-1]
audio_codes_total_num = batch["audio_codes_lens"].sum()
batch_info_dict.update(
{
"train/audio_codes_max_len": audio_codes_max_len,
"train/audio_codes_total_num_in_batch": audio_codes_total_num.item(),
"train/audio_codes_pad_ratio_percent_in_batch": 100
* (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)),
}
)
else:
audio_samples_max_len = batch["audio"].shape[-1]
audio_samples_total_num = batch["audio_lens"].sum()
batch_info_dict.update(
{
"train/audio_samples_max_len": audio_samples_max_len,
"train/audio_samples_total_num_in_batch": audio_samples_total_num.item(),
"train/audio_samples_pad_ratio_percent_in_batch": 100
* (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)),
}
)
self.log_dict(batch_info_dict, on_step=True)
return loss
def validation_step(self, batch, batch_idx):
batch_output = self.process_batch(batch, mode="val")
# self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction
# head. If we use local_transformer, then the local_transformer returns "local_transformer_logits"
loss = batch_output['loss']
codebook_loss = batch_output['codebook_loss']
alignment_loss = batch_output['alignment_loss']
aligner_encoder_loss = batch_output['aligner_encoder_loss']
logits = batch_output['logits']
audio_codes_target = batch_output['audio_codes_target']
audio_codes_lens_target = batch_output['audio_codes_lens_target']
context_audio_codes = batch_output['context_audio_codes']
context_audio_codes_lens = batch_output['context_audio_codes_lens']
attn_info = batch_output['attn_info']
text_lens = batch_output['text_lens']
dec_context_size = batch_output['dec_context_size']
if alignment_loss is None:
alignment_loss = torch.tensor(0.0, device=loss.device)
if aligner_encoder_loss is None:
aligner_encoder_loss = torch.tensor(0.0, device=loss.device)
if batch_idx == 0 and self.global_rank == 0:
# Prepare dictionary for aggregated wandb logging
wandb_log_dict = {}
# Get audio data for logging
wandb_log_dict.update(
self.log_val_audio_example(
logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens
)
)
# Get attention image data for logging
if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1:
# cross_attn_probabilities only returned when not using flash attention
cross_attention_probs = [
attn['cross_attn_probabilities'][0]
for layer_idx, attn in enumerate(attn_info)
if layer_idx in self.ctc_prior_layer_ids
]
wandb_log_dict.update(
self.log_attention_probs(
cross_attention_probs,
audio_codes_lens_target,
text_lens,
prefix="val",
dec_context_size=dec_context_size,
)
)
for layer_idx in self.transcript_decoder_layers:
cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]]
wandb_log_dict.update(
self.log_attention_probs(
cross_attention_probs,
audio_codes_lens_target,
text_lens,
prefix=f"val/layer_{layer_idx}",
dec_context_size=dec_context_size,
)
)
if batch_output['aligner_attn_soft'] is not None:
wandb_log_dict.update(
self.log_attention_probs(
[batch_output['aligner_attn_soft']],
audio_codes_lens_target,
text_lens,
prefix="val/aligner_encoder_attn",
)
)
if batch_output['aligner_attn_hard'] is not None:
wandb_log_dict.update(
self.log_attention_probs(
[batch_output['aligner_attn_hard'].unsqueeze(1)],
audio_codes_lens_target,
text_lens,
prefix="val/aligner_encoder_attn_hard",
)
)
# Perform single wandb log call if wandb is active and there is data
for logger in self.loggers:
if isinstance(logger, WandbLogger) and wandb_log_dict:
logger.experiment.log(wandb_log_dict)
local_transformer_loss = batch_output['local_transformer_loss']
val_output = {
'val_loss': loss,
'val_codebook_loss': codebook_loss,
'val_alignment_loss': alignment_loss,
'val_local_transformer_loss': local_transformer_loss,
'val_aligner_encoder_loss': aligner_encoder_loss,
}
self.validation_step_outputs.append(val_output)
return val_output
def get_cross_attention_scores(self, attn_probs, filter_layers=None):
"""
Returns the cross attention probabilities for the last audio timestep
"""
mean_cross_attn_scores = []
all_heads_cross_attn_scores = []
for lidx, layerwise_attn_prob in enumerate(attn_probs):
if (filter_layers is not None and lidx not in filter_layers) or (
lidx not in self.transcript_decoder_layers
):
continue
cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][
0
] # B, H, audio_timesteps, text_timesteps
mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps
for head_idx in range(cross_attn_prob.size(1)):
all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps
mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps
mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps
last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps
return last_audio_timestep_scores, all_heads_cross_attn_scores
def get_most_attended_text_timestep(
self,
alignment_attention_scores,
last_attended_timesteps,
text_lens,
lookahead_window_size,
attended_timestep_counter,
batch_size,
):
"""
Returns the most attended timestep for each batch item
"""
text_time_step_attended = []
for bidx in range(batch_size):
last_attended_timestep = last_attended_timesteps[-1][bidx]
if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8:
# This is probably an attention sink! Move to the next timestep
last_attended_timestep += 1
window_size = lookahead_window_size
window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps
item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end]
if item_attention_scores.size(0) == 0:
# This means the sentence has ended
attended_timestep = text_lens[bidx].item() - 1
else:
attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep
text_time_step_attended.append(attended_timestep)
attended_timestep_counter[bidx][attended_timestep] = (
attended_timestep_counter[bidx].get(attended_timestep, 0) + 1
)
return text_time_step_attended, attended_timestep_counter
def construct_inference_prior(
self,
prior_epsilon,
cross_attention_scores,
text_lens,
text_time_step_attended,
attended_timestep_counter,
unfinished_texts,
finished_texts_counter,
end_indices,
lookahead_window_size,
batch_size,
):
# Attn prior for the next timestep
_attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon
_attn_prior = _attn_prior.to(cross_attention_scores.device)
for bidx in range(cross_attention_scores.shape[0]):
if bidx < batch_size:
_text_len = text_lens[bidx]
if text_lens[bidx] <= 5:
# Very short sentences, No Prior
_attn_prior[bidx, 0, :] = 1.0
else:
_attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = (
1.0 # Slight exposure to history for better pronounciation. Not very important.
)
_attn_prior[bidx, 0, text_time_step_attended[bidx]] = (
1.0 # Slightly bias to continue moving forward. Not very important.
)
for ind in range(1, lookahead_window_size + 1):
_attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0
# Penalize timesteps that have been attended to more than 10 times
for _timestep in attended_timestep_counter[bidx]:
if attended_timestep_counter[bidx][_timestep] >= 10:
# This means the timestep has been attended to more than 10 times (To avoid getting stuck)
_attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon
unfinished_texts[bidx] = False
if text_time_step_attended[bidx] < text_lens[bidx] - 3:
# This means the sentence has not ended
if bidx not in end_indices:
unfinished_texts[bidx] = True
if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices:
if bidx not in finished_texts_counter:
finished_texts_counter[bidx] = 0
for bidx in finished_texts_counter:
finished_texts_counter[bidx] += 1
if finished_texts_counter[bidx] > 5:
# This means we have been within the text EOS window for at least 5 timesteps
# We should allow EOS to be predicted now.
unfinished_texts[bidx] = False
return _attn_prior, unfinished_texts, finished_texts_counter
def get_inference_attention_plots(
self,
cross_attention_scores_all_timesteps,
all_heads_cross_attn_scores_all_timesteps,
text_lens,
predicted_codes_lens,
batch_size,
compute_all_heads_attn_maps,
last_attended_timestep,
):
last_attended_timestep = np.array(last_attended_timestep).T
cross_attention_scores_all_timesteps = torch.stack(
cross_attention_scores_all_timesteps, dim=2
) # B, text_timesteps, T'
headwise_cross_attention_scores_all_timesteps = []
for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
head_cross_attention_all_timesteps = torch.stack(
[x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2
) # B, text_timesteps, T'
headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps)
cross_attention_maps = []
headwise_cross_attention_maps = []
for bidx in range(batch_size):
item_cross_attention_scores = cross_attention_scores_all_timesteps[
bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
]
cross_attn_np = plot_alignment_to_numpy(
item_cross_attention_scores.cpu().numpy(),
attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
)
cross_attention_maps.append(cross_attn_np)
item_all_head_cross_attn_maps = []
if compute_all_heads_attn_maps:
for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][
bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
]
headwise_cross_attn_np = plot_alignment_to_numpy(
item_headwise_cross_attention_scores.cpu().numpy(),
attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
)
item_all_head_cross_attn_maps.append(headwise_cross_attn_np)
headwise_cross_attention_maps.append(item_all_head_cross_attn_maps)
return cross_attention_maps, headwise_cross_attention_maps
def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]:
"""
Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack
that contains an EOS token across any codebook, or `None` if no EOS is found.
Args:
codes: (num_codebooks, frame_stacking_factor)
Returns:
index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
"""
eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor)
detection_type = EOSDetectionMethod.detection_type(eos_detection_method)
if detection_type == "any":
eos_per_frame = eos_mask.any(
dim=0
) # (frame_stacking_factor,) - True if any codebook has EOS in this frame
elif detection_type == "all":
eos_per_frame = eos_mask.all(
dim=0
) # (frame_stacking_factor,) - True if all codebooks have EOS in this frame
elif detection_type == "zero_cb":
eos_per_frame = eos_mask[:1, :].any(
dim=0
) # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame
else:
raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")
# find first frame with EOS
if eos_per_frame.any():
# return index of the first frame with EOS
return eos_per_frame.nonzero()[0].item()
return float('inf')
def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]:
"""
Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack
that triggers EOS detection, or `float('inf')` if no EOS is found.
Args:
audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples
audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples
eos_detection_method: EOS detection method
Returns:
index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
"""
sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method)
if sampling_type == "argmax":
return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
elif sampling_type == "argmax_or_multinomial":
argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method)
return min(argmax_eos_frame, multinomial_eos_frame)
else:
raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")
def infer_batch(
self,
batch,
max_decoder_steps=500,
temperature=0.7,
topk=80,
use_cfg=False,
cfg_scale=1.0,
return_cross_attn_probs=False,
apply_attention_prior=False,
prior_epsilon=1e-5,
lookahead_window_size=10,
estimate_alignment_from_layers=None,
apply_prior_to_layers=None,
start_prior_after_n_audio_steps=10,
compute_all_heads_attn_maps=False,
use_local_transformer_for_inference=False,
use_LT_kv_cache=True,
maskgit_n_steps=3,
maskgit_noise_scale=0.0,
maskgit_fixed_schedule=None,
maskgit_dynamic_cfg_scale=False,
maskgit_sampling_type=None,
ignore_finished_sentence_tracking=False,
eos_detection_method="argmax_or_multinomial_any",
# Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4
# lines up with the codec's minimum frame requirement.
min_generated_frames=4,
):
eos_detection_method = EOSDetectionMethod(eos_detection_method)
with torch.no_grad():
start_time = time.time()
self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference)
context_tensors = self.prepare_context_tensors(batch)
text = context_tensors['text']
audio_codes_bos = torch.full(
(text.size(0), self.num_audio_codebooks, self.frame_stacking_factor),
self.audio_bos_id,
device=text.device,
).long()
audio_codes_lens = torch.full(
(text.size(0),), 1, device=text.device
).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form
audio_codes_input = audio_codes_bos
audio_codes_mask = get_mask_from_lengths(audio_codes_lens)
all_predictions = []
end_indices = {}
if use_cfg:
dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = (
self.prepare_dummy_cond_for_cfg(
context_tensors['cond'],
context_tensors['cond_mask'],
context_tensors['additional_decoder_input'],
context_tensors['additional_decoder_mask'],
)
)
cross_attention_scores_all_timesteps = []
all_heads_cross_attn_scores_all_timesteps = []
_attn_prior = None
unfinished_texts = {}
finished_texts_counter = {}
attended_timestep_counter = [{} for _ in range(text.size(0))]
last_attended_timesteps = [
[1 for _ in range(text.size(0))]
] # Maintain a list of attended timesteps as we predict audio for each batch item
time_to_first_prediction = 0.0
for idx in range(max_decoder_steps // self.frame_stacking_factor):
if idx == 1:
time_to_first_prediction = time.time() - start_time
if idx % 20 == 0:
print(f"Decoding timestep {idx}")
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input)
if context_tensors['additional_decoder_input'] is not None:
_audio_codes_embedded = torch.cat(
[context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1
)
_audio_codes_mask = torch.cat(
[context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1
)
else:
_audio_codes_embedded = audio_codes_embedded
_audio_codes_mask = audio_codes_mask
if apply_prior_to_layers is not None:
attn_prior = [None for _ in range(self.decoder.n_layers)]
for layer_idx in apply_prior_to_layers:
attn_prior[layer_idx] = _attn_prior
else:
attn_prior = _attn_prior
if self.model_type == 'multi_encoder_context_tts':
attn_prior = [attn_prior, None]
if use_cfg:
batch_size = audio_codes_embedded.size(0)
if isinstance(context_tensors['cond'], list):
cfg_cond = [
torch.cat([cond_item, dummy_cond_item], dim=0)
for cond_item, dummy_cond_item in zip(context_tensors['cond'], dummy_cond)
]
cfg_cond_mask = [
torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0)
for cond_mask_item, dummy_cond_mask_item in zip(
context_tensors['cond_mask'], dummy_cond_mask
)
]
else:
cfg_cond = torch.cat([context_tensors['cond'], dummy_cond], dim=0)
cfg_cond_mask = torch.cat([context_tensors['cond_mask'], dummy_cond_mask], dim=0)
cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0)
cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0)
if dummy_additional_decoder_input is not None:
cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_additional_decoder_input
)
cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_addition_dec_mask
)
# print(f"step {idx}")
# print(f"use_cfg {use_cfg}")
# print(f"shape {cfg_audio_codes_embedded.shape}")
# print(f"use kv cahce? {self.use_kv_cache_for_inference}")
combined_logits, attn_probs, dec_out = self.forward(
dec_input_embedded=cfg_audio_codes_embedded,
dec_input_mask=cfg_audio_codes_mask,
cond=cfg_cond,
cond_mask=cfg_cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
cond_logits = combined_logits[:batch_size]
uncond_logits = combined_logits[batch_size:]
all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits
else:
batch_size = audio_codes_embedded.size(0)
all_code_logits, attn_probs, dec_out = self.forward(
dec_input_embedded=_audio_codes_embedded,
dec_input_mask=_audio_codes_mask,
cond=context_tensors['cond'],
cond_mask=context_tensors['cond_mask'],
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
if return_cross_attn_probs or apply_attention_prior:
cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores(
attn_probs
) # B, text_timesteps
alignment_attention_scores = cross_attention_scores
if estimate_alignment_from_layers is not None:
alignment_attention_scores, _ = self.get_cross_attention_scores(
attn_probs, filter_layers=estimate_alignment_from_layers
) # B, text_timesteps
cross_attention_scores_all_timesteps.append(cross_attention_scores)
all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores)
if apply_attention_prior and idx >= start_prior_after_n_audio_steps:
text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep(
alignment_attention_scores=alignment_attention_scores,
last_attended_timesteps=last_attended_timesteps,
text_lens=context_tensors['text_lens'],
lookahead_window_size=lookahead_window_size,
attended_timestep_counter=attended_timestep_counter,
batch_size=batch_size,
)
last_attended_timesteps.append(text_time_step_attended)
_attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior(
prior_epsilon=prior_epsilon,
cross_attention_scores=cross_attention_scores,
text_lens=context_tensors['text_lens'],
text_time_step_attended=text_time_step_attended,
attended_timestep_counter=attended_timestep_counter,
unfinished_texts=unfinished_texts,
finished_texts_counter=finished_texts_counter,
end_indices=end_indices,
lookahead_window_size=lookahead_window_size,
batch_size=batch_size,
)
if ignore_finished_sentence_tracking:
finished_items = {}
unfinished_items = {}
else:
finished_items = {
k: v for k, v in finished_texts_counter.items() if v >= 20
} # Items that have been close to the end for atleast 20 timesteps
unfinished_items = {k: v for k, v in unfinished_texts.items() if v}
# Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor)
# This guards against rare cases of termination right at the start of generation.
forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames
all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook)
if use_local_transformer_for_inference:
if self.local_transformer_type == LocalTransformerType.AR:
# Autoregressive sampling with local transformer
audio_codes_next = self.local_transformer_sample_autoregressive(
dec_output=dec_out[:, -1, :],
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
use_kv_cache=use_LT_kv_cache,
forbid_audio_eos=forbid_audio_eos,
)
elif self.local_transformer_type == LocalTransformerType.MASKGIT:
audio_codes_next = self.local_transformer_sample_maskgit(
dec_output=dec_out[:, -1, :],
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
n_steps=maskgit_n_steps,
noise_scale=maskgit_noise_scale,
fixed_schedule=maskgit_fixed_schedule,
dynamic_cfg_scale=maskgit_dynamic_cfg_scale,
sampling_type=maskgit_sampling_type,
forbid_audio_eos=forbid_audio_eos,
)
else:
raise ValueError(
f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}"
)
else:
# Parallel sampling from all codebooks
audio_codes_next = self.sample_codes_from_logits(
all_code_logits_t,
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)
all_codes_next_argmax = self.sample_codes_from_logits(
all_code_logits_t,
temperature=0.01,
topk=1,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)
for item_idx in range(all_codes_next_argmax.size(0)):
if item_idx not in end_indices:
end_frame_index = self.detect_eos(
audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method
)
if end_frame_index != float('inf'):
global_index = idx * self.frame_stacking_factor + end_frame_index
end_indices[item_idx] = global_index
print(f"End detected for item {item_idx} at decoder timestep: {idx}")
all_predictions.append(audio_codes_next)
audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) # (B, C, T')
audio_codes_lens = audio_codes_lens + 1 # already in stacked form
audio_codes_mask = get_mask_from_lengths(audio_codes_lens)
if len(end_indices) == text.size(0) and len(all_predictions) >= 4:
# Codec must be of atleast 4 timesteps to be decoded properly
print("All ends reached")
break
tts_generation_time = time.time() - start_time
tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor)
# Concatenate the list of predictions along the time dimension. Note that when frame stacking is on,
# this also undoes the stacking.
predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T')
predicted_lens = [
end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))
] # Ensure that the codec is atleast of length 4
predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long()
predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens)
end_time = time.time()
total_audio_duration_generated = (
predicted_audio_lens.max().item() * predicted_audio_lens.shape[0]
) / self.sample_rate
rtf = total_audio_duration_generated / (end_time - start_time)
rtf_metrics = {
'rtf': rtf,
'time_to_first_prediction': time_to_first_prediction,
'tts_generation_time': tts_generation_time,
'max_frames_generated': len(all_predictions),
'tts_generation_time_per_frame': tts_generation_time_per_frame,
'batch_size': text.size(0),
}
torch.cuda.empty_cache()
cross_attention_maps = None
headwise_cross_attention_maps = None
if return_cross_attn_probs:
cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots(
cross_attention_scores_all_timesteps,
all_heads_cross_attn_scores_all_timesteps,
context_tensors['text_lens'],
predicted_codes_lens,
text.size(0),
compute_all_heads_attn_maps,
last_attended_timesteps,
)
return InferBatchOutput(
predicted_audio=predicted_audio,
predicted_audio_lens=predicted_audio_lens,
predicted_codes=predicted_codes,
predicted_codes_lens=predicted_codes_lens,
rtf_metrics=rtf_metrics,
cross_attention_maps=cross_attention_maps,
headwise_cross_attention_maps=headwise_cross_attention_maps,
)
def test_step(self, batch, batch_idx):
with torch.no_grad():
test_dl_batch_size = self._test_dl.batch_size
temperature = self.cfg.get('inference_temperature', 0.7)
topk = self.cfg.get('inference_topk', 80)
use_cfg = self.cfg.get('inference_use_cfg', False)
cfg_scale = self.cfg.get('inference_cfg_scale', 1.0)
output = self.infer_batch(
batch,
max_decoder_steps=self.cfg.get('max_decoder_steps', 500),
temperature=temperature,
topk=topk,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
)
predicted_audio = output.predicted_audio
predicted_audio_lens = output.predicted_audio_lens
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
for idx in range(predicted_audio.size(0)):
predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]]
item_idx = batch_idx * test_dl_batch_size + idx
if is_wandb:
log_dict = {
"test/predicted_audio": wandb.Audio(
predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio"
),
}
logger.experiment.log(log_dict, step=item_idx)
if is_tb:
logger.experiment.add_audio(
'test/predicted_audio',
predicted_audio_np,
global_step=item_idx,
sample_rate=self.sample_rate,
)
# Save the predicted audio
log_dir = logger.log_dir
audio_dir = os.path.join(log_dir, 'audios')
if not os.path.exists(audio_dir):
os.makedirs(audio_dir)
audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav')
sf.write(audio_path, predicted_audio_np, self.sample_rate)
def on_validation_epoch_end(self):
collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean()
val_loss = collect("val_loss")
val_codebook_loss = collect("val_codebook_loss")
val_alignment_loss = collect("val_alignment_loss")
val_aligner_encoder_loss = collect("val_aligner_encoder_loss")
# log val_loss in the same group as the other val metrics.
self.log("val/loss", val_loss, prog_bar=True, sync_dist=True)
# ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs.
self.log(
"val_loss",
val_loss,
prog_bar=False,
sync_dist=True,
on_step=False,
on_epoch=True,
logger=False,
enable_graph=False,
)
self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True)
self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True)
self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True)
if self.local_transformer_type != LocalTransformerType.NO_LT:
val_local_transformer_loss = collect("val_local_transformer_loss")
self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True)
self.validation_step_outputs.clear() # free memory
def get_dataset(self, dataset_cfg, dataset_type):
dataset = instantiate(
dataset_cfg.dataset,
sample_rate=self.sample_rate,
bos_id=self.bos_id,
eos_id=self.eos_id,
audio_bos_id=self.audio_bos_id,
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.data_num_audio_codebooks,
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
context_duration_min=self.cfg.context_duration_min,
context_duration_max=self.cfg.context_duration_max,
text_context_remapping=self.text_context_remapping,
text_context_remapping_prob=self.text_context_remapping_prob,
)
dataset.load_16khz_audio = False
dataset.tokenizer_config = (
self.cfg.text_tokenizers
) # This will be used in worker_init_fn for instantiating tokenizer
return dataset
def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader:
# TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
# cfg is a classifier-free guidance.
dataset = MagpieTTSLhotseDataset(
sample_rate=self.sample_rate,
volume_norm=dataset_cfg.volume_norm,
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
audio_bos_id=self.audio_bos_id,
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.data_num_audio_codebooks,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
load_16khz_audio=False,
pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
context_duration_min=self.cfg.context_duration_min,
context_duration_max=self.cfg.context_duration_max,
use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
tokenizer_config=self.cfg.text_tokenizers,
text_context_remapping=self.text_context_remapping,
text_context_remapping_prob=self.text_context_remapping_prob,
)
data_loader = get_lhotse_dataloader_from_config(
config=dataset_cfg.dataset,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=dataset,
)
return data_loader
def setup_training_data(self, dataset_cfg):
if dataset_cfg.get("use_lhotse", False):
# TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
# cfg is a classifier-free guidance.
# specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
if not isinstance(dataset_cfg, DictConfig):
dataset_cfg = OmegaConf.create(dataset_cfg)
OmegaConf.set_struct(dataset_cfg.dataset, False)
dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
OmegaConf.set_struct(dataset_cfg.dataset, True)
self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train')
else:
dataset = self.get_dataset(dataset_cfg, dataset_type='train')
sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
persistent_workers = True
if dataset_cfg.dataloader_params.num_workers == 0:
persistent_workers = False
# For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
dataset.text_tokenizer = setup_tokenizers(
all_tokenizers_config=self.cfg.text_tokenizers,
mode='train',
)
self._train_dl = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
sampler=sampler,
**dataset_cfg.dataloader_params,
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
)
def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader:
if dataset_cfg.get("use_lhotse", False):
# specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
if not isinstance(dataset_cfg, DictConfig):
dataset_cfg = OmegaConf.create(dataset_cfg)
OmegaConf.set_struct(dataset_cfg.dataset, False)
dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
OmegaConf.set_struct(dataset_cfg.dataset, True)
data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test')
else:
dataset = self.get_dataset(dataset_cfg, dataset_type='test')
persistent_workers = True
if dataset_cfg.dataloader_params.num_workers == 0:
persistent_workers = False
# For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test')
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
**dataset_cfg.dataloader_params,
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
)
return data_loader
def setup_validation_data(self, dataset_cfg):
self._validation_dl = self._setup_test_dataloader(dataset_cfg)
def setup_test_data(self, dataset_cfg):
self._test_dl = self._setup_test_dataloader(dataset_cfg)
def setup_dummy_text_context_in_batch(
self,
batch: Dict[str, torch.Tensor],
) -> bool:
"""Setup dummy text context tensors in the batch dictionary.
"""
# No text context provided - set up dummy if model requires text conditioning tensors
dummy_context_text = "[NO TEXT CONTEXT]"
dummy_tokens = self.tokenizer.encode(
text=dummy_context_text, tokenizer_name=self.text_conditioning_tokenizer_name
)
batch['context_text_tokens'] = torch.tensor([dummy_tokens], device=self.device, dtype=torch.long)
batch['context_text_tokens_lens'] = torch.tensor([len(dummy_tokens)], device=self.device, dtype=torch.long)
batch['has_text_context'] = torch.tensor([False], device=self.device, dtype=torch.bool)
def setup_dummy_audio_context_in_batch(
self,
batch: Dict[str, torch.Tensor],
context_audio: Optional[torch.Tensor] = None,
context_audio_lens: Optional[torch.Tensor] = None,
) -> bool:
"""Setup dummy audio context tensors in the batch dictionary.
"""
# Model has baked context - create minimal dummy context tensors
# These will be ignored in prepare_context_tensors when baked embedding is used
dummy_context_codes = torch.zeros(
1, self.num_audio_codebooks, 2, device=self.device, dtype=torch.long
)
dummy_context_codes[:, :, 0] = self.context_audio_bos_id
dummy_context_codes[:, :, 1] = self.context_audio_eos_id
batch['context_audio_codes'] = dummy_context_codes
batch['context_audio_codes_lens'] = torch.tensor([2], device=self.device, dtype=torch.long)
def do_tts(
self,
transcript: str,
language: str = "en",
apply_TN: bool = False,
temperature: float = 0.7,
topk: int = 80,
max_decoder_steps: int = 500,
use_cfg: bool = True,
cfg_scale: float = 2.5,
) -> tuple:
"""
Generate speech from raw text transcript.
This is a convenience method for single-utterance text-to-speech synthesis.
For batch processing, use `infer_batch` directly. Only supports baked context embedding
context injection, NO audio conditioning and text conditioning.
Custom voice generation is not supported by this method.
Args:
transcript: Raw text to synthesize.
language: Language code for text normalization and tokenization.
Supported values depend on model's tokenizer configuration.
Common: "en" (English), "de" (German), "es" (Spanish), etc.
apply_TN: Whether to apply text normalization to the transcript.
If True, uses nemo_text_processing for normalization.
temperature: Sampling temperature for token generation.
topk: Top-k sampling parameter.
max_decoder_steps: Maximum number of decoder steps.
use_cfg: Whether to use classifier-free guidance.
cfg_scale: Scale factor for classifier-free guidance.
Returns:
Tuple of (audio, audio_len) where:
audio: Generated audio waveform. Shape: (1, T_audio).
audio_len: Length of generated audio in samples. Shape: (1,).
Raises:
ValueError: If model does not have a baked context embedding.
ImportError: If apply_TN=True but nemo_text_processing is not installed.
Example:
>>> # If text does not need to be normalized
>>> audio, audio_len = model.do_tts("Hello, how are you today?")
>>>
>>> # If text needs to be normalized
>>> audio, audio_len = model.do_tts(
... "Hello, how are you today?",
... apply_TN=True,
... )
"""
assert self.has_baked_context_embedding, "Model does not have a baked context embedding. Please use a checkpoint with a baked context embedding."
# Apply text normalization if requested
normalized_text = transcript
if apply_TN:
try:
from nemo_text_processing.text_normalization.normalize import Normalizer
normalizer = Normalizer(input_case='cased', lang=language)
normalized_text = normalizer.normalize(transcript, verbose=False)
logging.debug(f"Text normalization: '{transcript}' -> '{normalized_text}'")
except ImportError:
logging.warning(
"nemo_text_processing not installed. Skipping text normalization. "
"Install with: pip install nemo_text_processing"
)
# Determine tokenizer name based on language
# Try to find a matching tokenizer, fallback to first available
tokenizer_name = None
available_tokenizers = list(self.tokenizer.tokenizers.keys())
print(f"Available tokenizers: {available_tokenizers}")
# Common mappings for tokenizer names
language_tokenizer_map = {
"en": ["english_phoneme", "english"],
"de": ["german_phoneme", "german"],
"es": ["spanish_phoneme", "spanish"],
"fr": ["french_phoneme", "french"],
"it": ["italian_phoneme", "italian"],
"vi": ["vietnamese_phoneme", "vietnamese"],
"zh": ["mandarin_phoneme", "mandarin", "chinese"],
}
# Find matching tokenizer
if language in language_tokenizer_map:
for candidate in language_tokenizer_map[language]:
if candidate in available_tokenizers:
tokenizer_name = candidate
break
# Fallback to first available tokenizer
if tokenizer_name is None:
tokenizer_name = available_tokenizers[0]
logging.info(
f"No tokenizer found for language '{language}'. "
f"Using '{tokenizer_name}'. Available: {available_tokenizers}"
)
# Tokenize the transcript text
tokens = self.tokenizer.encode(text=normalized_text, tokenizer_name=tokenizer_name)
tokens = tokens + [self.eos_id] # Add EOS token (BOS not used per dataset convention)
text_tensor = torch.tensor([tokens], device=self.device, dtype=torch.long)
text_lens = torch.tensor([len(tokens)], device=self.device, dtype=torch.long)
# Create batch dictionary
batch = {
'text': text_tensor,
'text_lens': text_lens,
}
# Setup context in batch
if self.use_text_conditioning_encoder:
self.setup_dummy_text_context_in_batch(batch)
self.setup_dummy_audio_context_in_batch(batch)
# Run inference
with torch.no_grad():
output = self.infer_batch(
batch,
max_decoder_steps=max_decoder_steps,
temperature=temperature,
topk=topk,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
)
return output.predicted_audio, output.predicted_audio_lens
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
return []