# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import glob import json import os from typing import List from filelock import FileLock from nemo.utils import logging def create_transcribed_shard_manifests(prediction_filepaths: List[str]) -> List[str]: """ Creates transcribed shard manifest files by processing predictions and organizing them by shard ID. This function reads a `predictions_all.json` file from each given directory, organizes the data by shard IDs, and writes the entries to separate shard manifest files. For each shard, the `pred_text` field is updated as the main transcription (`text`), and the original transcription (`text`) is stored as `orig_text`. Args: prediction_filepaths (List[str]): A list of file paths to directories containing `predictions_all.json` files with prediction data, including shard IDs. Returns: List[str]: A list of file paths to the combined manifest files (`transcribed_manifest__OP_0..CL_.json`) created for each directory. """ all_manifest_filepaths = [] for prediction_filepath in prediction_filepaths: max_shard_id = 0 shard_data = {} full_path = os.path.join(prediction_filepath, "predictions_all.json") with open(full_path, 'r') as f: for line in f.readlines(): data_entry = json.loads(line) shard_id = data_entry.get("shard_id") if max_shard_id < shard_id: max_shard_id = shard_id if shard_id not in shard_data: shard_data[shard_id] = [] shard_data[shard_id].append(data_entry) for shard_id, entries in shard_data.items(): output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") with open(output_filename, 'w') as f: for data_entry in entries: if data_entry['audio_filepath'].endswith(".wav"): if 'text' in data_entry: data_entry['orig_text'] = data_entry.pop('text') data_entry['text'] = data_entry.pop('pred_text') json.dump(data_entry, f, ensure_ascii=False) f.write("\n") shard_manifest_filepath = os.path.join( prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json" ) all_manifest_filepaths.append(shard_manifest_filepath) return all_manifest_filepaths def create_transcribed_manifests(prediction_filepaths: List[str]) -> List[str]: """ Creates updated transcribed manifest files by processing predictions. This function reads prediction files (`predictions_all.json`) from the provided directories, updates the transcription data by renaming the `pred_text` field to `text`, and stores the original `text` field as `orig_text`. The updated data is written to new transcribed manifest files (`transcribed_manifest.json`) in each directory. Args: prediction_filepaths (List[str]): A list of file paths to directories containing prediction files (`predictions_all.json`). Returns: List[str]: A list of file paths to the newly created transcribed manifest files (`transcribed_manifest.json`). """ all_manifest_filepaths = [] for prediction_filepath in prediction_filepaths: prediction_name = os.path.join(prediction_filepath, "predictions_all.json") transcripted_name = os.path.join(prediction_filepath, f"transcribed_manifest.json") # Open and read the original predictions_all.json file with open(transcripted_name, 'w', encoding='utf-8') as f: with open(prediction_name, 'r', encoding='utf-8') as pred_f: for line in pred_f.readlines(): data_entry = json.loads(line) if 'text' in data_entry: data_entry['orig_text'] = data_entry.pop('text') data_entry['text'] = data_entry.pop('pred_text') json.dump(data_entry, f, ensure_ascii=False) f.write("\n") # Append the path of the new manifest file to the list all_manifest_filepaths.append(transcripted_name) return all_manifest_filepaths def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[List[str]]: """ Updates transcriptions by merging predicted shard data and transcribed manifest data. This function processes prediction and transcribed manifest files, merges them by matching the shard_id and audio file paths. For each shard, the corresponding data entries are written to a new file. Args: manifest_filepaths (List[str]): A list of file paths to directories containing prediction and transcribed manifest files. Returns: List[List[str]]: A list of lists containing the file paths to the generated transcribed shard manifest files. """ all_manifest_filepaths = [] # Process each prediction directory for prediction_filepath in manifest_filepaths: predicted_shard_data = {} # Collect entries from prediction files based on shard id prediction_path = os.path.join(prediction_filepath, "predictions_all.json") with open(prediction_path, 'r') as f: for line in f: data_entry = json.loads(line) shard_id = data_entry.get("shard_id") audio_filepath = data_entry['audio_filepath'] predicted_shard_data.setdefault(shard_id, {})[audio_filepath] = data_entry max_shard_id = 0 for full_path in glob.glob(os.path.join(prediction_filepath, f"transcribed_manifest_[0-9]*.json")): all_data_entries = [] with open(full_path, 'r') as f: for line in f: data_entry = json.loads(line) shard_id = data_entry.get("shard_id") max_shard_id = max(max_shard_id, shard_id) all_data_entries.append(data_entry) # Write the merged data to a new manifest file keeping new transcriptions output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json") with open(output_filename, 'w') as f: for data_entry in all_data_entries: audio_filepath = data_entry['audio_filepath'] # Escape duplicated audio files that end with *dup if audio_filepath.endswith(".wav"): if shard_id in predicted_shard_data and audio_filepath in predicted_shard_data[shard_id]: predicted_data_entry = predicted_shard_data[shard_id][audio_filepath] if 'text' in predicted_data_entry: predicted_data_entry['orig_text'] = predicted_data_entry.pop('text') if "pred_text" in predicted_data_entry: predicted_data_entry['text'] = predicted_data_entry.pop('pred_text') json.dump(predicted_data_entry, f, ensure_ascii=False) else: json.dump(data_entry, f, ensure_ascii=False) f.write("\n") shard_manifest_filepath = os.path.join(prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json") all_manifest_filepaths.append([shard_manifest_filepath]) return all_manifest_filepaths def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]: """ Updates transcriptions by merging predicted data with transcribed manifest data. This function processes prediction and transcribed manifest files within given directories. It matches audio file paths to update transcriptions with predictions, ensuring each audio file is properly transcribed. The updated data is written to the transcribed manifest file. Args: manifest_filepaths (List[str]): A list of file paths to directories containing the prediction file (`predictions_all.json`) and the transcribed manifest file (`transcribed_manifest.json`). Returns: List[str]: A list of file paths to the updated transcribed manifest files. """ all_manifest_filepaths = [] for prediction_filepath in manifest_filepaths: predicted_data = {} prediction_path = os.path.join(prediction_filepath, "predictions_all.json") with open(prediction_path, 'r') as f: for line in f: data_entry = json.loads(line) path = data_entry['audio_filepath'] predicted_data[path] = data_entry full_path = os.path.join(prediction_filepath, f"transcribed_manifest.json") all_data_entries = [] with open(full_path, 'r') as f: for line in f: data_entry = json.loads(line) all_data_entries.append(data_entry) output_filename = os.path.join(prediction_filepath, f"transcribed_manifest.json") with open(output_filename, 'w') as f: for data_entry in all_data_entries: audio_filepath = data_entry['audio_filepath'] if audio_filepath.endswith(".wav"): if audio_filepath in predicted_data: predicted_data_entry = predicted_data[audio_filepath] if 'text' in predicted_data_entry: predicted_data_entry['orig_text'] = predicted_data_entry.pop('text') predicted_data_entry['text'] = predicted_data_entry.pop('pred_text') json.dump(predicted_data_entry, f, ensure_ascii=False) f.write("\n") else: json.dump(data_entry, f, ensure_ascii=False) f.write("\n") all_manifest_filepaths.append(output_filename) return all_manifest_filepaths if __name__ == "__main__": rank = int(os.environ.get("RANK", 0)) # Default to 0 if not set parser = argparse.ArgumentParser(description="Script to create or write transcriptions") parser.add_argument("--is_tarred", action="store_true", help="If true, processes tarred manifests") parser.add_argument("--full_pass", action="store_true", help="If true, processes full pass manifests") parser.add_argument( "--prediction_filepaths", type=str, nargs='+', # Accepts one or more values as a list required=True, help="Paths to one or more inference config YAML files.", ) args = parser.parse_args() lock_dir = os.path.dirname(args.prediction_filepaths[0]) lock_file = lock_dir + "/my_script.lock" with FileLock(lock_file): if rank == 0: if args.is_tarred: result = ( write_sampled_shard_transcriptions(args.prediction_filepaths) if not args.full_pass else create_transcribed_shard_manifests(args.prediction_filepaths) ) else: result = ( write_sampled_transcriptions(args.prediction_filepaths) if not args.full_pass else create_transcribed_manifests(args.prediction_filepaths) ) # Remove the lock file after the FileLock context is exited if os.path.exists(lock_file): os.remove(lock_file)