File size: 12,160 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import json
import os
from typing import List


from filelock import FileLock

from nemo.utils import logging


def create_transcribed_shard_manifests(prediction_filepaths: List[str]) -> List[str]:
    """
    Creates transcribed shard manifest files by processing predictions and organizing them by shard ID.

    This function reads a `predictions_all.json` file from each given directory, organizes the data by
    shard IDs, and writes the entries to separate shard manifest files. For each shard, the `pred_text`
    field is updated as the main transcription (`text`), and the original transcription (`text`) is
    stored as `orig_text`.

    Args:
        prediction_filepaths (List[str]): A list of file paths to directories containing
            `predictions_all.json` files with prediction data, including shard IDs.

    Returns:
        List[str]: A list of file paths to the combined manifest files (`transcribed_manifest__OP_0..CL_.json`)
        created for each directory.
    """
    all_manifest_filepaths = []
    for prediction_filepath in prediction_filepaths:
        max_shard_id = 0
        shard_data = {}
        full_path = os.path.join(prediction_filepath, "predictions_all.json")
        with open(full_path, 'r') as f:
            for line in f.readlines():
                data_entry = json.loads(line)
                shard_id = data_entry.get("shard_id")
                if max_shard_id < shard_id:
                    max_shard_id = shard_id
                if shard_id not in shard_data:
                    shard_data[shard_id] = []
                shard_data[shard_id].append(data_entry)
        for shard_id, entries in shard_data.items():
            output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json")
            with open(output_filename, 'w') as f:
                for data_entry in entries:
                    if data_entry['audio_filepath'].endswith(".wav"):
                        if 'text' in data_entry:
                            data_entry['orig_text'] = data_entry.pop('text')
                        data_entry['text'] = data_entry.pop('pred_text')
                        json.dump(data_entry, f, ensure_ascii=False)
                        f.write("\n")
        shard_manifest_filepath = os.path.join(
            prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json"
        )
        all_manifest_filepaths.append(shard_manifest_filepath)
    return all_manifest_filepaths


def create_transcribed_manifests(prediction_filepaths: List[str]) -> List[str]:
    """
    Creates updated transcribed manifest files by processing predictions.

    This function reads prediction files (`predictions_all.json`) from the provided directories,
    updates the transcription data by renaming the `pred_text` field to `text`, and stores the
    original `text` field as `orig_text`. The updated data is written to new transcribed manifest
    files (`transcribed_manifest.json`) in each directory.

    Args:
        prediction_filepaths (List[str]): A list of file paths to directories containing
            prediction files (`predictions_all.json`).

    Returns:
        List[str]: A list of file paths to the newly created transcribed manifest files
        (`transcribed_manifest.json`).
    """
    all_manifest_filepaths = []
    for prediction_filepath in prediction_filepaths:
        prediction_name = os.path.join(prediction_filepath, "predictions_all.json")
        transcripted_name = os.path.join(prediction_filepath, f"transcribed_manifest.json")

        # Open and read the original predictions_all.json file
        with open(transcripted_name, 'w', encoding='utf-8') as f:
            with open(prediction_name, 'r', encoding='utf-8') as pred_f:

                for line in pred_f.readlines():
                    data_entry = json.loads(line)
                    if 'text' in data_entry:
                        data_entry['orig_text'] = data_entry.pop('text')
                    data_entry['text'] = data_entry.pop('pred_text')
                    json.dump(data_entry, f, ensure_ascii=False)
                    f.write("\n")
            # Append the path of the new manifest file to the list
            all_manifest_filepaths.append(transcripted_name)

    return all_manifest_filepaths


def write_sampled_shard_transcriptions(manifest_filepaths: List[str]) -> List[List[str]]:
    """
    Updates transcriptions by merging predicted shard data and transcribed manifest data.
    This function processes prediction and transcribed manifest files, merges them
    by matching the shard_id and audio file paths. For each shard, the corresponding
    data entries are written to a new file.
    Args:
        manifest_filepaths (List[str]): A list of file paths to directories containing
            prediction and transcribed manifest files.
    Returns:
        List[List[str]]: A list of lists containing the file paths to the generated
            transcribed shard manifest files.
    """
    all_manifest_filepaths = []

    # Process each prediction directory
    for prediction_filepath in manifest_filepaths:
        predicted_shard_data = {}
        # Collect entries from prediction files based on shard id
        prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
        with open(prediction_path, 'r') as f:
            for line in f:
                data_entry = json.loads(line)
                shard_id = data_entry.get("shard_id")
                audio_filepath = data_entry['audio_filepath']
                predicted_shard_data.setdefault(shard_id, {})[audio_filepath] = data_entry
    max_shard_id = 0
    for full_path in glob.glob(os.path.join(prediction_filepath, f"transcribed_manifest_[0-9]*.json")):
        all_data_entries = []
        with open(full_path, 'r') as f:
            for line in f:
                data_entry = json.loads(line)
                shard_id = data_entry.get("shard_id")
                max_shard_id = max(max_shard_id, shard_id)
                all_data_entries.append(data_entry)
        # Write the merged data to a new manifest file keeping new transcriptions
        output_filename = os.path.join(prediction_filepath, f"transcribed_manifest_{shard_id}.json")
        with open(output_filename, 'w') as f:
            for data_entry in all_data_entries:
                audio_filepath = data_entry['audio_filepath']
                # Escape duplicated audio files that end with *dup
                if audio_filepath.endswith(".wav"):
                    if shard_id in predicted_shard_data and audio_filepath in predicted_shard_data[shard_id]:
                        predicted_data_entry = predicted_shard_data[shard_id][audio_filepath]
                        if 'text' in predicted_data_entry:
                            predicted_data_entry['orig_text'] = predicted_data_entry.pop('text')
                        if "pred_text" in predicted_data_entry:
                            predicted_data_entry['text'] = predicted_data_entry.pop('pred_text')
                        json.dump(predicted_data_entry, f, ensure_ascii=False)
                    else:
                        json.dump(data_entry, f, ensure_ascii=False)
                    f.write("\n")

    shard_manifest_filepath = os.path.join(prediction_filepath, f"transcribed_manifest__OP_0..{max_shard_id}_CL_.json")
    all_manifest_filepaths.append([shard_manifest_filepath])

    return all_manifest_filepaths


def write_sampled_transcriptions(manifest_filepaths: List[str]) -> List[str]:
    """
    Updates transcriptions by merging predicted data with transcribed manifest data.

    This function processes prediction and transcribed manifest files within given directories.
    It matches audio file paths to update transcriptions with predictions, ensuring each audio file
    is properly transcribed. The updated data is written to the transcribed manifest file.

    Args:
        manifest_filepaths (List[str]): A list of file paths to directories containing
            the prediction file (`predictions_all.json`) and the transcribed manifest file
            (`transcribed_manifest.json`).

    Returns:
        List[str]: A list of file paths to the updated transcribed manifest files.
    """
    all_manifest_filepaths = []
    for prediction_filepath in manifest_filepaths:
        predicted_data = {}
        prediction_path = os.path.join(prediction_filepath, "predictions_all.json")
        with open(prediction_path, 'r') as f:
            for line in f:
                data_entry = json.loads(line)
                path = data_entry['audio_filepath']
                predicted_data[path] = data_entry

        full_path = os.path.join(prediction_filepath, f"transcribed_manifest.json")
        all_data_entries = []
        with open(full_path, 'r') as f:
            for line in f:
                data_entry = json.loads(line)
                all_data_entries.append(data_entry)

        output_filename = os.path.join(prediction_filepath, f"transcribed_manifest.json")
        with open(output_filename, 'w') as f:
            for data_entry in all_data_entries:
                audio_filepath = data_entry['audio_filepath']
                if audio_filepath.endswith(".wav"):
                    if audio_filepath in predicted_data:
                        predicted_data_entry = predicted_data[audio_filepath]
                        if 'text' in predicted_data_entry:
                            predicted_data_entry['orig_text'] = predicted_data_entry.pop('text')
                        predicted_data_entry['text'] = predicted_data_entry.pop('pred_text')
                        json.dump(predicted_data_entry, f, ensure_ascii=False)
                        f.write("\n")
                    else:
                        json.dump(data_entry, f, ensure_ascii=False)
                        f.write("\n")
        all_manifest_filepaths.append(output_filename)
    return all_manifest_filepaths


if __name__ == "__main__":
    rank = int(os.environ.get("RANK", 0))  # Default to 0 if not set

    parser = argparse.ArgumentParser(description="Script to create or write transcriptions")
    parser.add_argument("--is_tarred", action="store_true", help="If true, processes tarred manifests")
    parser.add_argument("--full_pass", action="store_true", help="If true, processes full pass manifests")
    parser.add_argument(
        "--prediction_filepaths",
        type=str,
        nargs='+',  # Accepts one or more values as a list
        required=True,
        help="Paths to one or more inference config YAML files.",
    )

    args = parser.parse_args()

    lock_dir = os.path.dirname(args.prediction_filepaths[0])
    lock_file = lock_dir + "/my_script.lock"

    with FileLock(lock_file):
        if rank == 0:
            if args.is_tarred:
                result = (
                    write_sampled_shard_transcriptions(args.prediction_filepaths)
                    if not args.full_pass
                    else create_transcribed_shard_manifests(args.prediction_filepaths)
                )
            else:
                result = (
                    write_sampled_transcriptions(args.prediction_filepaths)
                    if not args.full_pass
                    else create_transcribed_manifests(args.prediction_filepaths)
                )

    # Remove the lock file after the FileLock context is exited
    if os.path.exists(lock_file):
        os.remove(lock_file)