subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Optional
import numpy as np
import pytest
import torch
from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures
from nemo.collections.audio.modules.masking import (
MaskBasedDereverbWPE,
MaskEstimatorFlexChannels,
MaskEstimatorGSS,
MaskReferenceChannel,
)
from nemo.collections.audio.modules.ssl_pretrain_masking import SSLPretrainWithMaskedPatch
from nemo.collections.audio.modules.transforms import AudioToSpectrogram
from nemo.collections.audio.parts.submodules.multichannel import WPEFilter
from nemo.collections.audio.parts.utils.audio import convmtx_mc_numpy
from nemo.utils import logging
class TestSpectrogramToMultichannelFeatures:
@pytest.mark.unit
@pytest.mark.parametrize('fft_length', [128])
@pytest.mark.parametrize('num_channels', [1, 3])
@pytest.mark.parametrize('mag_reduction', [None, 'rms', 'abs_mean', 'mean_abs'])
@pytest.mark.parametrize('mag_power', [None, 2])
@pytest.mark.parametrize('mag_normalization', [None, 'mean', 'mean_var'])
def test_magnitude(
self,
fft_length: int,
num_channels: int,
mag_reduction: Optional[str],
mag_power: Optional[float],
mag_normalization: Optional[str],
):
"""Test calculation of spatial features for multi-channel audio."""
atol = 5e-5
batch_size = 8
num_samples = fft_length * 50
num_examples = 10
random_seed = 42
_rng = np.random.default_rng(seed=random_seed)
hop_length = fft_length // 4
audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length)
spec2feat = SpectrogramToMultichannelFeatures(
num_subbands=audio2spec.num_subbands,
mag_reduction=mag_reduction,
mag_power=mag_power,
mag_normalization=mag_normalization,
use_ipd=False,
)
for n in range(num_examples):
x = _rng.normal(size=(batch_size, num_channels, num_samples))
# convert to spectrogram
spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size))
# UUT output
feat, _ = spec2feat(input=spec, input_length=spec_len)
feat_np = feat.cpu().detach().numpy()
# Golden output
spec_np = spec.cpu().detach().numpy()
if mag_reduction is None:
feat_golden = np.abs(spec_np)
elif mag_reduction == 'rms':
feat_golden = np.sqrt(np.mean(np.abs(spec_np) ** 2, axis=1, keepdims=True))
elif mag_reduction == 'mean_abs':
feat_golden = np.mean(np.abs(spec_np), axis=1, keepdims=True)
elif mag_reduction == 'abs_mean':
feat_golden = np.abs(np.mean(spec_np, axis=1, keepdims=True))
else:
raise NotImplementedError(f'Magnitude reduction {mag_reduction} not implemented')
if mag_power is not None:
feat_golden = np.power(feat_golden, mag_power)
if mag_normalization == 'mean':
feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True)
elif mag_normalization == 'mean_var':
feat_golden = feat_golden - np.mean(feat_golden, axis=(1, 3), keepdims=True)
feat_golden = feat_golden / np.sqrt(np.mean(feat_golden**2, axis=(1, 3), keepdims=True))
# Compare shape
assert feat_np.shape == feat_golden.shape, f'Feature shape not matching for example {n}'
# Compare values
assert np.allclose(feat_np, feat_golden, atol=atol), f'Features not matching for example {n}'
@pytest.mark.unit
@pytest.mark.parametrize('fft_length', [128])
@pytest.mark.parametrize('num_channels', [1, 3])
@pytest.mark.parametrize('ipd_normalization', [None, 'mean', 'mean_var'])
@pytest.mark.parametrize('use_input_length', [True, False])
def test_ipd(self, fft_length: int, num_channels: int, ipd_normalization: Optional[str], use_input_length: bool):
"""Test calculation of IPD spatial features for multi-channel audio."""
atol = 5e-5
batch_size = 8
num_samples = fft_length * 50
num_examples = 10
random_seed = 42
_rng = np.random.default_rng(seed=random_seed)
hop_length = fft_length // 4
audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length)
spec2feat = SpectrogramToMultichannelFeatures(
num_subbands=audio2spec.num_subbands,
mag_reduction='rms',
use_ipd=True,
mag_normalization=None,
ipd_normalization=ipd_normalization,
)
for n in range(num_examples):
x = _rng.normal(size=(batch_size, num_channels, num_samples))
spec, spec_len = audio2spec(input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size))
# UUT output
feat, _ = spec2feat(input=spec, input_length=spec_len if use_input_length else None)
feat_np = feat.cpu().detach().numpy()
ipd = feat_np[..., audio2spec.num_subbands :, :]
# Golden output
spec_np = spec.cpu().detach().numpy()
spec_mean = np.mean(spec_np, axis=1, keepdims=True)
ipd_golden = np.angle(spec_np) - np.angle(spec_mean)
ipd_golden = np.remainder(ipd_golden + np.pi, 2 * np.pi) - np.pi
if ipd_normalization == 'mean':
ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True)
elif ipd_normalization == 'mean_var':
ipd_golden = ipd_golden - np.mean(ipd_golden, axis=(1, 3), keepdims=True)
ipd_golden = ipd_golden / np.sqrt(
np.maximum(np.mean(ipd_golden**2, axis=(1, 3), keepdims=True), spec2feat.eps)
)
# Compare shape
assert ipd.shape == ipd_golden.shape, f'Feature shape not matching for example {n}'
# Compare values
assert np.allclose(ipd, ipd_golden, atol=atol), f'Features not matching for example {n}'
@pytest.mark.unit
@pytest.mark.parametrize('use_ipd', [False, True])
def test_num_channels(self, use_ipd: bool):
"""Test num channels property."""
uut = SpectrogramToMultichannelFeatures(num_subbands=32, use_ipd=use_ipd)
with pytest.raises(ValueError):
# num_input_channels is not set
uut.num_channels
for num_channels in [1, 2, 3, 4]:
# num_input_channels is set
uut = SpectrogramToMultichannelFeatures(num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd)
assert uut.num_channels == num_channels
for num_channels in [1, 2, 3, 4]:
# num_input_channels is set, but magnitude will be reduced
uut = SpectrogramToMultichannelFeatures(
num_subbands=32, num_input_channels=num_channels, use_ipd=use_ipd, mag_reduction='rms'
)
if use_ipd:
assert uut.num_channels == num_channels
else:
assert uut.num_channels == 1
@pytest.mark.unit
@pytest.mark.parametrize('use_ipd', [False, True])
def test_num_features(self, use_ipd: bool):
"""Test num features property."""
for num_subbands in [5, 10]:
uut = SpectrogramToMultichannelFeatures(num_subbands=num_subbands, use_ipd=use_ipd)
assert uut.num_features == 2 * num_subbands if use_ipd else num_subbands
@pytest.mark.unit
def test_unsupported_norm(self):
"""Test initialization with unsupported normalization."""
# test magnitude normalization
with pytest.raises(NotImplementedError):
SpectrogramToMultichannelFeatures(
num_subbands=32,
mag_reduction='rms',
use_ipd=False,
mag_normalization='not-implemented',
)
# test phase normalization
with pytest.raises(NotImplementedError):
SpectrogramToMultichannelFeatures(
num_subbands=32,
use_ipd=True,
ipd_normalization='not-implemented',
)
# test magnitude reduction
uut = SpectrogramToMultichannelFeatures(
num_subbands=32,
mag_reduction='not-implemented',
)
input = torch.randn(1, 3, 100, 100)
with pytest.raises(ValueError):
uut(input=input, input_length=torch.Tensor([100]))
class TestMaskBasedProcessor:
@pytest.mark.unit
@pytest.mark.parametrize('fft_length', [256])
@pytest.mark.parametrize('num_channels', [1, 4])
@pytest.mark.parametrize('num_masks', [1, 2])
def test_mask_reference_channel(self, fft_length: int, num_channels: int, num_masks: int):
"""Test masking of the reference channel."""
if num_channels == 1:
# Only one channel available
ref_channels = [0]
else:
# Use first or last channel for MC signals
ref_channels = [0, num_channels - 1]
atol = 1e-6
batch_size = 8
num_samples = fft_length * 50
num_examples = 10
random_seed = 42
_rng = np.random.default_rng(seed=random_seed)
hop_length = fft_length // 4
audio2spec = AudioToSpectrogram(fft_length=fft_length, hop_length=hop_length)
for ref_channel in ref_channels:
mask_processor = MaskReferenceChannel(ref_channel=ref_channel)
for n in range(num_examples):
x = _rng.normal(size=(batch_size, num_channels, num_samples))
spec, spec_len = audio2spec(
input=torch.Tensor(x), input_length=torch.Tensor([num_samples] * batch_size)
)
# Randomly-generated mask
mask = _rng.uniform(
low=0.0, high=1.0, size=(batch_size, num_masks, audio2spec.num_subbands, spec.shape[-1])
)
# UUT output
out, _ = mask_processor(input=spec, input_length=spec_len, mask=torch.tensor(mask))
out_np = out.cpu().detach().numpy()
# Golden output
spec_np = spec.cpu().detach().numpy()
out_golden = np.zeros_like(mask, dtype=spec_np.dtype)
for m in range(num_masks):
out_golden[:, m, ...] = spec_np[:, ref_channel, ...] * mask[:, m, ...]
# Compare shape
assert out_np.shape == out_golden.shape, f'Output shape not matching for example {n}'
# Compare values
assert np.allclose(out_np, out_golden, atol=atol), f'Output not matching for example {n}'
class TestMaskBasedDereverb:
@pytest.mark.unit
@pytest.mark.parametrize('num_channels', [1, 3])
@pytest.mark.parametrize('filter_length', [10])
@pytest.mark.parametrize('delay', [0, 5])
def test_wpe_convtensor(self, num_channels: int, filter_length: int, delay: int):
"""Test construction of convolutional tensor in WPE. Compare against
reference implementation convmtx_mc.
"""
atol = 1e-6
random_seed = 42
num_examples = 10
batch_size = 8
num_subbands = 15
num_frames = 21
_rng = np.random.default_rng(seed=random_seed)
input_size = (batch_size, num_channels, num_subbands, num_frames)
for n in range(num_examples):
X = _rng.normal(size=input_size) + 1j * _rng.normal(size=input_size)
# Reference
tilde_X_ref = np.zeros((batch_size, num_subbands, num_frames, num_channels * filter_length), dtype=X.dtype)
for b in range(batch_size):
for f in range(num_subbands):
tilde_X_ref[b, f, :, :] = convmtx_mc_numpy(
X[b, :, f, :].transpose(), filter_length=filter_length, delay=delay
)
# UUT
tilde_X_uut = WPEFilter.convtensor(torch.tensor(X), filter_length=filter_length, delay=delay)
# UUT has vectors arranged in a tensor shape with permuted columns
# Reorganize to match the shape and column permutation
tilde_X_uut = WPEFilter.permute_convtensor(tilde_X_uut)
tilde_X_uut = tilde_X_uut.cpu().detach().numpy()
assert np.allclose(tilde_X_uut, tilde_X_ref, atol=atol), f'Example {n}: comparison failed'
@pytest.mark.unit
@pytest.mark.parametrize('num_channels', [1, 3])
@pytest.mark.parametrize('filter_length', [10])
@pytest.mark.parametrize('delay', [0, 5])
def test_wpe_filter(self, num_channels: int, filter_length: int, delay: int):
"""Test estimation of correlation matrices, filter and filtering."""
atol = 1e-6
random_seed = 42
num_examples = 10
batch_size = 4
num_subbands = 15
num_frames = 50
wpe_filter = WPEFilter(filter_length=filter_length, prediction_delay=delay, diag_reg=None)
_rng = np.random.default_rng(seed=random_seed)
input_size = (batch_size, num_channels, num_subbands, num_frames)
for n in range(num_examples):
X = torch.tensor(_rng.normal(size=input_size) + 1j * _rng.normal(size=input_size))
weight = torch.tensor(_rng.uniform(size=(batch_size, num_subbands, num_frames)))
# Create convtensor (B, C, F, N, filter_length)
tilde_X = wpe_filter.convtensor(X, filter_length=filter_length, delay=delay)
# Test 1:
# estimate_correlation
# Reference
# move channels to back
X_golden = X.permute(0, 2, 3, 1)
# move channels to back and reshape to (B, F, N, C*filter_length)
tilde_X_golden = tilde_X.permute(0, 2, 3, 1, 4).reshape(
batch_size, num_subbands, num_frames, num_channels * filter_length
)
# (B, F, C * filter_length, C * filter_length)
Q_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * tilde_X_golden)
# (B, F, C * filter_length, C)
R_golden = torch.matmul(tilde_X_golden.transpose(-1, -2).conj(), weight[..., None] * X_golden)
# UUT
Q_uut, R_uut = wpe_filter.estimate_correlations(input=X, weight=weight, tilde_input=tilde_X)
# Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length)
Q_uut_flattened = Q_uut.flatten(start_dim=-2, end_dim=-1).flatten(start_dim=-3, end_dim=-2)
# Flatten (B, F, C, filter_length, C, filter_length) into (B, F, C*filter_length, C*filter_length)
R_uut_flattened = R_uut.flatten(start_dim=-3, end_dim=-2)
assert torch.allclose(Q_uut_flattened, Q_golden, atol=atol), f'Example {n}: comparison failed for Q'
assert torch.allclose(R_uut_flattened, R_golden, atol=atol), f'Example {n}: comparison failed for R'
# Test 2:
# estimate_filter
# Reference
G_golden = torch.linalg.solve(Q_golden, R_golden)
# UUT
G_uut = wpe_filter.estimate_filter(Q_uut, R_uut)
# Flatten and move output channels to back
G_uut_flattened = G_uut.reshape(batch_size, num_channels, num_subbands, -1).permute(0, 2, 3, 1)
assert torch.allclose(G_uut_flattened, G_golden, atol=atol), f'Example {n}: comparison failed for G'
# Test 3:
# apply_filter
# Reference
U_golden = torch.matmul(tilde_X_golden, G_golden)
# UUT
U_uut = wpe_filter.apply_filter(filter=G_uut, tilde_input=tilde_X)
U_uut_ref = U_uut.permute(0, 2, 3, 1)
assert torch.allclose(
U_uut_ref, U_golden, atol=atol
), f'Example {n}: comparison failed for undesired output U'
@pytest.mark.unit
@pytest.mark.parametrize('num_channels', [3])
@pytest.mark.parametrize('filter_length', [5])
@pytest.mark.parametrize('delay', [0, 2])
def test_mask_based_dereverb_init(self, num_channels: int, filter_length: int, delay: int):
"""Test that dereverb can be initialized and can process audio."""
num_examples = 10
batch_size = 8
num_subbands = 15
num_frames = 21
num_iterations = 2
input_size = (batch_size, num_subbands, num_frames, num_channels)
dereverb = MaskBasedDereverbWPE(
filter_length=filter_length, prediction_delay=delay, num_iterations=num_iterations
)
for n in range(num_examples):
# multi-channel input
x = torch.randn(input_size) + 1j * torch.randn(input_size)
# random input_length
x_length = torch.randint(1, num_frames, (batch_size,))
# multi-channel mask
mask = torch.rand(input_size)
# UUT
y, y_length = dereverb(input=x, input_length=x_length, mask=mask)
assert y.shape == x.shape, 'Output shape not matching, example {n}'
assert torch.equal(y_length, x_length), 'Length not matching, example {n}'
class TestMaskEstimator:
@pytest.mark.unit
@pytest.mark.parametrize('channel_reduction_position', [0, 1, -1])
@pytest.mark.parametrize('channel_reduction_type', ['average', 'attention'])
@pytest.mark.parametrize('channel_block_type', ['transform_average_concatenate', 'transform_attend_concatenate'])
def test_flex_channels(
self, channel_reduction_position: int, channel_reduction_type: str, channel_block_type: str
):
"""Test initialization of the mask estimator and make sure it can process input tensor."""
# Model parameters
num_subbands_tests = [32, 65]
num_outputs_tests = [1, 2]
num_blocks_tests = [1, 5]
# Input configuration
num_channels_tests = [1, 4]
batch_size = 4
num_frames = 50
for num_subbands in num_subbands_tests:
for num_outputs in num_outputs_tests:
for num_blocks in num_blocks_tests:
logging.debug(
'Instantiate with num_subbands=%d, num_outputs=%d, num_blocks=%d',
num_subbands,
num_outputs,
num_blocks,
)
# Instantiate
uut = MaskEstimatorFlexChannels(
num_outputs=num_outputs,
num_subbands=num_subbands,
num_blocks=num_blocks,
channel_reduction_position=channel_reduction_position,
channel_reduction_type=channel_reduction_type,
channel_block_type=channel_block_type,
)
# Process different channel configurations
for num_channels in num_channels_tests:
logging.debug('Process num_channels=%d', num_channels)
input_size = (batch_size, num_channels, num_subbands, num_frames)
# multi-channel input
spec = torch.randn(input_size, dtype=torch.cfloat)
spec_length = torch.randint(1, num_frames, (batch_size,))
# UUT
mask, mask_length = uut(input=spec, input_length=spec_length)
# Check output dimensions match
expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames)
assert (
mask.shape == expected_mask_shape
), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}'
# Check output lengths match
assert torch.all(
mask_length == spec_length
), f'Output length mismatch: expected {spec_length}, got {mask_length}'
@pytest.mark.unit
@pytest.mark.parametrize('num_channels', [1, 4])
@pytest.mark.parametrize('num_subbands', [32, 65])
@pytest.mark.parametrize('num_outputs', [2, 3])
@pytest.mark.parametrize('batch_size', [1, 4])
def test_gss(self, num_channels: int, num_subbands: int, num_outputs: int, batch_size: int):
"""Test initialization of the GSS mask estimator and make sure it can process an input tensor.
This tests initialization and the output shape. It does not test correctness of the output.
"""
# Test vector length
num_frames = 50
# Instantiate UUT
uut = MaskEstimatorGSS()
# Process the current configuration
logging.debug('Process num_channels=%d', num_channels)
input_size = (batch_size, num_channels, num_subbands, num_frames)
logging.debug('Input size: %s', input_size)
# multi-channel input
mixture_spec = torch.randn(input_size, dtype=torch.cfloat)
source_activity = torch.randn(batch_size, num_outputs, num_frames) > 0
# UUT
mask = uut(input=mixture_spec, activity=source_activity)
# Check output dimensions match
expected_mask_shape = (batch_size, num_outputs, num_subbands, num_frames)
assert (
mask.shape == expected_mask_shape
), f'Output shape mismatch: expected {expected_mask_shape}, got {mask.shape}'
class TestSSLPretrainMaskingWithPatch:
@pytest.mark.unit
@pytest.mark.parametrize('patch_size', [1, 5, 10])
@pytest.mark.parametrize('mask_fraction', [0.5, 1.0])
@pytest.mark.parametrize('training', [True, False])
def test_masking(self, patch_size: int, mask_fraction: float, training: bool):
"""Test SSL pretrain masking."""
num_subbands = 32
num_frames = 5000
num_channels = 1
batch_size = 8
abs_tol = 1e-2
# Instantiate
uut = SSLPretrainWithMaskedPatch(patch_size=patch_size, mask_fraction=mask_fraction)
# Set training mode
if training:
uut.train()
else:
uut.eval()
# Generate random input spec and length
rng = torch.Generator()
rng.manual_seed(0)
input_spec = torch.randn(batch_size, num_channels, num_subbands, num_frames, dtype=torch.cfloat, generator=rng)
input_length = torch.randint(num_frames // 2, num_frames, (batch_size,), generator=rng)
for b in range(batch_size):
input_spec[b, :, :, input_length[b] :] = 0.0
# Apply masking
masked_spec = uut(input_spec=input_spec, length=input_length)
# Check output dimensions match
assert masked_spec.shape == input_spec.shape
# Check output values are masked for each example in the batch
for b in range(batch_size):
# Estimate mask fraction
est_mask_fraction = torch.sum(masked_spec[b, :, :, : input_length[b]].abs() == 0.0) / (
num_channels * num_subbands * input_length[b]
)
# Check if the estimated mask fraction is close to the expected mask fraction
assert (
abs(est_mask_fraction - mask_fraction) < abs_tol
), f'Example {b}: est_mask_fraction = {est_mask_fraction}, mask_fraction = {mask_fraction}'
@pytest.mark.unit
def test_unsupported_initialization(self):
"""Test SSL pretrain masking."""
with pytest.raises(ValueError):
SSLPretrainWithMaskedPatch(patch_size=0)
with pytest.raises(ValueError):
SSLPretrainWithMaskedPatch(patch_size=-1)
with pytest.raises(ValueError):
SSLPretrainWithMaskedPatch(mask_fraction=1.1)
with pytest.raises(ValueError):
SSLPretrainWithMaskedPatch(mask_fraction=-0.1)