MagpieTTS_Internal_Demo / scripts /pseudo_labeling /write_transcribed_files.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
from typing import List
from filelock import FileLock
from nemo.utils import logging
def create_transcribed_shard_manifests(prediction_filepaths: List[str]) -> List[str]:
"""
Creates transcribed shard manifest files by processing predictions and organizing them by shard ID.
This function reads a `predictions_all.json` file from each given directory, organizes the data by
shard IDs, and writes the entries to separate shard manifest files. For each shard, the `pred_text`
field is updated as the main transcription (`text`), and the original transcription (`text`) is
stored as `orig_text`.
Args:
prediction_filepaths (List[str]): A list of file paths to directories containing
`predictions_all.json` files with prediction data, including shard IDs.
Returns:
List[str]: A list of file paths to the combined manifest files (`transcribed_manifest__OP_0..CL_.json`)
created for each directory.
"""
all_manifest_filepaths = []
for prediction_filepath in prediction_filepaths:
max_shard_id = 0
shard_data = {}
full_path = os.path.join(prediction_filepath, "predictions_all.json")
with open(full_path, 'r') as f:
for line in f.readlines():
data_entry = json.loads(line)
shard_id = data_entry.get("shard_id")
if max_shard_id < shard_id:
max_shard_id = shard_id
if shard_id not in shard_data:
shard_data[shard_id] = []
shard_data[shard_id].append(data_entry)
for shard_id, entries in shard_data.items():
output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json")
with open(output_filename, 'w') as f:
for data_entry in entries:
if data_entry['audio_filepath'].endswith(".wav"):
if 'text' in data_entry:
data_entry['orig_text'] = data_entry.pop('text')
data_entry['text'] = data_entry.pop('pred_text')
json.dump(data_entry, f, ensure_ascii=False)
f.write("\n")
shard_manifest_filepath = os.path.join(
prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json"
)
all_manifest_filepaths.append(shard_manifest_filepath)
return all_manifest_filepaths
def create_transcribed_manifests(prediction_filepaths: List[str]) -> List[str]:
"""
Creates updated transcribed manifest files by processing predictions.
This function reads prediction files (`predictions_all.json`) from the provided directories,
updates the transcription data by renaming the `pred_text` field to `text`, and stores the
original `text` field as `orig_text`. The updated data is written to new transcribed manifest
files (`transcribed_manifest.json`) in each directory.
Args:
prediction_filepaths (List[str]): A list of file paths to directories containing
prediction files (`predictions_all.json`).
Returns:
List[str]: A list of file paths to the newly created transcribed manifest files
(`transcribed_manifest.json`).
"""
all_manifest_filepaths = []
for prediction_filepath in prediction_filepaths:
prediction_name = os.path.join(prediction_filepath, "predictions_all.json")
transcripted_name = os.path.join(prediction_filepath, f"transcribed_manifest.json")
# Open and read the original predictions_all.json file
with open(transcripted_name, 'w', encoding='utf-8') as f:
with open(prediction_name, 'r', encoding='utf-8') as pred_f:
for line in pred_f.readlines():
data_entry = json.loads(line)
if 'text' in data_entry:
data_entry['orig_text'] = data_entry.pop('text')
data_entry['text'] = data_entry.pop('pred_text')
json.dump(data_entry, f, ensure_ascii=False)
f.write("\n")
# Append the path of the new manifest file to the list
all_manifest_filepaths.append(transcripted_name)
return all_manifest_filepaths
def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[List[str]]:
"""
Updates transcriptions by merging predicted shard data and transcribed manifest data.
This function processes prediction and transcribed manifest files, merges them
by matching the shard_id and audio file paths. For each shard, the corresponding
data entries are written to a new file.
Args:
manifest_filepaths (List[str]): A list of file paths to directories containing
prediction and transcribed manifest files.
Returns:
List[List[str]]: A list of lists containing the file paths to the generated
transcribed shard manifest files.
"""
all_manifest_filepaths = []
# Process each prediction directory
for prediction_filepath in manifest_filepaths:
predicted_shard_data = {}
# Collect entries from prediction files based on shard id
prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
with open(prediction_path, 'r') as f:
for line in f:
data_entry = json.loads(line)
shard_id = data_entry.get("shard_id")
audio_filepath = data_entry['audio_filepath']
predicted_shard_data.setdefault(shard_id, {})[audio_filepath] = data_entry
max_shard_id = 0
for full_path in glob.glob(os.path.join(prediction_filepath, f"transcribed_manifest_[0-9]*.json")):
all_data_entries = []
with open(full_path, 'r') as f:
for line in f:
data_entry = json.loads(line)
shard_id = data_entry.get("shard_id")
max_shard_id = max(max_shard_id, shard_id)
all_data_entries.append(data_entry)
# Write the merged data to a new manifest file keeping new transcriptions
output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json")
with open(output_filename, 'w') as f:
for data_entry in all_data_entries:
audio_filepath = data_entry['audio_filepath']
# Escape duplicated audio files that end with *dup
if audio_filepath.endswith(".wav"):
if shard_id in predicted_shard_data and audio_filepath in predicted_shard_data[shard_id]:
predicted_data_entry = predicted_shard_data[shard_id][audio_filepath]
if 'text' in predicted_data_entry:
predicted_data_entry['orig_text'] = predicted_data_entry.pop('text')
if "pred_text" in predicted_data_entry:
predicted_data_entry['text'] = predicted_data_entry.pop('pred_text')
json.dump(predicted_data_entry, f, ensure_ascii=False)
else:
json.dump(data_entry, f, ensure_ascii=False)
f.write("\n")
shard_manifest_filepath = os.path.join(prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json")
all_manifest_filepaths.append([shard_manifest_filepath])
return all_manifest_filepaths
def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]:
"""
Updates transcriptions by merging predicted data with transcribed manifest data.
This function processes prediction and transcribed manifest files within given directories.
It matches audio file paths to update transcriptions with predictions, ensuring each audio file
is properly transcribed. The updated data is written to the transcribed manifest file.
Args:
manifest_filepaths (List[str]): A list of file paths to directories containing
the prediction file (`predictions_all.json`) and the transcribed manifest file
(`transcribed_manifest.json`).
Returns:
List[str]: A list of file paths to the updated transcribed manifest files.
"""
all_manifest_filepaths = []
for prediction_filepath in manifest_filepaths:
predicted_data = {}
prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
with open(prediction_path, 'r') as f:
for line in f:
data_entry = json.loads(line)
path = data_entry['audio_filepath']
predicted_data[path] = data_entry
full_path = os.path.join(prediction_filepath, f"transcribed_manifest.json")
all_data_entries = []
with open(full_path, 'r') as f:
for line in f:
data_entry = json.loads(line)
all_data_entries.append(data_entry)
output_filename = os.path.join(prediction_filepath, f"transcribed_manifest.json")
with open(output_filename, 'w') as f:
for data_entry in all_data_entries:
audio_filepath = data_entry['audio_filepath']
if audio_filepath.endswith(".wav"):
if audio_filepath in predicted_data:
predicted_data_entry = predicted_data[audio_filepath]
if 'text' in predicted_data_entry:
predicted_data_entry['orig_text'] = predicted_data_entry.pop('text')
predicted_data_entry['text'] = predicted_data_entry.pop('pred_text')
json.dump(predicted_data_entry, f, ensure_ascii=False)
f.write("\n")
else:
json.dump(data_entry, f, ensure_ascii=False)
f.write("\n")
all_manifest_filepaths.append(output_filename)
return all_manifest_filepaths
if __name__ == "__main__":
rank = int(os.environ.get("RANK", 0)) # Default to 0 if not set
parser = argparse.ArgumentParser(description="Script to create or write transcriptions")
parser.add_argument("--is_tarred", action="store_true", help="If true, processes tarred manifests")
parser.add_argument("--full_pass", action="store_true", help="If true, processes full pass manifests")
parser.add_argument(
"--prediction_filepaths",
type=str,
nargs='+', # Accepts one or more values as a list
required=True,
help="Paths to one or more inference config YAML files.",
)
args = parser.parse_args()
lock_dir = os.path.dirname(args.prediction_filepaths[0])
lock_file = lock_dir + "/my_script.lock"
with FileLock(lock_file):
if rank == 0:
if args.is_tarred:
result = (
write_sampled_shard_transcriptions(args.prediction_filepaths)
if not args.full_pass
else create_transcribed_shard_manifests(args.prediction_filepaths)
)
else:
result = (
write_sampled_transcriptions(args.prediction_filepaths)
if not args.full_pass
else create_transcribed_manifests(args.prediction_filepaths)
)
# Remove the lock file after the FileLock context is exited
if os.path.exists(lock_file):
os.remove(lock_file)