Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from omegaconf import DictConfig, OmegaConf, open_dict | |
| from nemo.collections.asr.parts.utils.eval_utils import get_hydra_override_from_config | |
| from nemo.utils import logging | |
| def model_name_contains(model_name: str, *keywords) -> bool: | |
| """ | |
| Check if any of the given keywords appear (case-insensitive) in the model name. | |
| Args: | |
| model_name (str): Model name. | |
| *keywords: Variable length argument list of keywords to check. | |
| Returns: | |
| bool: True if any of the keywords are found in the model name, False otherwise. | |
| """ | |
| model_name_lower = model_name.lower() | |
| return any(kw.lower() in model_name_lower for kw in keywords) | |
| def run_asr_inference(cfg: DictConfig) -> DictConfig: | |
| """ | |
| Execute ASR inference based on input mode and parameters. | |
| """ | |
| if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name): | |
| raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!") | |
| if cfg.inference.decoder_type not in [None, 'ctc', 'rnnt', 'aed']: | |
| raise ValueError("decoder_type could only be null, ctc, rnnt or aed") | |
| if cfg.inference.mode == "offline": | |
| cfg = run_offline_inference(cfg) | |
| elif cfg.inference.mode == "chunked": | |
| if ( | |
| "total_buffer_in_secs" not in cfg.inference | |
| or "chunk_len_in_secs" not in cfg.inference | |
| or not cfg.inference.total_buffer_in_secs | |
| or not cfg.inference.chunk_len_in_secs | |
| ): | |
| raise ValueError(f"Please specify both total_buffer_in_secs and chunk_len_in_secs for chunked inference") | |
| cfg = run_chunked_inference(cfg) | |
| elif cfg.inference.mode == "offline_by_chunked": | |
| # When use Conformer to transcribe long audio sample, we could probably encounter CUDA out of memory issue. | |
| # Here we use offline_by_chunked mode to simulate offline mode for Conformer. | |
| # And we specify default total_buffer_in_secs=22 and chunk_len_in_secs=20 to avoid above problem. | |
| OmegaConf.set_struct(cfg, True) | |
| if 'total_buffer_in_secs' not in cfg.inference or not cfg.inference.total_buffer_in_secs: | |
| with open_dict(cfg): | |
| cfg.inference.total_buffer_in_secs = 22 | |
| logging.info( | |
| f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.total_buffer_in_secs}" | |
| ) | |
| if 'chunk_len_in_secs' not in cfg.inference or not cfg.inference.chunk_len_in_secs: | |
| with open_dict(cfg): | |
| cfg.inference.chunk_len_in_secs = 20 | |
| logging.info( | |
| f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.chunk_len_in_secs}" | |
| ) | |
| cfg = run_chunked_inference(cfg) | |
| else: | |
| raise ValueError(f"inference could only be offline or chunked, but got {cfg.inference.mode}") | |
| return cfg | |
| def run_chunked_inference(cfg: DictConfig) -> DictConfig: | |
| if cfg.model_path: | |
| model_name = Path(cfg.model_path).stem | |
| else: | |
| model_name = cfg.pretrained_name | |
| if "output_filename" not in cfg or not cfg.output_filename: | |
| dataset_name = Path(cfg.test_ds.manifest_filepath).stem | |
| mode_name = ( | |
| cfg.inference.mode | |
| + "B" | |
| + str(cfg.inference.total_buffer_in_secs) | |
| + "C" | |
| + str(cfg.inference.chunk_len_in_secs) | |
| ) | |
| OmegaConf.set_struct(cfg, True) | |
| with open_dict(cfg): | |
| cfg.output_filename = f"{model_name}-{dataset_name}-{mode_name}.json" | |
| use_ctc_script = False | |
| use_rnnt_scrpit = False | |
| use_aed_script = False | |
| # hybrid model | |
| if model_name_contains(model_name, "hybrid"): | |
| if cfg.inference.decoder_type: | |
| if cfg.inference.decoder_type == "rnnt": | |
| use_rnnt_scrpit = True | |
| elif cfg.inference.decoder_type == "ctc": | |
| use_ctc_script = True | |
| else: | |
| raise ValueError( | |
| f"Hybrid models only support rnnt or ctc decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null, rnnt or ctc for hybrid models" | |
| ) | |
| else: | |
| # By default, use RNNT for hybrid models | |
| use_rnnt_scrpit = True | |
| # rnnt model | |
| elif model_name_contains(model_name, "rnnt", "transducer"): | |
| if cfg.inference.decoder_type and cfg.inference.decoder_type != 'rnnt': | |
| raise ValueError( | |
| f"rnnt models only support rnnt decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or rnnt for rnnt models" | |
| ) | |
| use_rnnt_scrpit = True | |
| # ctc model | |
| elif model_name_contains(model_name, "ctc"): | |
| if cfg.inference.decoder_type and cfg.inference.decoder_type != 'ctc': | |
| raise ValueError( | |
| f"ctc models only support ctc decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or ctc for ctc models" | |
| ) | |
| use_ctc_script = True | |
| # aed model | |
| elif model_name_contains(model_name, "canary"): | |
| if cfg.inference.decoder_type and cfg.inference.decoder_type != 'aed': | |
| raise ValueError( | |
| f"Canary models only support aed decoding! Current decoder_type: {cfg.inference.decoder_type}! Change it to null or aed for aed models" | |
| ) | |
| use_aed_script = True | |
| else: | |
| raise ValueError( | |
| "Please make sure your pretrained_name or model_path contains \n\ | |
| 'hybrid' for EncDecHybridRNNTCTCModel model, \n\ | |
| 'transducer/rnnt' for EncDecRNNTModel model, \n\ | |
| 'ctc' for EncDecCTCModel, or \n\ | |
| 'aed' for EncDecMultiTaskModel." | |
| ) | |
| script_path = None | |
| if use_rnnt_scrpit: | |
| script_path = ( | |
| Path(__file__).parents[2] | |
| / "examples" | |
| / "asr" | |
| / "asr_chunked_inference" | |
| / "rnnt" | |
| / "speech_to_text_buffered_infer_rnnt.py" | |
| ) | |
| elif use_aed_script: | |
| script_path = ( | |
| Path(__file__).parents[2] | |
| / "examples" | |
| / "asr" | |
| / "asr_chunked_inference" | |
| / "aed" | |
| / "speech_to_text_aed_chunked_infer.py" | |
| ) | |
| elif use_ctc_script: | |
| raise ValueError("Evaluation of CTC models with chunked inference is not supported") | |
| else: | |
| raise ValueError(f"Unsupported model: {model_name}") | |
| # If need to change other config such as decoding strategy, could either: | |
| # 1) change TranscriptionConfig on top of the executed scripts such as speech_to_text_buffered_infer_rnnt.py, or | |
| # 2) add command as "decoding.strategy=greedy_batch " to below script | |
| base_cmd = f"python {script_path} \ | |
| calculate_wer=False \ | |
| model_path={cfg.model_path} \ | |
| pretrained_name={cfg.pretrained_name} \ | |
| dataset_manifest={cfg.test_ds.manifest_filepath} \ | |
| output_filename={cfg.output_filename} \ | |
| random_seed={cfg.random_seed} \ | |
| batch_size={cfg.test_ds.batch_size} \ | |
| ++num_workers={cfg.test_ds.num_workers} \ | |
| chunk_len_in_secs={cfg.inference.chunk_len_in_secs} \ | |
| ++total_buffer_in_secs={cfg.inference.total_buffer_in_secs} \ | |
| model_stride={cfg.inference.model_stride} \ | |
| ++timestamps={cfg.inference.timestamps}" | |
| subprocess.run( | |
| base_cmd, | |
| shell=True, | |
| check=True, | |
| ) | |
| return cfg | |
| def run_offline_inference(cfg: DictConfig) -> DictConfig: | |
| if "output_filename" not in cfg or not cfg.output_filename: | |
| if cfg.model_path: | |
| model_name = Path(cfg.model_path).stem | |
| else: | |
| model_name = cfg.pretrained_name | |
| dataset_name = Path(cfg.test_ds.manifest_filepath).stem | |
| mode_name = cfg.inference.mode | |
| OmegaConf.set_struct(cfg, True) | |
| with open_dict(cfg): | |
| cfg.output_filename = f"{model_name}-{dataset_name}-{mode_name}.json" | |
| with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: | |
| OmegaConf.save(cfg, f) | |
| f.seek(0) # reset file pointer | |
| script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py" | |
| # some keys to ingore when generating hydra overrides | |
| exclude_keys = [ | |
| 'calculate_wer', | |
| 'model_path', | |
| 'pretrained_name', | |
| 'dataset_manifest', | |
| 'output_filename', | |
| 'batch_size', | |
| 'num_workers', | |
| 'random_seed', | |
| 'eval_config_yaml', | |
| 'decoder_type', | |
| ] | |
| hydra_overrides = get_hydra_override_from_config(cfg.get("transcribe_params", None), exclude_keys=exclude_keys) | |
| # If need to change other config such as decoding strategy, could either: | |
| # 1) change TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr, or | |
| # 2) add command as "rnnt_decoding.strategy=greedy_batch " to below script | |
| subprocess.run( | |
| f"python {script_path} " | |
| f"calculate_wer=False " | |
| f"model_path={cfg.model_path} " | |
| f"pretrained_name={cfg.pretrained_name} " | |
| f"dataset_manifest={cfg.test_ds.manifest_filepath} " | |
| f"output_filename={cfg.output_filename} " | |
| f"batch_size={cfg.test_ds.batch_size} " | |
| f"num_workers={cfg.test_ds.num_workers} " | |
| f"random_seed={cfg.random_seed} " | |
| f"eval_config_yaml={f.name} " | |
| f"decoder_type={cfg.inference.decoder_type} {hydra_overrides}", | |
| shell=True, | |
| check=True, | |
| ) | |
| return cfg | |
| def cal_target_metadata_wer( | |
| manifest: str, | |
| target: str, | |
| meta_cfg: DictConfig, | |
| eval_metric: str = "wer", | |
| ) -> dict: | |
| """ | |
| Caculating number of samples (samples), number of words/characters/tokens (tokens), | |
| wer/cer, insertion error rate (ins_rate), deletion error rate (del_rate), substitution error rate (sub_rate) of the group/slot of target metadata. | |
| The group could be [female, male] or slot group like [0-2s, 2-5s, >5s audios] | |
| Args: | |
| manifest (str): Filepath of the generated manifest which contains prediction and eval result for each samples. | |
| target (str): Target metadata. Execute the target metadata if field presents in manifest. | |
| such as 'duration', 'speaker', 'emotion', etc. | |
| meta_cfg (DictConfig): Config for calculating group eval_metric for the target metadata. | |
| eval_metric: (str): Supported evaluation metrics. Currently support 'wer' and 'cer'. | |
| Return: | |
| ret (dict): Generated dictionary containing all results regarding the target metadata. | |
| """ | |
| if eval_metric not in ['wer', 'cer']: | |
| raise ValueError( | |
| "Currently support wer and cer as eval_metric. Please implement it in cal_target_metadata_wer if using different eval_metric" | |
| ) | |
| wer_per_class = {} | |
| with open(manifest, 'r') as fp: | |
| for line in fp: | |
| sample = json.loads(line) | |
| if target in sample: | |
| target_class = sample[target] | |
| if target_class not in wer_per_class: | |
| wer_per_class[target_class] = { | |
| 'samples': 0, | |
| 'tokens': 0, | |
| "errors": 0, | |
| "inss": 0, | |
| "dels": 0, | |
| "subs": 0, | |
| } | |
| wer_per_class[target_class]['samples'] += 1 | |
| tokens = sample["tokens"] | |
| wer_per_class[target_class]["tokens"] += tokens | |
| wer_per_class[target_class]["errors"] += tokens * sample[eval_metric] | |
| wer_per_class[target_class]["inss"] += tokens * sample["ins_rate"] | |
| wer_per_class[target_class]["dels"] += tokens * sample["del_rate"] | |
| wer_per_class[target_class]["subs"] += tokens * sample["sub_rate"] | |
| if len(wer_per_class) > 0: | |
| res_wer_per_class = {} | |
| for target_class in wer_per_class: | |
| res_wer_per_class[target_class] = {} | |
| res_wer_per_class[target_class]["samples"] = wer_per_class[target_class]["samples"] | |
| res_wer_per_class[target_class][eval_metric] = ( | |
| wer_per_class[target_class]["errors"] / wer_per_class[target_class]["tokens"] | |
| ) | |
| res_wer_per_class[target_class]["tokens"] = wer_per_class[target_class]["tokens"] | |
| res_wer_per_class[target_class]["ins_rate"] = ( | |
| wer_per_class[target_class]["inss"] / wer_per_class[target_class]["tokens"] | |
| ) | |
| res_wer_per_class[target_class]["del_rate"] = ( | |
| wer_per_class[target_class]["dels"] / wer_per_class[target_class]["tokens"] | |
| ) | |
| res_wer_per_class[target_class]["sub_rate"] = ( | |
| wer_per_class[target_class]["subs"] / wer_per_class[target_class]["tokens"] | |
| ) | |
| else: | |
| logging.info(f"metadata '{target}' does not present in manifest. Skipping! ") | |
| return None | |
| values = ['samples', 'tokens', 'errors', 'inss', 'dels', 'subs'] | |
| slot_wer = {} | |
| if 'slot' in meta_cfg and meta_cfg.slot: | |
| for target_class in wer_per_class: | |
| for s in meta_cfg.slot: | |
| if isinstance(s[0], float) or isinstance(s[0], int): | |
| if s[0] <= target_class < s[1]: | |
| slot_key = "slot-" + ",".join(str(i) for i in s) | |
| if slot_key not in slot_wer: | |
| slot_wer[slot_key] = { | |
| 'samples': 0, | |
| 'tokens': 0, | |
| "errors": 0, | |
| "inss": 0, | |
| "dels": 0, | |
| "subs": 0, | |
| } | |
| for v in values: | |
| slot_wer[slot_key][v] += wer_per_class[target_class][v] | |
| break | |
| elif isinstance(s[0], str): | |
| if target_class in s: | |
| slot_key = "slot-" + ",".join(s) | |
| if slot_key not in slot_wer: | |
| slot_wer[slot_key] = { | |
| 'samples': 0, | |
| 'tokens': 0, | |
| "errors": 0, | |
| "inss": 0, | |
| "dels": 0, | |
| "subs": 0, | |
| } | |
| for v in values: | |
| slot_wer[slot_key][v] += wer_per_class[target_class][v] | |
| break | |
| else: | |
| raise ValueError("Current only support target metadata belongs to numeric or string ") | |
| for slot_key in slot_wer: | |
| slot_wer[slot_key][eval_metric] = slot_wer[slot_key]['errors'] / slot_wer[slot_key]['tokens'] | |
| slot_wer[slot_key]['ins_rate'] = slot_wer[slot_key]['inss'] / slot_wer[slot_key]['tokens'] | |
| slot_wer[slot_key]['del_rate'] = slot_wer[slot_key]['dels'] / slot_wer[slot_key]['tokens'] | |
| slot_wer[slot_key]['sub_rate'] = slot_wer[slot_key]['subs'] / slot_wer[slot_key]['tokens'] | |
| slot_wer[slot_key].pop('errors') | |
| slot_wer[slot_key].pop('inss') | |
| slot_wer[slot_key].pop('dels') | |
| slot_wer[slot_key].pop('subs') | |
| res_wer_per_class.update(slot_wer) | |
| ret = None | |
| if meta_cfg.save_wer_per_class: | |
| ret = res_wer_per_class | |
| if (not meta_cfg.save_wer_per_class) and ('slot' in meta_cfg and meta_cfg.slot): | |
| ret = slot_wer | |
| return ret | |