# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib from typing import Optional import numpy as np import pytest import torch from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures from nemo.collections.audio.modules.masking import ( MaskBasedDereverbWPE, MaskEstimatorFlexChannels, MaskEstimatorGSS, MaskReferenceChannel, ) from nemo.collections.audio.modules.ssl_pretrain_masking import SSLPretrainWithMaskedPatch from nemo.collections.audio.modules.transforms import AudioToSpectrogram from nemo.collections.audio.parts.submodules.multichannel import WPEFilter from nemo.collections.audio.parts.utils.audio import convmtx_mc_numpy from nemo.utils import logging class TestSpectrogramToMultichannelFeatures: @pytest.mark.unit @pytest.mark.parametrize('fft_length', [128]) @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('mag_reduction', [None, 'rms', 'abs_mean', 'mean_abs']) @pytest.mark.parametrize('mag_power', [None, 2]) @pytest.mark.parametrize('mag_normalization', [None, 'mean', 'mean_var']) def test_magnitude( self, fft_length: int, num_channels: int, mag_reduction: Optional[str], mag_power: Optional[float], mag_normalization: Optional[str], ): """Test calculation of spatial features for multi-channel audio.""" atol = 5e-5 batch_size = 8 num_samples = fft_length * 50 num_examples = 10 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) spec2feat = SpectrogramToMultichannelFeatures( num_subbands=audio2spec.num_subbands, mag_reduction=mag_reduction, mag_power=mag_power, mag_normalization=mag_normalization, use_ipd=False, ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) # convert to spectrogram spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) # UUT output feat, _ = spec2feat(input=spec, input_length=spec_len) feat_np = feat.cpu().detach().numpy() # Golden output spec_np = spec.cpu().detach().numpy() if mag_reduction is None: feat_golden = np.abs(spec_np) elif mag_reduction == 'rms': feat_golden = np.sqrt(np.mean(np.abs(spec_np) ** 2, axis=1, keepdims=True)) elif mag_reduction == 'mean_abs': feat_golden = np.mean(np.abs(spec_np), axis=1, keepdims=True) elif mag_reduction == 'abs_mean': feat_golden = np.abs(np.mean(spec_np, axis=1, keepdims=True)) else: raise NotImplementedError(f'Magnitude reduction {mag_reduction} not implemented') if mag_power is not None: feat_golden = np.power(feat_golden, mag_power) if mag_normalization == 'mean': feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True) elif mag_normalization == 'mean_var': feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True) feat_golden = feat_golden / np.sqrt(np.mean(feat_golden**2, axis=(1, 3), keepdims=True)) # Compare shape assert feat_np.shape == feat_golden.shape, f'Feature shape not matching for example {n}' # Compare values assert np.allclose(feat_np, feat_golden, atol=atol), f'Features not matching for example {n}' @pytest.mark.unit @pytest.mark.parametrize('fft_length', [128]) @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('ipd_normalization', [None, 'mean', 'mean_var']) @pytest.mark.parametrize('use_input_length', [True, False]) def test_ipd(self, fft_length: int, num_channels: int, ipd_normalization: Optional[str], use_input_length: bool): """Test calculation of IPD spatial features for multi-channel audio.""" atol = 5e-5 batch_size = 8 num_samples = fft_length * 50 num_examples = 10 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) spec2feat = SpectrogramToMultichannelFeatures( num_subbands=audio2spec.num_subbands, mag_reduction='rms', use_ipd=True, mag_normalization=None, ipd_normalization=ipd_normalization, ) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)) # UUT output feat, _ = spec2feat(input=spec, input_length=spec_len if use_input_length else None) feat_np = feat.cpu().detach().numpy() ipd = feat_np[..., audio2spec.num_subbands :, :] # Golden output spec_np = spec.cpu().detach().numpy() spec_mean = np.mean(spec_np, axis=1, keepdims=True) ipd_golden = np.angle(spec_np) - np.angle(spec_mean) ipd_golden = np.remainder(ipd_golden + np.pi, 2 * np.pi) - np.pi if ipd_normalization == 'mean': ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True) elif ipd_normalization == 'mean_var': ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True) ipd_golden = ipd_golden / np.sqrt( np.maximum(np.mean(ipd_golden**2, axis=(1, 3), keepdims=True), spec2feat.eps) ) # Compare shape assert ipd.shape == ipd_golden.shape, f'Feature shape not matching for example {n}' # Compare values assert np.allclose(ipd, ipd_golden, atol=atol), f'Features not matching for example {n}' @pytest.mark.unit @pytest.mark.parametrize('use_ipd', [False, True]) def test_num_channels(self, use_ipd: bool): """Test num channels property.""" uut = SpectrogramToMultichannelFeatures(num_subbands=32, use_ipd=use_ipd) with pytest.raises(ValueError): # num_input_channels is not set uut.num_channels for num_channels in [1, 2, 3, 4]: # num_input_channels is set uut = SpectrogramToMultichannelFeatures(num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd) assert uut.num_channels == num_channels for num_channels in [1, 2, 3, 4]: # num_input_channels is set, but magnitude will be reduced uut = SpectrogramToMultichannelFeatures( num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd, mag_reduction='rms' ) if use_ipd: assert uut.num_channels == num_channels else: assert uut.num_channels == 1 @pytest.mark.unit @pytest.mark.parametrize('use_ipd', [False, True]) def test_num_features(self, use_ipd: bool): """Test num features property.""" for num_subbands in [5, 10]: uut = SpectrogramToMultichannelFeatures(num_subbands=num_subbands, use_ipd=use_ipd) assert uut.num_features == 2 * num_subbands if use_ipd else num_subbands @pytest.mark.unit def test_unsupported_norm(self): """Test initialization with unsupported normalization.""" # test magnitude normalization with pytest.raises(NotImplementedError): SpectrogramToMultichannelFeatures( num_subbands=32, mag_reduction='rms', use_ipd=False, mag_normalization='not-implemented', ) # test phase normalization with pytest.raises(NotImplementedError): SpectrogramToMultichannelFeatures( num_subbands=32, use_ipd=True, ipd_normalization='not-implemented', ) # test magnitude reduction uut = SpectrogramToMultichannelFeatures( num_subbands=32, mag_reduction='not-implemented', ) input = torch.randn(1, 3, 100, 100) with pytest.raises(ValueError): uut(input=input, input_length=torch.Tensor([100])) class TestMaskBasedProcessor: @pytest.mark.unit @pytest.mark.parametrize('fft_length', [256]) @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('num_masks', [1, 2]) def test_mask_reference_channel(self, fft_length: int, num_channels: int, num_masks: int): """Test masking of the reference channel.""" if num_channels == 1: # Only one channel available ref_channels = [0] else: # Use first or last channel for MC signals ref_channels = [0, num_channels - 1] atol = 1e-6 batch_size = 8 num_samples = fft_length * 50 num_examples = 10 random_seed = 42 _rng = np.random.default_rng(seed=random_seed) hop_length = fft_length // 4 audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length) for ref_channel in ref_channels: mask_processor = MaskReferenceChannel(ref_channel=ref_channel) for n in range(num_examples): x = _rng.normal(size=(batch_size, num_channels, num_samples)) spec, spec_len = audio2spec( input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size) ) # Randomly-generated mask mask = _rng.uniform( low=0.0, high=1.0, size=(batch_size, num_masks, audio2spec.num_subbands, spec.shape[-1]) ) # UUT output out, _ = mask_processor(input=spec, input_length=spec_len, mask=torch.tensor(mask)) out_np = out.cpu().detach().numpy() # Golden output spec_np = spec.cpu().detach().numpy() out_golden = np.zeros_like(mask, dtype=spec_np.dtype) for m in range(num_masks): out_golden[:, m, ...] = spec_np[:, ref_channel, ...] * mask[:, m, ...] # Compare shape assert out_np.shape == out_golden.shape, f'Output shape not matching for example {n}' # Compare values assert np.allclose(out_np, out_golden, atol=atol), f'Output not matching for example {n}' class TestMaskBasedDereverb: @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('filter_length', [10]) @pytest.mark.parametrize('delay', [0, 5]) def test_wpe_convtensor(self, num_channels: int, filter_length: int, delay: int): """Test construction of convolutional tensor in WPE. Compare against reference implementation convmtx_mc. """ atol = 1e-6 random_seed = 42 num_examples = 10 batch_size = 8 num_subbands = 15 num_frames = 21 _rng = np.random.default_rng(seed=random_seed) input_size = (batch_size, num_channels, num_subbands, num_frames) for n in range(num_examples): X = _rng.normal(size=input_size) + 1j * _rng.normal(size=input_size) # Reference tilde_X_ref = np.zeros((batch_size, num_subbands, num_frames, num_channels * filter_length), dtype=X.dtype) for b in range(batch_size): for f in range(num_subbands): tilde_X_ref[b, f, :, :] = convmtx_mc_numpy( X[b, :, f, :].transpose(), filter_length=filter_length, delay=delay ) # UUT tilde_X_uut = WPEFilter.convtensor(torch.tensor(X), filter_length=filter_length, delay=delay) # UUT has vectors arranged in a tensor shape with permuted columns # Reorganize to match the shape and column permutation tilde_X_uut = WPEFilter.permute_convtensor(tilde_X_uut) tilde_X_uut = tilde_X_uut.cpu().detach().numpy() assert np.allclose(tilde_X_uut, tilde_X_ref, atol=atol), f'Example {n}: comparison failed' @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 3]) @pytest.mark.parametrize('filter_length', [10]) @pytest.mark.parametrize('delay', [0, 5]) def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int): """Test estimation of correlation matrices, filter and filtering.""" atol = 1e-6 random_seed = 42 num_examples = 10 batch_size = 4 num_subbands = 15 num_frames = 50 wpe_filter = WPEFilter(filter_length=filter_length, prediction_delay=delay, diag_reg=None) _rng = np.random.default_rng(seed=random_seed) input_size = (batch_size, num_channels, num_subbands, num_frames) for n in range(num_examples): X = torch.tensor(_rng.normal(size=input_size) + 1j * _rng.normal(size=input_size)) weight = torch.tensor(_rng.uniform(size=(batch_size, num_subbands, num_frames))) # Create convtensor (B, C, F, N, filter_length) tilde_X = wpe_filter.convtensor(X, filter_length=filter_length, delay=delay) # Test 1: # estimate_correlation # Reference # move channels to back X_golden = X.permute(0, 2, 3, 1) # move channels to back and reshape to (B, F, N, C*filter_length) tilde_X_golden = tilde_X.permute(0, 2, 3, 1, 4).reshape( batch_size, num_subbands, num_frames, num_channels * filter_length ) # (B, F, C * filter_length, C * filter_length) Q_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * tilde_X_golden) # (B, F, C * filter_length, C) R_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * X_golden) # UUT Q_uut, R_uut = wpe_filter.estimate_correlations(input=X, weight=weight, tilde_input=tilde_X) # Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length) Q_uut_flattened = Q_uut.flatten(start_dim=-2, end_dim=-1).flatten(start_dim=-3, end_dim=-2) # Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length) R_uut_flattened = R_uut.flatten(start_dim=-3, end_dim=-2) assert torch.allclose(Q_uut_flattened, Q_golden, atol=atol), f'Example {n}: comparison failed for Q' assert torch.allclose(R_uut_flattened, R_golden, atol=atol), f'Example {n}: comparison failed for R' # Test 2: # estimate_filter # Reference G_golden = torch.linalg.solve(Q_golden, R_golden) # UUT G_uut = wpe_filter.estimate_filter(Q_uut, R_uut) # Flatten and move output channels to back G_uut_flattened = G_uut.reshape(batch_size, num_channels, num_subbands, -1).permute(0, 2, 3, 1) assert torch.allclose(G_uut_flattened, G_golden, atol=atol), f'Example {n}: comparison failed for G' # Test 3: # apply_filter # Reference U_golden = torch.matmul(tilde_X_golden, G_golden) # UUT U_uut = wpe_filter.apply_filter(filter=G_uut, tilde_input=tilde_X) U_uut_ref = U_uut.permute(0, 2, 3, 1) assert torch.allclose( U_uut_ref, U_golden, atol=atol ), f'Example {n}: comparison failed for undesired output U' @pytest.mark.unit @pytest.mark.parametrize('num_channels', [3]) @pytest.mark.parametrize('filter_length', [5]) @pytest.mark.parametrize('delay', [0, 2]) def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, delay: int): """Test that dereverb can be initialized and can process audio.""" num_examples = 10 batch_size = 8 num_subbands = 15 num_frames = 21 num_iterations = 2 input_size = (batch_size, num_subbands, num_frames, num_channels) dereverb = MaskBasedDereverbWPE( filter_length=filter_length, prediction_delay=delay, num_iterations=num_iterations ) for n in range(num_examples): # multi-channel input x = torch.randn(input_size) + 1j * torch.randn(input_size) # random input_length x_length = torch.randint(1, num_frames, (batch_size,)) # multi-channel mask mask = torch.rand(input_size) # UUT y, y_length = dereverb(input=x, input_length=x_length, mask=mask) assert y.shape == x.shape, 'Output shape not matching, example {n}' assert torch.equal(y_length, x_length), 'Length not matching, example {n}' class TestMaskEstimator: @pytest.mark.unit @pytest.mark.parametrize('channel_reduction_position', [0, 1, -1]) @pytest.mark.parametrize('channel_reduction_type', ['average', 'attention']) @pytest.mark.parametrize('channel_block_type', ['transform_average_concatenate', 'transform_attend_concatenate']) def test_flex_channels( self, channel_reduction_position: int, channel_reduction_type: str, channel_block_type: str ): """Test initialization of the mask estimator and make sure it can process input tensor.""" # Model parameters num_subbands_tests = [32, 65] num_outputs_tests = [1, 2] num_blocks_tests = [1, 5] # Input configuration num_channels_tests = [1, 4] batch_size = 4 num_frames = 50 for num_subbands in num_subbands_tests: for num_outputs in num_outputs_tests: for num_blocks in num_blocks_tests: logging.debug( 'Instantiate with num_subbands=%d, num_outputs=%d, num_blocks=%d', num_subbands, num_outputs, num_blocks, ) # Instantiate uut = MaskEstimatorFlexChannels( num_outputs=num_outputs, num_subbands=num_subbands, num_blocks=num_blocks, channel_reduction_position=channel_reduction_position, channel_reduction_type=channel_reduction_type, channel_block_type=channel_block_type, ) # Process different channel configurations for num_channels in num_channels_tests: logging.debug('Process num_channels=%d', num_channels) input_size = (batch_size, num_channels, num_subbands, num_frames) # multi-channel input spec = torch.randn(input_size, dtype=torch.cfloat) spec_length = torch.randint(1, num_frames, (batch_size,)) # UUT mask, mask_length = uut(input=spec, input_length=spec_length) # Check output dimensions match expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames) assert ( mask.shape == expected_mask_shape ), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}' # Check output lengths match assert torch.all( mask_length == spec_length ), f'Output length mismatch: expected {spec_length}, got {mask_length}' @pytest.mark.unit @pytest.mark.parametrize('num_channels', [1, 4]) @pytest.mark.parametrize('num_subbands', [32, 65]) @pytest.mark.parametrize('num_outputs', [2, 3]) @pytest.mark.parametrize('batch_size', [1, 4]) def test_gss(self, num_channels: int, num_subbands: int, num_outputs: int, batch_size: int): """Test initialization of the GSS mask estimator and make sure it can process an input tensor. This tests initialization and the output shape. It does not test correctness of the output. """ # Test vector length num_frames = 50 # Instantiate UUT uut = MaskEstimatorGSS() # Process the current configuration logging.debug('Process num_channels=%d', num_channels) input_size = (batch_size, num_channels, num_subbands, num_frames) logging.debug('Input size: %s', input_size) # multi-channel input mixture_spec = torch.randn(input_size, dtype=torch.cfloat) source_activity = torch.randn(batch_size, num_outputs, num_frames) > 0 # UUT mask = uut(input=mixture_spec, activity=source_activity) # Check output dimensions match expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames) assert ( mask.shape == expected_mask_shape ), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}' class TestSSLPretrainMaskingWithPatch: @pytest.mark.unit @pytest.mark.parametrize('patch_size', [1, 5, 10]) @pytest.mark.parametrize('mask_fraction', [0.5, 1.0]) @pytest.mark.parametrize('training', [True, False]) def test_masking(self, patch_size: int, mask_fraction: float, training: bool): """Test SSL pretrain masking.""" num_subbands = 32 num_frames = 5000 num_channels = 1 batch_size = 8 abs_tol = 1e-2 # Instantiate uut = SSLPretrainWithMaskedPatch(patch_size=patch_size, mask_fraction=mask_fraction) # Set training mode if training: uut.train() else: uut.eval() # Generate random input spec and length rng = torch.Generator() rng.manual_seed(0) input_spec = torch.randn(batch_size, num_channels, num_subbands, num_frames, dtype=torch.cfloat, generator=rng) input_length = torch.randint(num_frames // 2, num_frames, (batch_size,), generator=rng) for b in range(batch_size): input_spec[b, :, :, input_length[b] :] = 0.0 # Apply masking masked_spec = uut(input_spec=input_spec, length=input_length) # Check output dimensions match assert masked_spec.shape == input_spec.shape # Check output values are masked for each example in the batch for b in range(batch_size): # Estimate mask fraction est_mask_fraction = torch.sum(masked_spec[b, :, :, : input_length[b]].abs() == 0.0) / ( num_channels * num_subbands * input_length[b] ) # Check if the estimated mask fraction is close to the expected mask fraction assert ( abs(est_mask_fraction - mask_fraction) < abs_tol ), f'Example {b}: est_mask_fraction = {est_mask_fraction}, mask_fraction = {mask_fraction}' @pytest.mark.unit def test_unsupported_initialization(self): """Test SSL pretrain masking.""" with pytest.raises(ValueError): SSLPretrainWithMaskedPatch(patch_size=0) with pytest.raises(ValueError): SSLPretrainWithMaskedPatch(patch_size=-1) with pytest.raises(ValueError): SSLPretrainWithMaskedPatch(mask_fraction=1.1) with pytest.raises(ValueError): SSLPretrainWithMaskedPatch(mask_fraction=-0.1)