Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import json | |
| import os | |
| import tempfile | |
| import lightning.pytorch as pl | |
| import numpy as np | |
| import pytest | |
| import soundfile as sf | |
| import torch | |
| from omegaconf import DictConfig, ListConfig | |
| from nemo.collections.asr.data import audio_to_label | |
| from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel, configs | |
| from nemo.utils.config_utils import assert_dataclass_signature_match | |
| def speech_classification_model(): | |
| preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} | |
| encoder = { | |
| 'cls': 'nemo.collections.asr.modules.ConvASREncoder', | |
| 'params': { | |
| 'feat_in': 64, | |
| 'activation': 'relu', | |
| 'conv_mask': True, | |
| 'jasper': [ | |
| { | |
| 'filters': 32, | |
| 'repeat': 1, | |
| 'kernel': [1], | |
| 'stride': [1], | |
| 'dilation': [1], | |
| 'dropout': 0.0, | |
| 'residual': False, | |
| 'separable': True, | |
| 'se': True, | |
| 'se_context_size': -1, | |
| } | |
| ], | |
| }, | |
| } | |
| decoder = { | |
| 'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification', | |
| 'params': { | |
| 'feat_in': 32, | |
| 'num_classes': 30, | |
| }, | |
| } | |
| modelConfig = DictConfig( | |
| { | |
| 'preprocessor': DictConfig(preprocessor), | |
| 'encoder': DictConfig(encoder), | |
| 'decoder': DictConfig(decoder), | |
| 'labels': ListConfig(["dummy_cls_{}".format(i + 1) for i in range(30)]), | |
| } | |
| ) | |
| model = EncDecClassificationModel(cfg=modelConfig) | |
| return model | |
| def frame_classification_model(): | |
| preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} | |
| encoder = { | |
| 'cls': 'nemo.collections.asr.modules.ConvASREncoder', | |
| 'params': { | |
| 'feat_in': 64, | |
| 'activation': 'relu', | |
| 'conv_mask': True, | |
| 'jasper': [ | |
| { | |
| 'filters': 32, | |
| 'repeat': 1, | |
| 'kernel': [1], | |
| 'stride': [1], | |
| 'dilation': [1], | |
| 'dropout': 0.0, | |
| 'residual': False, | |
| 'separable': True, | |
| 'se': True, | |
| 'se_context_size': -1, | |
| } | |
| ], | |
| }, | |
| } | |
| decoder = { | |
| 'cls': 'nemo.collections.common.parts.MultiLayerPerceptron', | |
| 'params': { | |
| 'hidden_size': 32, | |
| 'num_classes': 5, | |
| }, | |
| } | |
| optim = { | |
| 'name': 'sgd', | |
| 'lr': 0.01, | |
| 'weight_decay': 0.001, | |
| 'momentum': 0.9, | |
| } | |
| modelConfig = DictConfig( | |
| { | |
| 'preprocessor': DictConfig(preprocessor), | |
| 'encoder': DictConfig(encoder), | |
| 'decoder': DictConfig(decoder), | |
| 'optim': DictConfig(optim), | |
| 'labels': ListConfig(["0", "1"]), | |
| } | |
| ) | |
| model = EncDecFrameClassificationModel(cfg=modelConfig) | |
| return model | |
| class TestEncDecClassificationModel: | |
| def test_constructor(self, speech_classification_model): | |
| asr_model = speech_classification_model.train() | |
| conv_cnt = (64 * 32 * 1 + 32) + (64 * 1 * 1 + 32) # separable kernel + bias + pointwise kernel + bias | |
| bn_cnt = (4 * 32) * 2 # 2 * moving averages | |
| dec_cnt = 32 * 30 + 30 # fc + bias | |
| param_count = conv_cnt + bn_cnt + dec_cnt | |
| assert asr_model.num_weights == param_count | |
| # Check to/from config_dict: | |
| confdict = asr_model.to_config_dict() | |
| instance2 = EncDecClassificationModel.from_config_dict(confdict) | |
| assert isinstance(instance2, EncDecClassificationModel) | |
| def test_forward(self, speech_classification_model): | |
| asr_model = speech_classification_model.eval() | |
| asr_model.preprocessor.featurizer.dither = 0.0 | |
| asr_model.preprocessor.featurizer.pad_to = 0 | |
| input_signal = torch.randn(size=(4, 512)) | |
| length = torch.randint(low=321, high=500, size=[4]) | |
| with torch.no_grad(): | |
| # batch size 1 | |
| logprobs_instance = [] | |
| for i in range(input_signal.size(0)): | |
| logprobs_ins = asr_model.forward( | |
| input_signal=input_signal[i : i + 1], input_signal_length=length[i : i + 1] | |
| ) | |
| logprobs_instance.append(logprobs_ins) | |
| logprobs_instance = torch.cat(logprobs_instance, 0) | |
| # batch size 4 | |
| logprobs_batch = asr_model.forward(input_signal=input_signal, input_signal_length=length) | |
| assert logprobs_instance.shape == logprobs_batch.shape | |
| diff = torch.mean(torch.abs(logprobs_instance - logprobs_batch)) | |
| assert diff <= 1e-6 | |
| diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) | |
| assert diff <= 1e-6 | |
| def test_vocab_change(self, speech_classification_model): | |
| asr_model = speech_classification_model.train() | |
| old_labels = copy.deepcopy(asr_model._cfg.labels) | |
| nw1 = asr_model.num_weights | |
| asr_model.change_labels(new_labels=old_labels) | |
| # No change | |
| assert nw1 == asr_model.num_weights | |
| new_labels = copy.deepcopy(old_labels) | |
| new_labels.append('dummy_cls_31') | |
| new_labels.append('dummy_cls_32') | |
| new_labels.append('dummy_cls_33') | |
| asr_model.change_labels(new_labels=new_labels) | |
| # fully connected + bias | |
| assert asr_model.num_weights == nw1 + 3 * (asr_model.decoder._feat_in + 1) | |
| def test_transcription(self, speech_classification_model, test_data_dir): | |
| # Ground truth labels = ["yes", "no"] | |
| audio_filenames = ['an22-flrp-b.wav', 'an90-fbbh-b.wav'] | |
| audio_paths = [os.path.join(test_data_dir, "asr", "train", "an4", "wav", fp) for fp in audio_filenames] | |
| model = speech_classification_model.eval() | |
| # Test Top 1 classification transcription | |
| results = model.transcribe(audio_paths, batch_size=2) | |
| assert len(results) == 2 | |
| assert results[0].shape == torch.Size([1]) | |
| # Test Top 5 classification transcription | |
| model._accuracy.top_k = [5] # set top k to 5 for accuracy calculation | |
| results = model.transcribe(audio_paths, batch_size=2) | |
| assert len(results) == 2 | |
| assert results[0].shape == torch.Size([5]) | |
| # Test Top 1 and Top 5 classification transcription | |
| model._accuracy.top_k = [1, 5] | |
| results = model.transcribe(audio_paths, batch_size=2) | |
| assert len(results) == 2 | |
| assert results[0].shape == torch.Size([2, 1]) | |
| assert results[1].shape == torch.Size([2, 5]) | |
| assert model._accuracy.top_k == [1, 5] | |
| # Test log probs extraction | |
| model._accuracy.top_k = [1] | |
| results = model.transcribe(audio_paths, batch_size=2, logprobs=True) | |
| assert len(results) == 2 | |
| assert results[0].shape == torch.Size([len(model.cfg.labels)]) | |
| # Test log probs extraction remains same for any top_k | |
| model._accuracy.top_k = [5] | |
| results = model.transcribe(audio_paths, batch_size=2, logprobs=True) | |
| assert len(results) == 2 | |
| assert results[0].shape == torch.Size([len(model.cfg.labels)]) | |
| def test_EncDecClassificationDatasetConfig_for_AudioToSpeechLabelDataset(self): | |
| # ignore some additional arguments as dataclass is generic | |
| IGNORE_ARGS = [ | |
| 'is_tarred', | |
| 'num_workers', | |
| 'batch_size', | |
| 'tarred_audio_filepaths', | |
| 'shuffle', | |
| 'pin_memory', | |
| 'drop_last', | |
| 'tarred_shard_strategy', | |
| 'shuffle_n', | |
| # `featurizer` is supplied at runtime | |
| 'featurizer', | |
| # additional ignored arguments | |
| 'vad_stream', | |
| 'int_values', | |
| 'sample_rate', | |
| 'normalize_audio', | |
| 'augmentor', | |
| 'bucketing_batch_size', | |
| 'bucketing_strategy', | |
| 'bucketing_weights', | |
| ] | |
| REMAP_ARGS = {'trim_silence': 'trim'} | |
| result = assert_dataclass_signature_match( | |
| audio_to_label.AudioToSpeechLabelDataset, | |
| configs.EncDecClassificationDatasetConfig, | |
| ignore_args=IGNORE_ARGS, | |
| remap_args=REMAP_ARGS, | |
| ) | |
| signatures_match, cls_subset, dataclass_subset = result | |
| assert signatures_match | |
| assert cls_subset is None | |
| assert dataclass_subset is None | |
| class TestEncDecFrameClassificationModel(TestEncDecClassificationModel): | |
| def test_reshape_labels(self, frame_classification_model, logits_len, labels_len): | |
| model = frame_classification_model.eval() | |
| logits = torch.ones(4, logits_len, 2) | |
| labels = torch.ones(4, labels_len) | |
| logits_len = torch.tensor([6, 7, 8, 9]) | |
| labels_len = torch.tensor([5, 6, 7, 8]) | |
| labels_new, labels_len_new = model.reshape_labels( | |
| logits=logits, labels=labels, logits_len=logits_len, labels_len=labels_len | |
| ) | |
| assert labels_new.size(1) == logits.size(1) | |
| assert torch.equal(labels_len_new, torch.tensor([6, 7, 8, 9])) | |
| def test_EncDecClassificationDatasetConfig_for_AudioToMultiSpeechLabelDataset(self): | |
| # ignore some additional arguments as dataclass is generic | |
| IGNORE_ARGS = [ | |
| 'is_tarred', | |
| 'num_workers', | |
| 'batch_size', | |
| 'tarred_audio_filepaths', | |
| 'shuffle', | |
| 'pin_memory', | |
| 'drop_last', | |
| 'tarred_shard_strategy', | |
| 'shuffle_n', | |
| # `featurizer` is supplied at runtime | |
| 'featurizer', | |
| # additional ignored arguments | |
| 'vad_stream', | |
| 'int_values', | |
| 'sample_rate', | |
| 'normalize_audio', | |
| 'augmentor', | |
| 'bucketing_batch_size', | |
| 'bucketing_strategy', | |
| 'bucketing_weights', | |
| 'delimiter', | |
| 'normalize_audio_db', | |
| 'normalize_audio_db_target', | |
| 'window_length_in_sec', | |
| 'shift_length_in_sec', | |
| ] | |
| REMAP_ARGS = {'trim_silence': 'trim'} | |
| result = assert_dataclass_signature_match( | |
| audio_to_label.AudioToMultiLabelDataset, | |
| configs.EncDecClassificationDatasetConfig, | |
| ignore_args=IGNORE_ARGS, | |
| remap_args=REMAP_ARGS, | |
| ) | |
| signatures_match, cls_subset, dataclass_subset = result | |
| assert signatures_match | |
| assert cls_subset is None | |
| assert dataclass_subset is None | |
| def test_frame_classification_model(self, frame_classification_model: EncDecFrameClassificationModel): | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # generate random audio | |
| audio = np.random.randn(16000 * 1) | |
| # save the audio | |
| audio_path = os.path.join(temp_dir, "audio.wav") | |
| sf.write(audio_path, audio, 16000) | |
| dummy_labels = "0 0 0 0 1 1 1 1 0 0 0 0" | |
| dummy_sample = { | |
| "audio_filepath": audio_path, | |
| "offset": 0.0, | |
| "duration": 1.0, | |
| "label": dummy_labels, | |
| } | |
| # create a manifest file | |
| manifest_path = os.path.join(temp_dir, "dummy_manifest.json") | |
| with open(manifest_path, "w") as f: | |
| for i in range(4): | |
| f.write(json.dumps(dummy_sample) + "\n") | |
| dataloader_cfg = { | |
| "batch_size": 2, | |
| "manifest_filepath": manifest_path, | |
| "sample_rate": 16000, | |
| "num_workers": 0, | |
| "shuffle": False, | |
| "labels": ["0", "1"], | |
| } | |
| trainer_cfg = { | |
| "max_epochs": 1, | |
| "devices": 1, | |
| "accelerator": "auto", | |
| } | |
| optim = { | |
| 'name': 'sgd', | |
| 'lr': 0.01, | |
| 'weight_decay': 0.001, | |
| 'momentum': 0.9, | |
| } | |
| trainer = pl.Trainer(**trainer_cfg) | |
| frame_classification_model.set_trainer(trainer) | |
| frame_classification_model.setup_optimization(DictConfig(optim)) | |
| frame_classification_model.setup_training_data(dataloader_cfg) | |
| frame_classification_model.setup_validation_data(dataloader_cfg) | |
| trainer.fit(frame_classification_model) | |