Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import json | |
| import os | |
| import random | |
| import string | |
| from typing import Optional | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from lightning.pytorch import Trainer | |
| from omegaconf import DictConfig, open_dict | |
| import nemo.collections.asr as nemo_asr | |
| from nemo.collections.asr.metrics.wer import word_error_rate | |
| from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors | |
| from nemo.utils import logging | |
| try: | |
| import torchaudio | |
| from torchaudio.pipelines import SQUIM_OBJECTIVE | |
| HAVE_TORCHAUDIO = True | |
| except ImportError: | |
| HAVE_TORCHAUDIO = False | |
| try: | |
| from nemo_text_processing.text_normalization.normalize import Normalizer | |
| PYNINI_AVAILABLE = True | |
| except (ImportError, ModuleNotFoundError): | |
| Normalizer = None | |
| PYNINI_AVAILABLE = False | |
| from nemo.collections.tts.models import MagpieTTSModel | |
| class MagpieTTSModelOfflinePODataGen(MagpieTTSModel): | |
| """Small override of MagpieTTSModel for parallel multi-GPU inference and metrics calculation. | |
| This class is used in 'test' mode and leverages trainer.test() for multi-GPU/multi-node inference. | |
| Saves the predicted audio files and logs the CER/WER metrics as individual json files for each audio. | |
| """ | |
| def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): | |
| super().__init__(cfg, trainer) | |
| if cfg.get('pref_set_language', "en") == "en": | |
| self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( | |
| model_name="nvidia/parakeet-ctc-0.6b" | |
| ) | |
| self.eval_asr_model.freeze() | |
| self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | |
| model_name='titanet_large' | |
| ) | |
| self.eval_speaker_verification_model.freeze() | |
| if cfg.get('load_whisper_model', False): | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model.eval() | |
| self._normalize_whisper_transcript = cfg.get('normalize_whisper_transcript', True) | |
| if self._normalize_whisper_transcript and PYNINI_AVAILABLE: | |
| self._normalizer_cache = {} | |
| # Pre-create normalizer for the configured language | |
| lang = cfg.get('pref_set_language', 'en') | |
| self._get_cached_normalizer(lang) | |
| def _get_cached_normalizer(self, lang_key): | |
| """Get or create a cached normalizer for the given language.""" | |
| if not PYNINI_AVAILABLE: | |
| return None | |
| lang_key = lang_key if lang_key else "en" | |
| if lang_key not in self._normalizer_cache: | |
| logging.info(f"Creating normalizer for language: {lang_key}") | |
| self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) | |
| return self._normalizer_cache[lang_key] | |
| def test_step(self, batch, batch_idx): | |
| with torch.no_grad(): | |
| test_dl_batch_size = self._test_dl.batch_size | |
| temperature = self.cfg.get('inference_temperature', 0.7) | |
| topk = self.cfg.get('inference_topk', 80) | |
| use_cfg = self.cfg.get('inference_use_cfg', False) | |
| cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) | |
| output = self.infer_batch( | |
| batch, | |
| max_decoder_steps=self.cfg.get('max_decoder_steps', 500), | |
| temperature=temperature, | |
| topk=topk, | |
| use_cfg=use_cfg, | |
| cfg_scale=cfg_scale, | |
| ) | |
| predicted_audio = output.predicted_audio | |
| predicted_audio_lens = output.predicted_audio_lens | |
| predicted_codes = output.predicted_codes | |
| predicted_codes_lens = output.predicted_codes_lens | |
| predicted_audio_paths = [] | |
| audio_durations = [] | |
| batch_invalid = False | |
| for idx in range(predicted_audio.size(0)): | |
| predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() | |
| predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] | |
| item_idx = batch_idx * test_dl_batch_size + idx | |
| # Save the predicted audio | |
| log_dir = self.logger.log_dir | |
| audio_dir = os.path.join(log_dir, 'audios') | |
| if not os.path.exists(audio_dir): | |
| os.makedirs(audio_dir) | |
| audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') | |
| audio_durations.append(len(predicted_audio_np) / self.sample_rate) | |
| sf.write(audio_path, predicted_audio_np, self.sample_rate) | |
| predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) | |
| predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] | |
| torch.save( | |
| predicted_codes_torch, | |
| os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), | |
| ) | |
| predicted_audio_paths.append(audio_path) | |
| if not batch_invalid: | |
| with torch.no_grad(): | |
| try: | |
| if self.cfg.get("pref_set_language", "en") == "en": | |
| pred_transcripts = self.eval_asr_model.transcribe( | |
| predicted_audio_paths, batch_size=len(predicted_audio_paths) | |
| ) | |
| pred_transcripts = [ | |
| process_text_for_cer(transcript.text) for transcript in pred_transcripts | |
| ] | |
| else: | |
| pred_transcripts = [] | |
| for audio_path in predicted_audio_paths: | |
| normalizer = ( | |
| self._get_cached_normalizer(self.cfg.pref_set_language) | |
| if self._normalize_whisper_transcript | |
| else None | |
| ) | |
| transcript = transcribe_with_whisper( | |
| audio_path, | |
| self.cfg.pref_set_language, | |
| self.whisper_processor, | |
| self.whisper_model, | |
| self.device, | |
| normalizer, | |
| ) | |
| pred_transcripts.append(transcript) | |
| pred_transcripts = [ | |
| process_text_for_cer(transcript) for transcript in pred_transcripts | |
| ] | |
| except Exception as e: | |
| assert ( | |
| predicted_audio_lens[idx] < 1000 | |
| ).any(), f"Expected short audio file to be the only cause of ASR errors, but got error with lengths {predicted_audio_lens}" | |
| logging.warning(f"Exception during ASR transcription: {e}") | |
| logging.warning( | |
| "Skipping processing of the batch; generating metrics indicating a WER of 100% and Speaker Similarity of 0.0" | |
| ) | |
| batch_invalid = True | |
| continue # don't break since we want to continue building audio durations list | |
| pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( | |
| predicted_audio_paths, self.eval_speaker_verification_model, self.device | |
| ) | |
| gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( | |
| batch['audio_filepaths'], self.eval_speaker_verification_model, self.device | |
| ) | |
| for idx in range(predicted_audio.size(0)): | |
| if not batch_invalid: | |
| item_idx = batch_idx * test_dl_batch_size + idx | |
| pred_transcript = pred_transcripts[idx] | |
| gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) | |
| cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) | |
| wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) | |
| spk_embedding_pred = pred_speaker_embeddings[idx].cpu().numpy() | |
| spk_embedding_gt = gt_speaker_embeddings[idx].cpu().numpy() | |
| spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( | |
| np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) | |
| ) | |
| else: | |
| # Create an entry indicating invalid metrics | |
| cer_gt = 1.0 | |
| wer_gt = 1.0 | |
| spk_similarity = 0.0 | |
| pred_transcript = "<INVALID>" # do not change this string; subsequent processing relies on it | |
| gt_transcript = process_text_for_cer(batch['raw_texts'][idx]) | |
| item_metrics = { | |
| 'cer_gt': float(cer_gt), | |
| 'wer_gt': float(wer_gt), | |
| 'duration': audio_durations[idx], | |
| 'spk_similarity': float(spk_similarity), | |
| 'pred_transcript': pred_transcript, | |
| 'gt_transcript': gt_transcript, | |
| } | |
| with open( | |
| os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' | |
| ) as f: | |
| json.dump(item_metrics, f) | |
| class MagpieTTSModelOfflinePO(MagpieTTSModel): | |
| """ | |
| MagpieTTS_Model_OfflinePO is a class that extends MagpieTTS_Model to support | |
| offline preference optimization (DPO, IPO, RPO). | |
| Set cfg.model.dpo_loss_type to 'dpo', 'ipo', or 'rpo' to use the corresponding loss. | |
| """ | |
| def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): | |
| super().__init__(cfg, trainer) | |
| ref_model_cfg = copy.deepcopy(cfg) | |
| with open_dict(ref_model_cfg): | |
| ref_model_cfg.train_ds = None | |
| ref_model_cfg.validation_ds = None | |
| self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) | |
| print("Loading reference model from checkpoint") | |
| self._reference_model.load_state_dict( | |
| torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] | |
| ) | |
| self._reference_model.freeze() | |
| self._reference_model._no_state_dict = True | |
| print("Reference model loaded and frozen") | |
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |
| state_dict = super().state_dict(destination, prefix, keep_vars) | |
| keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model', '_reference_model'] | |
| for key in list(state_dict.keys()): | |
| if any([substring in key for substring in keys_substrings_to_exclude]): | |
| del state_dict[key] | |
| return state_dict | |
| def _get_batch_logps(self, logits, labels, loss_mask, average_log_prob=False): | |
| """Compute the log probabilities of the given labels under the given logits. | |
| Args: | |
| logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) | |
| labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) | |
| average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. | |
| Returns: | |
| A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. | |
| """ | |
| per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) | |
| if average_log_prob: | |
| return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) | |
| else: | |
| return (per_token_logps * loss_mask).sum(-1) | |
| def preference_loss( | |
| self, | |
| policy_chosen_logps, | |
| policy_rejected_logps, | |
| reference_chosen_logps, | |
| reference_rejected_logps, | |
| chosen_gt_rewards=None, | |
| rejected_gt_rewards=None, | |
| beta=0.2, | |
| gt_reward_scale=1.0, | |
| label_smoothing=0, | |
| loss_type="dpo", | |
| reference_free=False, | |
| ): | |
| """Compute the DPO loss for a batch of policy and reference model log probabilities. | |
| Referenced From: https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py | |
| Args: | |
| policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) | |
| policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) | |
| reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) | |
| reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) | |
| beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. | |
| label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing) | |
| ipo: If True, use the IPO loss instead of the DPO loss. | |
| reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. | |
| Returns: | |
| A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). | |
| The losses tensor contains the DPO loss for each example in the batch. | |
| The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. | |
| """ | |
| pi_logratios = policy_chosen_logps - policy_rejected_logps | |
| ref_logratios = reference_chosen_logps - reference_rejected_logps | |
| if reference_free: | |
| ref_logratios = 0 | |
| logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} | |
| # logits = (policy_chosen_logps - policy_rejected_logps) - (reference_chosen_logps - reference_rejected_logps) | |
| # logits = (policy_chosen_logps - reference_chosen_logps) - (policy_rejected_logps - reference_rejected_logps) | |
| # logits is the same as rewards_delta in NeMo aligner: https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 | |
| if loss_type == "ipo": | |
| losses = (logits - 1 / (2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf | |
| elif loss_type == "rpo": | |
| # https://github.com/NVIDIA/NeMo-Aligner/blob/0b5bffeb78a8316dd57e0816a2a9544540f0c8dd/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py#L241 | |
| logbeta_hat_chosen = torch.nn.functional.logsigmoid(beta * logits) | |
| logbeta_hat_rejected = torch.nn.functional.logsigmoid(-beta * logits) | |
| gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) | |
| logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) | |
| logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) | |
| losses = torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + torch.exp( | |
| logalpha_hat_rejected | |
| ) * (logalpha_hat_rejected - logbeta_hat_rejected) | |
| elif loss_type == "rpo_sq": | |
| gt_rewards_delta = gt_reward_scale * (chosen_gt_rewards - rejected_gt_rewards) | |
| losses = (beta * logits - gt_rewards_delta) ** 2 | |
| elif loss_type == "dpo": | |
| # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) | |
| F = torch.nn.functional | |
| losses = ( | |
| -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing | |
| ) | |
| else: | |
| raise NotImplementedError("loss type {} is not implemented".format(loss_type)) | |
| chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() | |
| rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() | |
| return losses, chosen_rewards, rejected_rewards | |
| def process_batch_dpo(self, batch_chosen_rejected): | |
| batch_chosen = batch_chosen_rejected['chosen'] | |
| batch_rejected = batch_chosen_rejected['rejected'] | |
| model_output_chosen = self.process_batch(batch_chosen) | |
| model_output_rejected = self.process_batch(batch_rejected) | |
| with torch.no_grad(): | |
| reference_model_output_chosen = self._reference_model.process_batch(batch_chosen) | |
| reference_model_output_rejected = self._reference_model.process_batch(batch_rejected) | |
| chosen_policy_logprobs = None | |
| rejected_policy_logprobs = None | |
| chosen_ref_logprobs = None | |
| rejected_ref_logprobs = None | |
| for codebook_idx in range(self.num_audio_codebooks): | |
| si = codebook_idx * self.num_all_tokens_per_codebook | |
| ei = si + self.num_all_tokens_per_codebook | |
| codebook_logits_chosen = model_output_chosen['logits'][:, :, si:ei] | |
| codebook_logits_rejected = model_output_rejected['logits'][:, :, si:ei] | |
| ref_codebook_logits_chosen = reference_model_output_chosen['logits'][:, :, si:ei] | |
| ref_codebook_logits_rejected = reference_model_output_rejected['logits'][:, :, si:ei] | |
| codebook_labels_chosen = model_output_chosen['audio_codes_target'][:, codebook_idx] | |
| codebook_labels_rejected = model_output_rejected['audio_codes_target'][:, codebook_idx] | |
| codebook_log_probs_chosen = self._get_batch_logps( | |
| codebook_logits_chosen, codebook_labels_chosen, model_output_chosen['loss_mask'][:, codebook_idx] | |
| ) | |
| codebook_log_probs_rejected = self._get_batch_logps( | |
| codebook_logits_rejected, codebook_labels_rejected, model_output_rejected['loss_mask'][:, codebook_idx] | |
| ) | |
| with torch.no_grad(): | |
| ref_codebook_log_probs_chosen = self._get_batch_logps( | |
| ref_codebook_logits_chosen, | |
| codebook_labels_chosen, | |
| reference_model_output_chosen['loss_mask'][:, codebook_idx], | |
| ) | |
| ref_codebook_log_probs_rejected = self._get_batch_logps( | |
| ref_codebook_logits_rejected, | |
| codebook_labels_rejected, | |
| reference_model_output_rejected['loss_mask'][:, codebook_idx], | |
| ) | |
| if chosen_policy_logprobs is None: | |
| chosen_policy_logprobs = codebook_log_probs_chosen | |
| rejected_policy_logprobs = codebook_log_probs_rejected | |
| chosen_ref_logprobs = ref_codebook_log_probs_chosen | |
| rejected_ref_logprobs = ref_codebook_log_probs_rejected | |
| else: | |
| chosen_policy_logprobs += codebook_log_probs_chosen | |
| rejected_policy_logprobs += codebook_log_probs_rejected | |
| chosen_ref_logprobs += ref_codebook_log_probs_chosen | |
| rejected_ref_logprobs += ref_codebook_log_probs_rejected | |
| rewards_chosen = batch_chosen['rewards'] | |
| rewards_rejected = batch_rejected['rewards'] | |
| assert torch.all(rewards_chosen == 1) | |
| assert torch.all(rewards_rejected < 1) | |
| pref_loss, chosen_rewards, rejected_rewards = self.preference_loss( | |
| chosen_policy_logprobs, | |
| rejected_policy_logprobs, | |
| chosen_ref_logprobs, | |
| rejected_ref_logprobs, | |
| chosen_gt_rewards=rewards_chosen, | |
| rejected_gt_rewards=rewards_rejected, | |
| beta=self.cfg.get('dpo_beta', 0.01), | |
| loss_type=self.cfg.get('dpo_loss_type', 'dpo'), | |
| ) | |
| pref_loss = pref_loss.mean() | |
| sft_loss = -chosen_policy_logprobs.mean() | |
| pref_loss_weight = self.cfg.get('dpo_pref_loss_weight', 1.0) | |
| sft_loss_weight = self.cfg.get('dpo_sft_loss_weight', 0.0) | |
| loss = pref_loss_weight * pref_loss + sft_loss * sft_loss_weight | |
| alignment_loss = model_output_chosen['alignment_loss'] | |
| if alignment_loss is not None: | |
| loss += alignment_loss | |
| return { | |
| 'loss': loss, | |
| 'pref_loss': pref_loss, | |
| 'sft_loss': sft_loss, | |
| 'alignment_loss': alignment_loss, | |
| } | |
| def training_step(self, batch, batch_idx): | |
| dpo_outputs = self.process_batch_dpo(batch) | |
| self.log('train_loss', dpo_outputs['loss'], prog_bar=True, sync_dist=True) | |
| self.log('train_pref_loss', dpo_outputs['pref_loss'], prog_bar=True, sync_dist=True) | |
| self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) | |
| return dpo_outputs['loss'] | |
| def validation_step(self, batch, batch_idx): | |
| dpo_outputs = self.process_batch_dpo(batch) | |
| val_loss = dpo_outputs['loss'] | |
| val_pref_loss = dpo_outputs['pref_loss'] | |
| val_sft_loss = dpo_outputs['sft_loss'] | |
| val_alignment_loss = dpo_outputs['alignment_loss'] | |
| self.validation_step_outputs.append( | |
| { | |
| 'val_loss': val_loss, | |
| 'val_pref_loss': val_pref_loss, | |
| 'val_sft_loss': val_sft_loss, | |
| 'val_alignment_loss': val_alignment_loss, | |
| } | |
| ) | |
| def on_validation_epoch_end(self): | |
| def collect(key): | |
| values = [] | |
| for x in self.validation_step_outputs: | |
| if x[key] is not None: | |
| values.append(x[key]) | |
| else: | |
| values.append(torch.tensor(0.0, device=self.device)) | |
| stacked_values = torch.stack(values) | |
| return stacked_values.mean() | |
| val_loss = collect("val_loss") | |
| val_pref_loss = collect("val_pref_loss") | |
| val_sft_loss = collect("val_sft_loss") | |
| val_alignment_loss = collect("val_alignment_loss") | |
| self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) | |
| self.log("val_pref_loss", val_pref_loss, prog_bar=True, sync_dist=True) | |
| self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) | |
| if val_alignment_loss is not None: | |
| self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) | |
| self.validation_step_outputs.clear() | |
| class MagpieTTSModelOnlinePO(MagpieTTSModel): | |
| """ | |
| MagpieTTS_Model_OnlinePO is a class that extends MagpieTTS_Model to support | |
| online preference optimization (GRPO). | |
| """ | |
| def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): | |
| super().__init__(cfg, trainer) | |
| # Copy cfg | |
| ref_model_cfg = copy.deepcopy(cfg) | |
| with open_dict(ref_model_cfg): | |
| ref_model_cfg.train_ds = None | |
| ref_model_cfg.validation_ds = None | |
| self.reference_free = self.cfg.get('reference_free', False) # True means we dont use the reference model | |
| if not self.reference_free: | |
| self._reference_model = MagpieTTSModel(cfg=ref_model_cfg) | |
| print("Loading reference model from checkpoint") | |
| self._reference_model.load_state_dict( | |
| torch.load(cfg.reference_model_ckpt_path, map_location="cpu", weights_only=False)['state_dict'] | |
| ) | |
| self._reference_model.freeze() | |
| self._reference_model._no_state_dict = True | |
| print("Reference model loaded and frozen") | |
| if cfg.get('reward_asr_model', "nemo") == "nemo": | |
| self.eval_asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained( | |
| model_name="nvidia/parakeet-ctc-0.6b" | |
| ) | |
| self.eval_asr_model.freeze() | |
| elif cfg.get('reward_asr_model', "nemo") == "whisper": | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model.eval() | |
| else: | |
| raise ValueError(f"Unknown reward_asr_model: {cfg.reward_asr_model}") | |
| self.eval_speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | |
| model_name='titanet_large' | |
| ) | |
| self.eval_speaker_verification_model.freeze() | |
| if cfg.get('load_whisper_model', False): | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") | |
| self.whisper_model.eval() | |
| use_pesq = self.cfg.get('use_pesq', False) | |
| if use_pesq: | |
| assert HAVE_TORCHAUDIO, "torchaudio is required for PESQ reward" | |
| self.squim_objective_model = SQUIM_OBJECTIVE.get_model() | |
| self.loss_type = self.cfg.get('loss_type', 'grpo') | |
| if self.loss_type not in ['grpo', 'dr_grpo']: | |
| raise ValueError( | |
| f"Received loss_type of {self.loss_type}, but the model only accepts one of ['grpo', 'dr_grpo']" | |
| ) | |
| self.scale_rewards = self.cfg.get('scale_rewards', True) | |
| self.max_decoder_steps = self.cfg.get('max_decoder_steps', 430) | |
| self._normalize_whisper_transcript = self.cfg.get('normalize_whisper_transcript', True) | |
| if cfg.get('reward_asr_model', "nemo") == "whisper" and self._normalize_whisper_transcript: | |
| self._normalizer_cache = {} | |
| # If the best record in the group is above this threshold, we will not use that group for training | |
| # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO | |
| self.best_cer_threshold = self.cfg.get('best_cer_threshold', 1.0) | |
| # If the worst record in the group exceeds this threshold, we will not use that group for training | |
| # Setting this to 1.0, because we clamp the ASR rewards to be in [0, 1] for OnlinePO | |
| self.worst_cer_threshold = self.cfg.get('worst_cer_threshold', 1.0) | |
| def _get_cached_normalizer(self, lang_key): | |
| """Get or create a cached normalizer for the given language.""" | |
| if not PYNINI_AVAILABLE: | |
| return None | |
| lang_key = lang_key if lang_key else "en" | |
| if lang_key not in self._normalizer_cache: | |
| logging.info(f"Creating normalizer for language: {lang_key}") | |
| self._normalizer_cache[lang_key] = Normalizer(input_case="cased", lang=lang_key) | |
| return self._normalizer_cache[lang_key] | |
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |
| state_dict = super().state_dict(destination, prefix, keep_vars) | |
| keys_substrings_to_exclude = [ | |
| '_speaker_verification_model', | |
| '_codec_model', | |
| '_reference_model', | |
| 'eval_asr_model', | |
| 'eval_speaker_verification_model', | |
| 'whisper_model', | |
| ] | |
| for key in list(state_dict.keys()): | |
| if any([substring in key for substring in keys_substrings_to_exclude]): | |
| del state_dict[key] | |
| return state_dict | |
| def _get_per_token_logps(self, logits, labels, loss_mask): | |
| """Compute the log probabilities of the given labels under the given logits. | |
| Args: | |
| logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) | |
| labels: Labels for which to compute the log probabilities. | |
| """ | |
| per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) | |
| per_token_logps = per_token_logps * loss_mask | |
| return per_token_logps | |
| def repeat_items_in_batch(self, batch, num_repeats): | |
| repeated_batch = {} | |
| for key, value in batch.items(): | |
| if isinstance(value, torch.Tensor): | |
| repeated_value = value.repeat_interleave(num_repeats, dim=0) | |
| elif isinstance(value, list): | |
| repeated_value = [] | |
| for item in value: | |
| repeated_value.extend([item] * num_repeats) | |
| else: | |
| repeated_value = value | |
| repeated_batch[key] = repeated_value | |
| return repeated_batch | |
| def generate_and_reward( | |
| self, batch, num_generations_per_item, mode='train', use_local_transformer_for_inference=False | |
| ): | |
| batch_repeated = self.repeat_items_in_batch(batch, num_generations_per_item) | |
| temperature = self.cfg.get('inference_temperature', 0.7) | |
| topk = self.cfg.get('inference_topk', 80) | |
| use_cfg = False | |
| cfg_scale = 1.0 | |
| use_pesq = self.cfg.get('use_pesq', False) | |
| inference_cfg_prob = self.cfg.get('inference_cfg_prob', 0.0) | |
| if (inference_cfg_prob == 1.0) or (inference_cfg_prob > 0.0 and mode == 'train'): | |
| # Randomly set use_cfg based on the given probability | |
| use_cfg = random.random() < self.cfg.inference_cfg_prob | |
| cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) | |
| output = self.infer_batch( | |
| batch_repeated, | |
| max_decoder_steps=self.max_decoder_steps, | |
| temperature=temperature, | |
| topk=topk, | |
| use_cfg=use_cfg, | |
| cfg_scale=cfg_scale, | |
| use_local_transformer_for_inference=use_local_transformer_for_inference, | |
| use_LT_kv_cache=False, # We don't use KV caching for local transformer in GRPO due to issues. | |
| ) | |
| predicted_audio = output.predicted_audio | |
| predicted_audio_lens = output.predicted_audio_lens | |
| predicted_codes = output.predicted_codes | |
| predicted_codes_lens = output.predicted_codes_lens | |
| predicted_audio_paths = [] | |
| audio_durations = [] | |
| for idx in range(predicted_audio.size(0)): | |
| predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() | |
| predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] | |
| if predicted_audio_np.shape[0] < 1000: | |
| # Corner case to handle short audio files | |
| predicted_audio_np = np.pad(predicted_audio_np, (0, 1000 - predicted_audio_np.shape[0])) | |
| item_idx = idx | |
| # Save the predicted audio | |
| log_dir = self.logger.log_dir | |
| audio_dir = os.path.join(log_dir, 'audios') | |
| os.makedirs(audio_dir, exist_ok=True) | |
| audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') | |
| audio_durations.append(len(predicted_audio_np) / self.sample_rate) | |
| sf.write(audio_path, predicted_audio_np, self.sample_rate) | |
| predicted_codes_torch = predicted_codes[idx].cpu().type(torch.int16) | |
| predicted_codes_torch = predicted_codes_torch[:, : predicted_codes_lens[idx]] # C, T | |
| torch.save( | |
| predicted_codes_torch, | |
| os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_codes.pt'), | |
| ) | |
| predicted_audio_paths.append(audio_path) | |
| with torch.no_grad(): | |
| if self.cfg.get("reward_asr_model", "nemo") == "nemo": | |
| pred_transcripts = self.eval_asr_model.transcribe( | |
| predicted_audio_paths, batch_size=len(predicted_audio_paths) | |
| ) | |
| pred_transcripts = [process_text_for_cer(transcript.text) for transcript in pred_transcripts] | |
| elif self.cfg.get("reward_asr_model", "nemo") == "whisper": | |
| pred_transcripts = [] | |
| for item_idx, audio_path in enumerate(predicted_audio_paths): | |
| language = batch_repeated['languages'][item_idx] | |
| normalizer = self._get_cached_normalizer(language) if self._normalize_whisper_transcript else None | |
| transcript = transcribe_with_whisper( | |
| audio_path, language, self.whisper_processor, self.whisper_model, self.device, normalizer | |
| ) | |
| pred_transcripts.append(transcript) | |
| pred_transcripts = [process_text_for_cer(transcript) for transcript in pred_transcripts] | |
| else: | |
| # Address CodeQL issue where pred_transcripts might be undefined for future code | |
| raise ValueError( | |
| f"{self} received a value of {self.cfg.get('reward_asr_model', 'nemo')} in cfg.reward_asr_model " | |
| "but this class only supports 'nemo' or 'whisper'." | |
| ) | |
| pred_speaker_embeddings = get_speaker_embeddings_from_filepaths( | |
| predicted_audio_paths, self.eval_speaker_verification_model, self.device | |
| ) | |
| gt_speaker_embeddings = get_speaker_embeddings_from_filepaths( | |
| batch_repeated['audio_filepaths'], self.eval_speaker_verification_model, self.device | |
| ) | |
| batch_metrics = [] | |
| cer_reward_weight = self.cfg.get('cer_reward_weight', 0.5) | |
| ssim_reward_weight = self.cfg.get('ssim_reward_weight', 0.5) | |
| pesq_reward_weight = self.cfg.get('pesq_reward_weight', 0.0) | |
| for idx in range(predicted_audio.size(0)): | |
| audio_path = predicted_audio_paths[idx] | |
| item_idx = idx | |
| pred_transcript = pred_transcripts[idx] | |
| gt_transcript = process_text_for_cer(batch_repeated['raw_texts'][idx]) | |
| cer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=True) | |
| wer_gt = word_error_rate([pred_transcript], [gt_transcript], use_cer=False) | |
| cer_gt = min(max(cer_gt, 0.0), 1.0) # Ensure CER is in [0, 1] | |
| wer_gt = min(max(wer_gt, 0.0), 1.0) # Ensure WER is in [0, 1] | |
| spk_embedding_pred = pred_speaker_embeddings[idx].cpu().float().numpy() | |
| spk_embedding_gt = gt_speaker_embeddings[idx].cpu().float().numpy() | |
| spk_similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( | |
| np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) | |
| ) | |
| if use_pesq: | |
| sample_audio, sr = torchaudio.load(audio_path) | |
| sample_audio = sample_audio.to(self.device) | |
| if sr != 16000: | |
| sample_audio = torchaudio.functional.resample(sample_audio, sr, 16000) | |
| _, pesq_hyp, _ = self.squim_objective_model(sample_audio) | |
| pesq_hyp = pesq_hyp.item() | |
| item_metrics = { | |
| 'cer_gt': float(cer_gt), | |
| 'wer_gt': float(wer_gt), | |
| 'duration': audio_durations[idx], | |
| 'spk_similarity': float(spk_similarity), | |
| 'pred_transcript': pred_transcript, | |
| 'gt_transcript': gt_transcript, | |
| 'codes_len': predicted_codes_lens[idx].item(), | |
| 'pesq': pesq_hyp if use_pesq else 0.0, | |
| } | |
| with open( | |
| os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}_metrics.json'), 'w' | |
| ) as f: | |
| json.dump(item_metrics, f) | |
| batch_metrics.append(item_metrics) | |
| num_groups = len(batch['audio_filepaths']) | |
| best_ssim_achievable = self.cfg.get( | |
| "best_ssim_achievable", 0.9 | |
| ) # Examples with this speaker similarity or higher will have SSIM reward of 1 | |
| mean_cer_dataset = self.cfg.get("mean_cer_dataset", 0.1) # CER equal to this value will have reward of 0.5 | |
| mean_ssim_dataset = self.cfg.get("mean_ssim_dataset", 0.6) # SSIM equal to this value will have reward of 0.5 | |
| all_groups_mean_reward = 0.0 | |
| all_groups_std_reward = 0.0 | |
| group_validities = [] | |
| for group_idx in range(num_groups): | |
| group_start_idx = group_idx * num_generations_per_item | |
| group_end_idx = group_start_idx + num_generations_per_item | |
| group_rewards = [] | |
| mean_reward = 0 | |
| is_group_valid = True | |
| group_best_cer = 1.0 | |
| group_worst_cer = 0.0 | |
| for idx in range(group_start_idx, group_end_idx): | |
| # Lower CER and higher speaker similarity is better, means high reward | |
| # Higher pesq is better, means high reward | |
| # Reward for best CER and best speaker similarity should be 1 | |
| item_cer = batch_metrics[idx]['cer_gt'] | |
| item_ssim = batch_metrics[idx]['spk_similarity'] | |
| item_cer = min(max(item_cer, 0.0), 1.0) | |
| item_ssim = max(min(item_ssim, best_ssim_achievable), 0.0) | |
| item_pesq = batch_metrics[idx]['pesq'] | |
| group_best_cer = min(group_best_cer, item_cer) | |
| group_worst_cer = max(group_worst_cer, item_cer) | |
| if item_cer <= mean_cer_dataset: | |
| cer_reward = 0.5 + 0.5 * (mean_cer_dataset - item_cer) / mean_cer_dataset # 0.5 to 1 | |
| else: | |
| cer_reward = 0.5 - 0.5 * (item_cer - mean_cer_dataset) / (1 - mean_cer_dataset) # 0 to 0.5 | |
| if item_ssim >= mean_ssim_dataset: | |
| spk_similarity_reward = 0.5 + 0.5 * (item_ssim - mean_ssim_dataset) / ( | |
| best_ssim_achievable - mean_ssim_dataset | |
| ) | |
| else: | |
| spk_similarity_reward = 0.5 - 0.5 * (mean_ssim_dataset - item_ssim) / (mean_ssim_dataset) | |
| if use_pesq: | |
| pesq_reward = item_pesq / 4.5 | |
| else: | |
| pesq_reward = 0.0 | |
| batch_metrics[idx]['reward'] = ( | |
| cer_reward * cer_reward_weight | |
| + spk_similarity_reward * ssim_reward_weight | |
| + pesq_reward * pesq_reward_weight | |
| ) | |
| if (batch_metrics[idx]['codes_len'] >= 425) or ( | |
| batch_metrics[idx]['codes_len'] <= 3 | |
| ): # TODO: Remove hardcoded lengths | |
| # This means it did not complete the sentence or generated an extremely short sentence | |
| batch_metrics[idx]['reward'] = 0.0 | |
| print( | |
| "Item idx: ", | |
| idx, | |
| " CER: ", | |
| item_cer, | |
| " SSIM: ", | |
| item_ssim, | |
| " Reward: ", | |
| batch_metrics[idx]['reward'], | |
| " Codes len: ", | |
| batch_metrics[idx]['codes_len'], | |
| ) | |
| batch_metrics[idx]['cer_reward'] = cer_reward | |
| batch_metrics[idx]['spk_similarity_reward'] = spk_similarity_reward | |
| batch_metrics[idx]['pesq_reward'] = pesq_reward | |
| mean_reward += batch_metrics[idx]['reward'] | |
| group_rewards.append(batch_metrics[idx]['reward']) | |
| if group_best_cer > self.best_cer_threshold: | |
| is_group_valid = False | |
| print( | |
| f"Group {group_idx} has best CER {group_best_cer} which is above the threshold {self.best_cer_threshold}. Group is invalid." | |
| ) | |
| if group_worst_cer > self.worst_cer_threshold: | |
| is_group_valid = False | |
| print( | |
| f"Group {group_idx} has worst CER {group_worst_cer} which is above the threshold {self.worst_cer_threshold}. Group is invalid." | |
| ) | |
| for _ in range(num_generations_per_item): | |
| group_validities.append(is_group_valid) | |
| mean_reward /= num_generations_per_item | |
| std_reward = np.std(group_rewards) | |
| all_groups_mean_reward += mean_reward | |
| all_groups_std_reward += std_reward | |
| for idx in range(group_start_idx, group_end_idx): | |
| batch_metrics[idx]['advantage'] = batch_metrics[idx]['reward'] - mean_reward | |
| if self.scale_rewards: | |
| batch_metrics[idx]['advantage'] = batch_metrics[idx]['advantage'] / (std_reward + 1e-4) | |
| all_groups_mean_reward = all_groups_mean_reward / num_groups | |
| all_groups_std_reward = all_groups_std_reward / num_groups | |
| advantages = [x['advantage'] for x in batch_metrics] | |
| advantages = torch.tensor(advantages, device=self.device) | |
| print("Mean reward: ", all_groups_mean_reward) | |
| group_validities = torch.tensor(group_validities, device=self.device) | |
| return { | |
| 'mean_reward': torch.tensor(all_groups_mean_reward, device=self.device), | |
| 'std_reward': torch.tensor(all_groups_std_reward, device=self.device), | |
| 'batch_repeated': batch_repeated, | |
| 'metrics': batch_metrics, | |
| 'predicted_codes': predicted_codes, | |
| 'predicted_codes_lens': predicted_codes_lens, | |
| 'advantages': advantages, | |
| 'group_validities': group_validities, | |
| } | |
| def process_batch_online_po(self, batch, n_generations_per_item, mode='train'): | |
| use_kv_cache_during_online_po = self.cfg.get("use_kv_cache_during_online_po", False) | |
| if use_kv_cache_during_online_po: | |
| self.use_kv_cache_for_inference = True | |
| self.decoder.reset_cache(use_cache=True) | |
| use_local_transformer_for_inference = False | |
| logits_key = 'logits' | |
| use_local_transformer_prob = self.cfg.get('use_local_transformer_prob', 0.0) | |
| if use_local_transformer_prob > 0.0 and mode == 'train': | |
| use_local_transformer_for_inference = random.random() < use_local_transformer_prob | |
| logits_key = 'local_transformer_logits' | |
| with torch.no_grad(): | |
| self.eval() | |
| generated_codes_and_metrics = self.generate_and_reward( | |
| batch, | |
| n_generations_per_item, | |
| mode, | |
| use_local_transformer_for_inference=use_local_transformer_for_inference, | |
| ) | |
| self.train() | |
| if use_kv_cache_during_online_po: | |
| self.use_kv_cache_for_inference = False | |
| self.decoder.reset_cache(use_cache=False) | |
| batch_repeated = generated_codes_and_metrics['batch_repeated'] | |
| predicted_codes = generated_codes_and_metrics['predicted_codes'] # B, 8, T | |
| predicted_codes_lens = generated_codes_and_metrics['predicted_codes_lens'] # B | |
| predicted_codes = predicted_codes[:, :, : predicted_codes_lens.max()] | |
| advantages = generated_codes_and_metrics['advantages'] # B | |
| # Add extra tokens for BOS and EOS | |
| bos_tensor = torch.full( | |
| (predicted_codes.size(0), predicted_codes.size(1), 1), | |
| self.audio_bos_id, | |
| dtype=predicted_codes.dtype, | |
| device=predicted_codes.device, | |
| ) | |
| padding_tensor = torch.full( | |
| (predicted_codes.size(0), predicted_codes.size(1), 1), | |
| 0, | |
| dtype=predicted_codes.dtype, | |
| device=predicted_codes.device, | |
| ) | |
| predicted_codes = torch.cat([bos_tensor, predicted_codes, padding_tensor], dim=2) | |
| for idx in range(predicted_codes.size(0)): | |
| predicted_codes[idx, :, predicted_codes_lens[idx] + 1] = self.audio_eos_id # Accounts for BOS | |
| batch_repeated['audio_codes'] = predicted_codes | |
| batch_repeated['audio_codes_lens'] = predicted_codes_lens + 2 # Accounts for BOS and EOS | |
| if 'audio' in batch_repeated: | |
| del batch_repeated['audio'] | |
| if 'audio_lens' in batch_repeated: | |
| del batch_repeated['audio_lens'] | |
| policy_model_outputs = self.process_batch(batch_repeated) | |
| reference_model_output = ( | |
| None # Address CodeQL issue even though this varibable is only used not self.reference_free | |
| ) | |
| if not self.reference_free: | |
| with torch.no_grad(): | |
| reference_model_output = self._reference_model.process_batch(batch_repeated) | |
| total_loss = None | |
| total_kl = None | |
| for codebook_idx in range(self.num_audio_codebooks): | |
| policy_codebook_loss_mask = policy_model_outputs['loss_mask'][:, codebook_idx, :] | |
| reference_codebook_loss_mask = ( | |
| reference_model_output['loss_mask'][:, codebook_idx, :] if not self.reference_free else None | |
| ) | |
| si = codebook_idx * self.num_all_tokens_per_codebook | |
| ei = si + self.num_all_tokens_per_codebook | |
| codebook_logits = policy_model_outputs[logits_key][:, :, si:ei] # B, T, C | |
| codebook_labels = batch_repeated['audio_codes'][:, codebook_idx, 1:] | |
| per_token_codebook_log_probs = self._get_per_token_logps( | |
| codebook_logits, codebook_labels, policy_codebook_loss_mask | |
| ) | |
| per_token_loss = -( | |
| torch.exp(per_token_codebook_log_probs - per_token_codebook_log_probs.detach()) | |
| * advantages.unsqueeze(1) | |
| ) | |
| group_validities = generated_codes_and_metrics['group_validities'] # B * n_generations_per_item | |
| per_token_loss = per_token_loss * group_validities.unsqueeze(1) # B, T | |
| if not self.reference_free: | |
| with torch.no_grad(): | |
| ref_codebook_logits = reference_model_output[logits_key][:, :, si:ei] | |
| per_token_ref_codebook_log_probs = self._get_per_token_logps( | |
| ref_codebook_logits, codebook_labels, reference_codebook_loss_mask | |
| ) | |
| # https://github.com/huggingface/trl/blob/ffcb9f4aee725a2bd072d0387afe68a4b1c7967c/trl/trainer/grpo_trainer.py#L703 | |
| per_token_codebook_kl = ( | |
| torch.exp(per_token_ref_codebook_log_probs - per_token_codebook_log_probs) | |
| - (per_token_ref_codebook_log_probs - per_token_codebook_log_probs) | |
| - 1 | |
| ) | |
| per_token_loss = per_token_loss + self.cfg.grpo_beta * per_token_codebook_kl | |
| codebook_kl_loss_mean = ( | |
| (per_token_codebook_kl * policy_codebook_loss_mask).sum(dim=1) | |
| / policy_codebook_loss_mask.sum(dim=1) | |
| ).mean() | |
| else: | |
| codebook_kl_loss_mean = torch.tensor(0.0, device=self.device) | |
| if self.loss_type == "grpo": | |
| codebook_loss = ( | |
| (per_token_loss * policy_codebook_loss_mask).sum(dim=1) / policy_codebook_loss_mask.sum(dim=1) | |
| ).mean() | |
| elif self.loss_type == "dr_grpo": | |
| # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py | |
| total_tokens = per_token_loss.shape[0] * self.max_decoder_steps | |
| codebook_loss = (per_token_loss * policy_codebook_loss_mask).sum() / total_tokens | |
| else: | |
| raise ValueError(f"Unknown loss function: {self.loss_type}") | |
| if total_loss is None: | |
| total_loss = codebook_loss | |
| total_kl = codebook_kl_loss_mean | |
| else: | |
| total_loss += codebook_loss | |
| total_kl += codebook_kl_loss_mean | |
| total_loss /= self.num_audio_codebooks | |
| return { | |
| 'mean_reward': generated_codes_and_metrics['mean_reward'], | |
| 'std_reward': generated_codes_and_metrics['std_reward'], | |
| 'loss': total_loss, | |
| 'kl_loss': total_kl, | |
| 'batch_metrics': generated_codes_and_metrics['metrics'], | |
| } | |
| def training_step(self, batch, batch_idx): | |
| torch.cuda.empty_cache() | |
| n_generations_per_item = self.cfg.get('n_generations_per_item', 6) | |
| po_outputs = self.process_batch_online_po(batch, n_generations_per_item) | |
| self.log('train_loss', po_outputs['loss'], prog_bar=True, sync_dist=True) | |
| self.log('train_kl_loss', po_outputs['kl_loss'], prog_bar=True, sync_dist=True) | |
| self.log('train_mean_reward', po_outputs['mean_reward'], prog_bar=True, sync_dist=True) | |
| self.log('train_std_reward', po_outputs['std_reward'], prog_bar=True, sync_dist=True) | |
| return po_outputs['loss'] | |
| def validation_step(self, batch, batch_idx): | |
| po_outputs = self.process_batch_online_po(batch, 1, mode='val') | |
| batch_metrics = po_outputs['batch_metrics'] | |
| mean_reward = po_outputs['mean_reward'] | |
| val_loss = po_outputs['loss'] | |
| val_kl_loss = po_outputs['kl_loss'] | |
| self.validation_step_outputs.append( | |
| { | |
| 'mean_reward': mean_reward, | |
| 'std_reward': po_outputs['std_reward'], | |
| 'val_loss': val_loss, | |
| 'val_kl_loss': val_kl_loss, | |
| 'batch_metrics': batch_metrics, | |
| } | |
| ) | |
| def on_validation_epoch_end(self): | |
| def collect(key): | |
| values = [] | |
| for x in self.validation_step_outputs: | |
| if x[key] is not None: | |
| values.append(x[key]) | |
| else: | |
| values.append(torch.tensor(0.0, device=self.device)) | |
| stacked_values = torch.stack(values) | |
| return stacked_values.mean() | |
| val_loss = collect("val_loss") | |
| val_kl_loss = collect("val_kl_loss") | |
| mean_reward = collect("mean_reward") | |
| std_reward = collect("std_reward") | |
| self.log("val_loss", val_loss, prog_bar=True, sync_dist=True) | |
| self.log("val_kl_loss", val_kl_loss, prog_bar=True, sync_dist=True) | |
| self.log("val_mean_reward", mean_reward, prog_bar=True, sync_dist=True) | |
| self.log("val_std_reward", std_reward, prog_bar=True, sync_dist=True) | |
| mean_metrics = {} | |
| for val_output in self.validation_step_outputs: | |
| batch_metrics = val_output['batch_metrics'] | |
| for item_metrics in batch_metrics: | |
| for key, value in item_metrics.items(): | |
| if "transcript" not in key: | |
| if key not in mean_metrics: | |
| mean_metrics[key] = [] | |
| mean_metrics[key].append(value) | |
| for key, values in mean_metrics.items(): | |
| mean_metrics[key] = np.mean(values) | |
| self.log(f"val_{key}", mean_metrics[key], prog_bar=True, sync_dist=True) | |
| self.validation_step_outputs.clear() | |
| # Utility functions | |
| def process_text_for_cer(input_text): | |
| """ | |
| Normalizes text for CER/WER calculation. | |
| Taken from hallucination_eval.py | |
| """ | |
| # Convert text to lowercase | |
| lower_case_text = input_text.lower() | |
| # Remove commas from text | |
| no_comma_text = lower_case_text.replace(",", "") | |
| # Replace "-" with spaces | |
| no_dash_text = no_comma_text.replace("-", " ") | |
| no_dash_text = no_dash_text.replace("'", "") | |
| no_dash_text = no_dash_text.replace(";", "") | |
| no_dash_text = no_dash_text.replace(".", "") | |
| # Replace double spaces with single space | |
| single_space_text = " ".join(no_dash_text.split()) | |
| single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) | |
| # @shehzeen: Added this to handle some common errors in ASR transcripts | |
| single_space_text = single_space_text.replace("h t t p", "http") | |
| single_space_text = single_space_text.replace("w w w", "www") | |
| return single_space_text | |
| def get_speaker_embeddings_from_filepaths(filepaths, speaker_verification_model, device): | |
| audio_batch = [] | |
| audio_lengths = [] | |
| for filepath in filepaths: | |
| audio, sr = sf.read(filepath) | |
| if sr != 16000: | |
| audio = librosa.core.resample(audio, orig_sr=sr, target_sr=16000) | |
| audio_tensor = torch.tensor(audio, dtype=torch.float32, device=device) | |
| audio_batch.append(audio_tensor) | |
| audio_lengths.append(audio_tensor.size(0)) | |
| batch_audio_lens = torch.tensor(audio_lengths, device=device).long() | |
| max_audio_len = int(batch_audio_lens.max().item()) | |
| audio_batch = stack_tensors(audio_batch, max_lens=[max_audio_len]) | |
| _, speaker_embeddings = speaker_verification_model.forward( | |
| input_signal=audio_batch, input_signal_length=batch_audio_lens | |
| ) | |
| return speaker_embeddings | |
| def transcribe_with_whisper( | |
| audio_filepath, language, whisper_processor, whisper_model, device, normalizer: Optional[Normalizer] = None | |
| ): | |
| speech_array, sampling_rate = librosa.load(audio_filepath, sr=16000) | |
| forced_decoder_ids = ( | |
| whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") if language else None | |
| ) | |
| inputs = whisper_processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_features | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| predicted_ids = whisper_model.generate(inputs, forced_decoder_ids=forced_decoder_ids) | |
| transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
| result = transcription[0] | |
| if normalizer is not None: | |
| result = normalizer.normalize(result) | |
| return result | |