MagpieTTS_Internal_Demo / tests /collections /audio /test_audio_modules_projections.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from nemo.collections.audio.modules.projections import MixtureConsistencyProjection
class TestMixtureConsistencyProjection:
@pytest.mark.unit
@pytest.mark.parametrize('weighting', [None, 'power'])
@pytest.mark.parametrize('num_sources', [1, 3])
def test_mixture_consistency(self, weighting: str, num_sources: int):
batch_size = 4
num_subbands = 33
num_samples = 100
num_examples = 8
atol = 1e-5
rng = torch.Generator()
rng.manual_seed(42)
# create projection
uut = MixtureConsistencyProjection(weighting=weighting)
for n in range(num_examples):
# single-channel mixture
mixture = torch.randn(batch_size, 1, num_subbands, num_samples, generator=rng, dtype=torch.cfloat)
# source estimates
estimate = torch.randn(
batch_size, num_sources, num_subbands, num_samples, generator=rng, dtype=torch.cfloat
)
# project
uut_projected = uut(mixture=mixture, estimate=estimate)
# estimated mixture
estimated_mixture = torch.sum(estimate, dim=1, keepdim=True)
if weighting is None:
weight = 1 / num_sources
elif weighting == 'power':
weight = estimate.abs().pow(2)
weight = weight / (weight.sum(dim=1, keepdim=True) + uut.eps)
else:
raise ValueError(f'Weighting {weighting} not implemented')
correction = weight * (mixture - estimated_mixture)
ref_projected = estimate + correction
# check consistency
assert torch.allclose(uut_projected, ref_projected, atol=atol)
@pytest.mark.unit
def test_unsupported_weighting(self):
# Initialize with unsupported weighting
with pytest.raises(NotImplementedError):
MixtureConsistencyProjection(weighting='not-implemented')
# Initialize with None and change later
uut = MixtureConsistencyProjection(weighting=None)
uut.weighting = 'not-implemented'
with pytest.raises(NotImplementedError):
uut(
mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat),
estimate=torch.randn(1, 1, 1, 1, dtype=torch.cfloat),
)
@pytest.mark.unit
def test_unsupported_inputs(self):
# Multi-channel mixtures are not supported
uut = MixtureConsistencyProjection(weighting=None)
with pytest.raises(ValueError):
uut(
mixture=torch.randn(1, 2, 1, 1, dtype=torch.cfloat),
estimate=torch.randn(1, 2, 1, 1, dtype=torch.cfloat),
)
# Consistency projection is applied in the time-frequency domain
# It is expected that the mixture has a single channel, and shape (B, 1, F, N)
with pytest.raises(TypeError):
uut(mixture=torch.randn(1, 2, 1), estimate=torch.randn(1, 2, 1))
# It is expected that the estimate has shape (B, num_sources, F, N)
with pytest.raises(TypeError):
uut(mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), estimate=torch.randn(1, 2, 1))