# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import subprocess import tempfile from pathlib import Path from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.parts.utils.eval_utils import get_hydra_override_from_config from nemo.utils import logging def model_name_contains(model_name: str, *keywords) -> bool: """ Check if any of the given keywords appear (case-insensitive) in the model name. Args: model_name (str): Model name. *keywords: Variable length argument list of keywords to check. Returns: bool: True if any of the keywords are found in the model name, False otherwise. """ model_name_lower = model_name.lower() return any(kw.lower() in model_name_lower for kw in keywords) def run_asr_inference(cfg: DictConfig) -> DictConfig: """ Execute ASR inference based on input mode and parameters. """ if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name): raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!") if cfg.inference.decoder_type not in [None, 'ctc', 'rnnt', 'aed']: raise ValueError("decoder_type could only be null, ctc, rnnt or aed") if cfg.inference.mode == "offline": cfg = run_offline_inference(cfg) elif cfg.inference.mode == "chunked": if ( "total_buffer_in_secs" not in cfg.inference or "chunk_len_in_secs" not in cfg.inference or not cfg.inference.total_buffer_in_secs or not cfg.inference.chunk_len_in_secs ): raise ValueError(f"Please specify both total_buffer_in_secs and chunk_len_in_secs for chunked inference") cfg = run_chunked_inference(cfg) elif cfg.inference.mode == "offline_by_chunked": # When use Conformer to transcribe long audio sample, we could probably encounter CUDA out of memory issue. # Here we use offline_by_chunked mode to simulate offline mode for Conformer. # And we specify default total_buffer_in_secs=22 and chunk_len_in_secs=20 to avoid above problem. OmegaConf.set_struct(cfg, True) if 'total_buffer_in_secs' not in cfg.inference or not cfg.inference.total_buffer_in_secs: with open_dict(cfg): cfg.inference.total_buffer_in_secs = 22 logging.info( f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.total_buffer_in_secs}" ) if 'chunk_len_in_secs' not in cfg.inference or not cfg.inference.chunk_len_in_secs: with open_dict(cfg): cfg.inference.chunk_len_in_secs = 20 logging.info( f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.chunk_len_in_secs}" ) cfg = run_chunked_inference(cfg) else: raise ValueError(f"inference could only be offline or chunked, but got {cfg.inference.mode}") return cfg def run_chunked_inference(cfg: DictConfig) -> DictConfig: if cfg.model_path: model_name = Path(cfg.model_path).stem else: model_name = cfg.pretrained_name if "output_filename" not in cfg or not cfg.output_filename: dataset_name = Path(cfg.test_ds.manifest_filepath).stem mode_name = ( cfg.inference.mode + "B" + str(cfg.inference.total_buffer_in_secs) + "C" + str(cfg.inference.chunk_len_in_secs) ) OmegaConf.set_struct(cfg, True) with open_dict(cfg): cfg.output_filename = f"{model_name}-{dataset_name}-{mode_name}.json" use_ctc_script = False use_rnnt_scrpit = False use_aed_script = False # hybrid model if model_name_contains(model_name, "hybrid"): if cfg.inference.decoder_type: if cfg.inference.decoder_type == "rnnt": use_rnnt_scrpit = True elif cfg.inference.decoder_type == "ctc": use_ctc_script = True else: raise ValueError( f"Hybrid models only support rnnt or ctc decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null, rnnt or ctc for hybrid models" ) else: # By default, use RNNT for hybrid models use_rnnt_scrpit = True # rnnt model elif model_name_contains(model_name, "rnnt", "transducer"): if cfg.inference.decoder_type and cfg.inference.decoder_type != 'rnnt': raise ValueError( f"rnnt models only support rnnt decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or rnnt for rnnt models" ) use_rnnt_scrpit = True # ctc model elif model_name_contains(model_name, "ctc"): if cfg.inference.decoder_type and cfg.inference.decoder_type != 'ctc': raise ValueError( f"ctc models only support ctc decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or ctc for ctc models" ) use_ctc_script = True # aed model elif model_name_contains(model_name, "canary"): if cfg.inference.decoder_type and cfg.inference.decoder_type != 'aed': raise ValueError( f"Canary models only support aed decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or aed for aed models" ) use_aed_script = True else: raise ValueError( "Please make sure your pretrained_name or model_path contains \n\ 'hybrid' for EncDecHybridRNNTCTCModel model, \n\ 'transducer/rnnt' for EncDecRNNTModel model, \n\ 'ctc' for EncDecCTCModel, or \n\ 'aed' for EncDecMultiTaskModel." ) script_path = None if use_rnnt_scrpit: script_path = ( Path(__file__).parents[2] / "examples" / "asr" / "asr_chunked_inference" / "rnnt" / "speech_to_text_buffered_infer_rnnt.py" ) elif use_aed_script: script_path = ( Path(__file__).parents[2] / "examples" / "asr" / "asr_chunked_inference" / "aed" / "speech_to_text_aed_chunked_infer.py" ) elif use_ctc_script: raise ValueError("Evaluation of CTC models with chunked inference is not supported") else: raise ValueError(f"Unsupported model: {model_name}") # If need to change other config such as decoding strategy, could either: # 1) change TranscriptionConfig on top of the executed scripts such as speech_to_text_buffered_infer_rnnt.py, or # 2) add command as "decoding.strategy=greedy_batch " to below script base_cmd = f"python {script_path} \ calculate_wer=False \ model_path={cfg.model_path} \ pretrained_name={cfg.pretrained_name} \ dataset_manifest={cfg.test_ds.manifest_filepath} \ output_filename={cfg.output_filename} \ random_seed={cfg.random_seed} \ batch_size={cfg.test_ds.batch_size} \ ++num_workers={cfg.test_ds.num_workers} \ chunk_len_in_secs={cfg.inference.chunk_len_in_secs} \ ++total_buffer_in_secs={cfg.inference.total_buffer_in_secs} \ model_stride={cfg.inference.model_stride} \ ++timestamps={cfg.inference.timestamps}" subprocess.run( base_cmd, shell=True, check=True, ) return cfg def run_offline_inference(cfg: DictConfig) -> DictConfig: if "output_filename" not in cfg or not cfg.output_filename: if cfg.model_path: model_name = Path(cfg.model_path).stem else: model_name = cfg.pretrained_name dataset_name = Path(cfg.test_ds.manifest_filepath).stem mode_name = cfg.inference.mode OmegaConf.set_struct(cfg, True) with open_dict(cfg): cfg.output_filename = f"{model_name}-{dataset_name}-{mode_name}.json" with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: OmegaConf.save(cfg, f) f.seek(0) # reset file pointer script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py" # some keys to ingore when generating hydra overrides exclude_keys = [ 'calculate_wer', 'model_path', 'pretrained_name', 'dataset_manifest', 'output_filename', 'batch_size', 'num_workers', 'random_seed', 'eval_config_yaml', 'decoder_type', ] hydra_overrides = get_hydra_override_from_config(cfg.get("transcribe_params", None), exclude_keys=exclude_keys) # If need to change other config such as decoding strategy, could either: # 1) change TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr, or # 2) add command as "rnnt_decoding.strategy=greedy_batch " to below script subprocess.run( f"python {script_path} " f"calculate_wer=False " f"model_path={cfg.model_path} " f"pretrained_name={cfg.pretrained_name} " f"dataset_manifest={cfg.test_ds.manifest_filepath} " f"output_filename={cfg.output_filename} " f"batch_size={cfg.test_ds.batch_size} " f"num_workers={cfg.test_ds.num_workers} " f"random_seed={cfg.random_seed} " f"eval_config_yaml={f.name} " f"decoder_type={cfg.inference.decoder_type} {hydra_overrides}", shell=True, check=True, ) return cfg def cal_target_metadata_wer( manifest: str, target: str, meta_cfg: DictConfig, eval_metric: str = "wer", ) -> dict: """ Caculating number of samples (samples), number of words/characters/tokens (tokens), wer/cer, insertion error rate (ins_rate), deletion error rate (del_rate), substitution error rate (sub_rate) of the group/slot of target metadata. The group could be [female, male] or slot group like [0-2s, 2-5s, >5s audios] Args: manifest (str): Filepath of the generated manifest which contains prediction and eval result for each samples. target (str): Target metadata. Execute the target metadata if field presents in manifest. such as 'duration', 'speaker', 'emotion', etc. meta_cfg (DictConfig): Config for calculating group eval_metric for the target metadata. eval_metric: (str): Supported evaluation metrics. Currently support 'wer' and 'cer'. Return: ret (dict): Generated dictionary containing all results regarding the target metadata. """ if eval_metric not in ['wer', 'cer']: raise ValueError( "Currently support wer and cer as eval_metric. Please implement it in cal_target_metadata_wer if using different eval_metric" ) wer_per_class = {} with open(manifest, 'r') as fp: for line in fp: sample = json.loads(line) if target in sample: target_class = sample[target] if target_class not in wer_per_class: wer_per_class[target_class] = { 'samples': 0, 'tokens': 0, "errors": 0, "inss": 0, "dels": 0, "subs": 0, } wer_per_class[target_class]['samples'] += 1 tokens = sample["tokens"] wer_per_class[target_class]["tokens"] += tokens wer_per_class[target_class]["errors"] += tokens * sample[eval_metric] wer_per_class[target_class]["inss"] += tokens * sample["ins_rate"] wer_per_class[target_class]["dels"] += tokens * sample["del_rate"] wer_per_class[target_class]["subs"] += tokens * sample["sub_rate"] if len(wer_per_class) > 0: res_wer_per_class = {} for target_class in wer_per_class: res_wer_per_class[target_class] = {} res_wer_per_class[target_class]["samples"] = wer_per_class[target_class]["samples"] res_wer_per_class[target_class][eval_metric] = ( wer_per_class[target_class]["errors"] / wer_per_class[target_class]["tokens"] ) res_wer_per_class[target_class]["tokens"] = wer_per_class[target_class]["tokens"] res_wer_per_class[target_class]["ins_rate"] = ( wer_per_class[target_class]["inss"] / wer_per_class[target_class]["tokens"] ) res_wer_per_class[target_class]["del_rate"] = ( wer_per_class[target_class]["dels"] / wer_per_class[target_class]["tokens"] ) res_wer_per_class[target_class]["sub_rate"] = ( wer_per_class[target_class]["subs"] / wer_per_class[target_class]["tokens"] ) else: logging.info(f"metadata '{target}' does not present in manifest. Skipping! ") return None values = ['samples', 'tokens', 'errors', 'inss', 'dels', 'subs'] slot_wer = {} if 'slot' in meta_cfg and meta_cfg.slot: for target_class in wer_per_class: for s in meta_cfg.slot: if isinstance(s[0], float) or isinstance(s[0], int): if s[0] <= target_class < s[1]: slot_key = "slot-" + ",".join(str(i) for i in s) if slot_key not in slot_wer: slot_wer[slot_key] = { 'samples': 0, 'tokens': 0, "errors": 0, "inss": 0, "dels": 0, "subs": 0, } for v in values: slot_wer[slot_key][v] += wer_per_class[target_class][v] break elif isinstance(s[0], str): if target_class in s: slot_key = "slot-" + ",".join(s) if slot_key not in slot_wer: slot_wer[slot_key] = { 'samples': 0, 'tokens': 0, "errors": 0, "inss": 0, "dels": 0, "subs": 0, } for v in values: slot_wer[slot_key][v] += wer_per_class[target_class][v] break else: raise ValueError("Current only support target metadata belongs to numeric or string ") for slot_key in slot_wer: slot_wer[slot_key][eval_metric] = slot_wer[slot_key]['errors'] / slot_wer[slot_key]['tokens'] slot_wer[slot_key]['ins_rate'] = slot_wer[slot_key]['inss'] / slot_wer[slot_key]['tokens'] slot_wer[slot_key]['del_rate'] = slot_wer[slot_key]['dels'] / slot_wer[slot_key]['tokens'] slot_wer[slot_key]['sub_rate'] = slot_wer[slot_key]['subs'] / slot_wer[slot_key]['tokens'] slot_wer[slot_key].pop('errors') slot_wer[slot_key].pop('inss') slot_wer[slot_key].pop('dels') slot_wer[slot_key].pop('subs') res_wer_per_class.update(slot_wer) ret = None if meta_cfg.save_wer_per_class: ret = res_wer_per_class if (not meta_cfg.save_wer_per_class) and ('slot' in meta_cfg and meta_cfg.slot): ret = slot_wer return ret