MagpieTTS_Internal_Demo / tests /collections /asr /test_asr_classification_model.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import tempfile
import lightning.pytorch as pl
import numpy as np
import pytest
import soundfile as sf
import torch
from omegaconf import DictConfig, ListConfig
from nemo.collections.asr.data import audio_to_label
from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel, configs
from nemo.utils.config_utils import assert_dataclass_signature_match
@pytest.fixture()
def speech_classification_model():
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
encoder = {
'cls': 'nemo.collections.asr.modules.ConvASREncoder',
'params': {
'feat_in': 64,
'activation': 'relu',
'conv_mask': True,
'jasper': [
{
'filters': 32,
'repeat': 1,
'kernel': [1],
'stride': [1],
'dilation': [1],
'dropout': 0.0,
'residual': False,
'separable': True,
'se': True,
'se_context_size': -1,
}
],
},
}
decoder = {
'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification',
'params': {
'feat_in': 32,
'num_classes': 30,
},
}
modelConfig = DictConfig(
{
'preprocessor': DictConfig(preprocessor),
'encoder': DictConfig(encoder),
'decoder': DictConfig(decoder),
'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]),
}
)
model = EncDecClassificationModel(cfg=modelConfig)
return model
@pytest.fixture()
def frame_classification_model():
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
encoder = {
'cls': 'nemo.collections.asr.modules.ConvASREncoder',
'params': {
'feat_in': 64,
'activation': 'relu',
'conv_mask': True,
'jasper': [
{
'filters': 32,
'repeat': 1,
'kernel': [1],
'stride': [1],
'dilation': [1],
'dropout': 0.0,
'residual': False,
'separable': True,
'se': True,
'se_context_size': -1,
}
],
},
}
decoder = {
'cls': 'nemo.collections.common.parts.MultiLayerPerceptron',
'params': {
'hidden_size': 32,
'num_classes': 5,
},
}
optim = {
'name': 'sgd',
'lr': 0.01,
'weight_decay': 0.001,
'momentum': 0.9,
}
modelConfig = DictConfig(
{
'preprocessor': DictConfig(preprocessor),
'encoder': DictConfig(encoder),
'decoder': DictConfig(decoder),
'optim': DictConfig(optim),
'labels': ListConfig(["0", "1"]),
}
)
model = EncDecFrameClassificationModel(cfg=modelConfig)
return model
class TestEncDecClassificationModel:
@pytest.mark.unit
def test_constructor(self, speech_classification_model):
asr_model = speech_classification_model.train()
conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32) # separable kernel + bias + pointwise kernel + bias
bn_cnt = (4 * 32) * 2 # 2 * moving averages
dec_cnt = 32 * 30 + 30 # fc + bias
param_count = conv_cnt + bn_cnt + dec_cnt
assert asr_model.num_weights == param_count
# Check to/from config_dict:
confdict = asr_model.to_config_dict()
instance2 = EncDecClassificationModel.from_config_dict(confdict)
assert isinstance(instance2, EncDecClassificationModel)
@pytest.mark.unit
def test_forward(self, speech_classification_model):
asr_model = speech_classification_model.eval()
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
input_signal = torch.randn(size=(4, 512))
length = torch.randint(low=321, high=500, size=[4])
with torch.no_grad():
# batch size 1
logprobs_instance = []
for i in range(input_signal.size(0)):
logprobs_ins = asr_model.forward(
input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1]
)
logprobs_instance.append(logprobs_ins)
logprobs_instance = torch.cat(logprobs_instance, 0)
# batch size 4
logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length)
assert logprobs_instance.shape == logprobs_batch.shape
diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch))
assert diff <= 1e-6
diff = torch.max(torch.abs(logprobs_instance - logprobs_batch))
assert diff <= 1e-6
@pytest.mark.unit
def test_vocab_change(self, speech_classification_model):
asr_model = speech_classification_model.train()
old_labels = copy.deepcopy(asr_model._cfg.labels)
nw1 = asr_model.num_weights
asr_model.change_labels(new_labels=old_labels)
# No change
assert nw1 == asr_model.num_weights
new_labels = copy.deepcopy(old_labels)
new_labels.append('dummy_cls_31')
new_labels.append('dummy_cls_32')
new_labels.append('dummy_cls_33')
asr_model.change_labels(new_labels=new_labels)
# fully connected + bias
assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1)
@pytest.mark.unit
def test_transcription(self, speech_classification_model, test_data_dir):
# Ground truth labels = ["yes", "no"]
audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav']
audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames]
model = speech_classification_model.eval()
# Test Top 1 classification transcription
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([1])
# Test Top 5 classification transcription
model._accuracy.top_k = [5] # set top k to 5 for accuracy calculation
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([5])
# Test Top 1 and Top 5 classification transcription
model._accuracy.top_k = [1, 5]
results = model.transcribe(audio_paths, batch_size=2)
assert len(results) == 2
assert results[0].shape == torch.Size([2, 1])
assert results[1].shape == torch.Size([2, 5])
assert model._accuracy.top_k == [1, 5]
# Test log probs extraction
model._accuracy.top_k = [1]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])
# Test log probs extraction remains same for any top_k
model._accuracy.top_k = [5]
results = model.transcribe(audio_paths, batch_size=2, logprobs=True)
assert len(results) == 2
assert results[0].shape == torch.Size([len(model.cfg.labels)])
@pytest.mark.unit
def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
'num_workers',
'batch_size',
'tarred_audio_filepaths',
'shuffle',
'pin_memory',
'drop_last',
'tarred_shard_strategy',
'shuffle_n',
# `featurizer` is supplied at runtime
'featurizer',
# additional ignored arguments
'vad_stream',
'int_values',
'sample_rate',
'normalize_audio',
'augmentor',
'bucketing_batch_size',
'bucketing_strategy',
'bucketing_weights',
]
REMAP_ARGS = {'trim_silence': 'trim'}
result = assert_dataclass_signature_match(
audio_to_label.AudioToSpeechLabelDataset,
configs.EncDecClassificationDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None
class TestEncDecFrameClassificationModel(TestEncDecClassificationModel):
@pytest.mark.parametrize(["logits_len", "labels_len"], [(20, 10), (21, 10), (19, 10), (20, 9), (20, 11)])
@pytest.mark.unit
def test_reshape_labels(self, frame_classification_model, logits_len, labels_len):
model = frame_classification_model.eval()
logits = torch.ones(4, logits_len, 2)
labels = torch.ones(4, labels_len)
logits_len = torch.tensor([6, 7, 8, 9])
labels_len = torch.tensor([5, 6, 7, 8])
labels_new, labels_len_new = model.reshape_labels(
logits=logits, labels=labels, logits_len=logits_len, labels_len=labels_len
)
assert labels_new.size(1) == logits.size(1)
assert torch.equal(labels_len_new, torch.tensor([6, 7, 8, 9]))
@pytest.mark.unit
def test_EncDecClassificationDatasetConfig_for_AudioToMultiSpeechLabelDataset(self):
# ignore some additional arguments as dataclass is generic
IGNORE_ARGS = [
'is_tarred',
'num_workers',
'batch_size',
'tarred_audio_filepaths',
'shuffle',
'pin_memory',
'drop_last',
'tarred_shard_strategy',
'shuffle_n',
# `featurizer` is supplied at runtime
'featurizer',
# additional ignored arguments
'vad_stream',
'int_values',
'sample_rate',
'normalize_audio',
'augmentor',
'bucketing_batch_size',
'bucketing_strategy',
'bucketing_weights',
'delimiter',
'normalize_audio_db',
'normalize_audio_db_target',
'window_length_in_sec',
'shift_length_in_sec',
]
REMAP_ARGS = {'trim_silence': 'trim'}
result = assert_dataclass_signature_match(
audio_to_label.AudioToMultiLabelDataset,
configs.EncDecClassificationDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None
@pytest.mark.unit
def test_frame_classification_model(self, frame_classification_model: EncDecFrameClassificationModel):
with tempfile.TemporaryDirectory() as temp_dir:
# generate random audio
audio = np.random.randn(16000 * 1)
# save the audio
audio_path = os.path.join(temp_dir, "audio.wav")
sf.write(audio_path, audio, 16000)
dummy_labels = "0 0 0 0 1 1 1 1 0 0 0 0"
dummy_sample = {
"audio_filepath": audio_path,
"offset": 0.0,
"duration": 1.0,
"label": dummy_labels,
}
# create a manifest file
manifest_path = os.path.join(temp_dir, "dummy_manifest.json")
with open(manifest_path, "w") as f:
for i in range(4):
f.write(json.dumps(dummy_sample) + "\n")
dataloader_cfg = {
"batch_size": 2,
"manifest_filepath": manifest_path,
"sample_rate": 16000,
"num_workers": 0,
"shuffle": False,
"labels": ["0", "1"],
}
trainer_cfg = {
"max_epochs": 1,
"devices": 1,
"accelerator": "auto",
}
optim = {
'name': 'sgd',
'lr': 0.01,
'weight_decay': 0.001,
'momentum': 0.9,
}
trainer = pl.Trainer(**trainer_cfg)
frame_classification_model.set_trainer(trainer)
frame_classification_model.setup_optimization(DictConfig(optim))
frame_classification_model.setup_training_data(dataloader_cfg)
frame_classification_model.setup_validation_data(dataloader_cfg)
trainer.fit(frame_classification_model)