# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union import lightning.pytorch as pl import numpy as np import torch from omegaconf import OmegaConf, open_dict from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, get_inference_dtype, prepare_audio_data, restore_transcription_order, setup_model, write_transcription, ) from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.timers import SimpleTimer """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. # Arguments model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) output_filename: Output filename where the transcriptions will be written batch_size: batch size during inference presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) cuda: Optional int to enable or disable execution of model on certain CUDA device. allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available amp: Bool to decide if Automatic Mixed Precision should be used during inference audio_type: Str filetype of the audio. Supported = wav, flac, mp3 overwrite_transcripts: Bool which when set allows repeated transcriptions to overwrite previous results. ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values. rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values. calculate_wer: Bool to decide whether to calculate wer/cer at end of this script clean_groundtruth_text: Bool to clean groundtruth text langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". append_pred - optional. Allows you to add more than one prediction to an existing .json pred_name_postfix - optional. The name you want to be written for the current model Results are returned in a JSON manifest file. python transcribe_speech.py \ model_path=null \ pretrained_name=null \ audio_dir="" \ dataset_manifest="" \ output_filename="" \ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ append_pred=False \ pred_name_postfix="" """ @dataclass class ModelChangeConfig: """ Sub-config for changes specific to the Conformer Encoder """ conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: """ Transcription Configuration for audio to text transcription. """ # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest channel_selector: Optional[Union[int, str]] = ( None # Used to select a single channel from multichannel audio, or use average across channels ) audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction # General configs output_filename: Optional[str] = None batch_size: int = 32 num_workers: int = 0 append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses timestamps: Optional[bool] = None # Set to True to return hypotheses instead of text from the transcribe function return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. cuda: Optional[int] = None allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp compute_dtype: Optional[str] = ( None # "float32", "bfloat16" or "float16"; if None (default): bfloat16 if available else float32 ) matmul_precision: str = "high" # Literal["highest", "high", "medium"] audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True # Decoding strategy for CTC models ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) # Decoding strategy for RNNT models # enable CUDA graphs for transcription rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) # Decoding strategy for AED models multitask_decoding: MultiTaskDecodingConfig = field(default_factory=MultiTaskDecodingConfig) # Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs: # Implicit single-turn assuming default role='user' (works with Canary-1B) # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes # Explicit single-turn prompt: # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes # Explicit multi-turn prompt: # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' prompt: dict = field(default_factory=dict) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None # att_context_size can be set for cache-aware streaming models with multiple look-aheads att_context_size: Optional[list] = None # Use this for model-specific changes before transcription model_change: ModelChangeConfig = field(default_factory=ModelChangeConfig) # Config for word / character error rate calculation calculate_wer: bool = True clean_groundtruth_text: bool = False langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning use_cer: bool = False # can be set to True to return list of transcriptions instead of the config # if True, will also skip writing anything to the output file return_transcriptions: bool = False # key for groundtruth text in manifest gt_text_attr_name: str = "text" gt_lang_attr_name: str = "lang" extract_nbest: bool = False # Extract n-best hypotheses from the model calculate_rtfx: bool = False warmup_steps: int = 0 # by default - no warmup run_steps: int = 1 # by default - single run @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: """ Transcribes the input audio and can be used to infer with Encoder-Decoder models. """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: cfg[key] = None if cfg[key] == 'None' else cfg[key] if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if cfg.random_seed: pl.seed_everything(cfg.random_seed) if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") # Load augmentor from exteranl yaml file which contains eval info, could be extend to other feature such VAD, P&C augmentor = None if cfg.eval_config_yaml: eval_config = OmegaConf.load(cfg.eval_config_yaml) augmentor = eval_config.test_ds.get("augmentor") logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") # setup GPU torch.set_float32_matmul_precision(cfg.matmul_precision) if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device accelerator = 'gpu' map_location = torch.device('cuda:0') elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): logging.warning( "MPS device (Apple Silicon M-series GPU) support is experimental." " Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures." ) device = [0] accelerator = 'mps' map_location = torch.device('mps') else: device = 1 accelerator = 'cpu' map_location = torch.device('cpu') else: device = [cfg.cuda] accelerator = 'gpu' map_location = torch.device(f'cuda:{cfg.cuda}') logging.info(f"Inference will be done on device: {map_location}") asr_model, model_name = setup_model(cfg, map_location) trainer = pl.Trainer(devices=device, accelerator=accelerator) asr_model.set_trainer(trainer) asr_model = asr_model.eval() if (cfg.compute_dtype is not None and cfg.compute_dtype != "float32") and cfg.amp: raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 compute_dtype: torch.dtype if cfg.amp: # with amp model weights required to be in float32 compute_dtype = torch.float32 else: compute_dtype = get_inference_dtype(compute_dtype=cfg.compute_dtype, device=map_location) asr_model.to(compute_dtype) # we will adjust this flag if the model does not support it compute_langs = cfg.compute_langs if cfg.timestamps: cfg.return_hypotheses = True # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): if cfg.decoder_type and cfg.decoder_type != 'ctc': raise ValueError('CTC model only support ctc decoding!') elif isinstance(asr_model, EncDecHybridRNNTCTCModel): if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']: raise ValueError('Hybrid model only support ctc or rnnt decoding!') elif isinstance(asr_model, EncDecRNNTModel): if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') if cfg.decoder_type and hasattr(asr_model.encoder, 'set_default_att_context_size'): asr_model.encoder.set_default_att_context_size(cfg.att_context_size) # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs if cfg.extract_nbest: cfg.multitask_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.multitask_decoding) elif cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually if cfg.compute_langs and cfg.decoder_type == 'ctc': raise ValueError("CTC models do not support `compute_langs` at the moment") decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding if cfg.extract_nbest: decoding_cfg.beam.return_best_hypothesis = False cfg.return_hypotheses = True if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type) else: asr_model.change_decoding_strategy(decoding_cfg) # Check if ctc or rnnt model elif hasattr(asr_model, 'joint'): # RNNT model if cfg.extract_nbest: cfg.rnnt_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 cfg.rnnt_decoding.compute_langs = cfg.compute_langs asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") if cfg.extract_nbest: cfg.ctc_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.ctc_decoding) # Setup decoding config based on model type and decoder_type with open_dict(cfg): if isinstance(asr_model, EncDecCTCModel) or ( isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc" ): cfg.decoding = cfg.ctc_decoding elif isinstance(asr_model.decoding, MultiTaskDecoding): cfg.decoding = cfg.multitask_decoding else: cfg.decoding = cfg.rnnt_decoding filepaths, sorted_manifest_path = prepare_audio_data(cfg) remove_path_after_done = sorted_manifest_path if sorted_manifest_path is not None else None filepaths = sorted_manifest_path if sorted_manifest_path is not None else filepaths # Compute output filename cfg = compute_output_filename(cfg, model_name) # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.return_transcriptions and not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): logging.info( f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." ) return cfg # transcribe audio if cfg.calculate_rtfx: total_duration = 0.0 with open(cfg.dataset_manifest, "rt") as fh: for line in fh: item = json.loads(line) if "duration" not in item: raise ValueError( f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ lacks a 'duration' field." ) total_duration += item["duration"] if cfg.warmup_steps == 0: logging.warning( "RTFx measurement enabled, but warmup_steps=0. " "At least one warmup step is recommended to measure RTFx" ) timer = SimpleTimer() model_measurements = [] with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): override_cfg = asr_model.get_transcribe_config() override_cfg.batch_size = cfg.batch_size override_cfg.num_workers = cfg.num_workers override_cfg.return_hypotheses = cfg.return_hypotheses override_cfg.channel_selector = cfg.channel_selector override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name override_cfg.timestamps = cfg.timestamps if hasattr(override_cfg, "prompt"): override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) device = next(asr_model.parameters()).device for run_step in range(cfg.warmup_steps + cfg.run_steps): if run_step < cfg.warmup_steps: logging.info(f"Running warmup step {run_step}") # reset timer timer.reset() timer.start(device=device) # call transcribe transcriptions = asr_model.transcribe( audio=filepaths, override_config=override_cfg, timestamps=cfg.timestamps, ) # stop timer, log time timer.stop(device=device) logging.info(f"Model time for iteration {run_step}: {timer.total_sec():.3f}") if run_step >= cfg.warmup_steps: model_measurements.append(timer.total_sec()) model_measurements_np = np.asarray(model_measurements) logging.info( f"Model time avg: {model_measurements_np.mean():.3f}" + (f" (std: {model_measurements_np.std():.3f})" if cfg.run_steps > 1 else "") ) if cfg.dataset_manifest is not None: logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") if cfg.presort_manifest: transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions) else: logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") # if transcriptions form a tuple of (best_hypotheses, all_hypotheses) if type(transcriptions) == tuple and len(transcriptions) == 2: if cfg.extract_nbest: # extract all hypotheses if exists transcriptions = transcriptions[1] else: # extract just best hypothesis transcriptions = transcriptions[0] if cfg.return_transcriptions: return transcriptions # write audio transcriptions output_filename, pred_text_attr_name = write_transcription( transcriptions, cfg, model_name, filepaths=filepaths, compute_langs=compute_langs, timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") # clean-up if cfg.presort_manifest is not None: if remove_path_after_done is not None: os.unlink(remove_path_after_done) if cfg.calculate_wer: output_manifest_w_wer, total_res, _ = cal_write_wer( pred_manifest=output_filename, gt_text_attr_name=cfg.gt_text_attr_name, pred_text_attr_name=pred_text_attr_name, clean_groundtruth_text=cfg.clean_groundtruth_text, langid=cfg.langid, use_cer=cfg.use_cer, output_filename=None, ) if output_manifest_w_wer: logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") if cfg.calculate_rtfx: rtfx_measurements = total_duration / model_measurements_np logging.info( f"Model RTFx on the dataset: {rtfx_measurements.mean():.3f}" + (f" (std: {rtfx_measurements.std():.3f})" if cfg.run_steps > 1 else "") ) return cfg if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter