subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import pytest
import torch.cuda
from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset
from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures, WaveformFeaturizer
from nemo.collections.asr.parts.utils.speaker_utils import get_vad_out_from_rttm_line, read_rttm_lines
def is_rttm_length_too_long(rttm_file_path, wav_len_in_sec):
"""
Check if the maximum RTTM duration exceeds the length of the provided audio file.
Args:
rttm_file_path (str): Path to the RTTM file.
wav_len_in_sec (float): Length of the audio file in seconds.
Returns:
bool: True if the maximum RTTM duration is less than or equal to the length of the audio file, False otherwise.
"""
rttm_lines = read_rttm_lines(rttm_file_path)
max_rttm_sec = 0
for line in rttm_lines:
start, dur = get_vad_out_from_rttm_line(line)
max_rttm_sec = max(max_rttm_sec, start + dur)
return max_rttm_sec <= wav_len_in_sec
class TestAudioToSpeechE2ESpkDiarDataset:
@pytest.mark.unit
def test_e2e_speaker_diar_dataset(self, test_data_dir):
manifest_path = os.path.abspath(os.path.join(test_data_dir, 'asr/diarizer/lsm_val.json'))
batch_size = 4
num_samples = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dict_list = []
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
with open(manifest_path, 'r', encoding='utf-8') as mfile:
for ix, line in enumerate(mfile):
if ix >= num_samples:
break
line = line.replace("tests/data/", test_data_dir + "/").replace("\n", "")
f.write(f"{line}\n")
data_dict = json.loads(line)
data_dict_list.append(data_dict)
f.seek(0)
featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None)
fb_featurizer = FilterbankFeatures(
sample_rate=featurizer.sample_rate,
n_window_size=int(0.025 * featurizer.sample_rate),
n_window_stride=int(0.01 * featurizer.sample_rate),
dither=False,
)
dataset = AudioToSpeechE2ESpkDiarDataset(
manifest_filepath=f.name,
soft_label_thres=0.5,
session_len_sec=90,
num_spks=4,
featurizer=featurizer,
window_stride=0.01,
global_rank=0,
soft_targets=False,
device=device,
fb_featurizer=fb_featurizer,
)
dataloader_instance = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=dataset.eesd_train_collate_fn,
drop_last=False,
shuffle=False,
num_workers=0,
pin_memory=False,
)
assert len(dataloader_instance) == (num_samples / batch_size) # Check if the number of batches is correct
batch_counts = len(dataloader_instance)
deviation_thres_rate = 0.01 # 1% deviation allowed
for batch_index, batch in enumerate(dataloader_instance):
if batch_index != batch_counts - 1:
assert len(batch) == batch_size, "Batch size does not match the expected value"
audio_signals, audio_signal_len, targets, target_lens = batch
for sample_index in range(audio_signals.shape[0]):
dataloader_audio_in_sec = audio_signal_len[sample_index].item()
data_dur_in_sec = abs(
data_dict_list[batch_size * batch_index + sample_index]['duration'] * featurizer.sample_rate
- dataloader_audio_in_sec
)
assert (
data_dur_in_sec <= deviation_thres_rate * dataloader_audio_in_sec
), "Duration deviation exceeds 1%"
assert not torch.isnan(audio_signals).any(), "audio_signals tensor contains NaN values"
assert not torch.isnan(audio_signal_len).any(), "audio_signal_len tensor contains NaN values"
assert not torch.isnan(targets).any(), "targets tensor contains NaN values"
assert not torch.isnan(target_lens).any(), "target_lens tensor contains NaN values"