Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import pytest | |
| import torch | |
| from nemo.collections.audio.modules.projections import MixtureConsistencyProjection | |
| class TestMixtureConsistencyProjection: | |
| def test_mixture_consistency(self, weighting: str, num_sources: int): | |
| batch_size = 4 | |
| num_subbands = 33 | |
| num_samples = 100 | |
| num_examples = 8 | |
| atol = 1e-5 | |
| rng = torch.Generator() | |
| rng.manual_seed(42) | |
| # create projection | |
| uut = MixtureConsistencyProjection(weighting=weighting) | |
| for n in range(num_examples): | |
| # single-channel mixture | |
| mixture = torch.randn(batch_size, 1, num_subbands, num_samples, generator=rng, dtype=torch.cfloat) | |
| # source estimates | |
| estimate = torch.randn( | |
| batch_size, num_sources, num_subbands, num_samples, generator=rng, dtype=torch.cfloat | |
| ) | |
| # project | |
| uut_projected = uut(mixture=mixture, estimate=estimate) | |
| # estimated mixture | |
| estimated_mixture = torch.sum(estimate, dim=1, keepdim=True) | |
| if weighting is None: | |
| weight = 1 / num_sources | |
| elif weighting == 'power': | |
| weight = estimate.abs().pow(2) | |
| weight = weight / (weight.sum(dim=1, keepdim=True) + uut.eps) | |
| else: | |
| raise ValueError(f'Weighting {weighting} not implemented') | |
| correction = weight * (mixture - estimated_mixture) | |
| ref_projected = estimate + correction | |
| # check consistency | |
| assert torch.allclose(uut_projected, ref_projected, atol=atol) | |
| def test_unsupported_weighting(self): | |
| # Initialize with unsupported weighting | |
| with pytest.raises(NotImplementedError): | |
| MixtureConsistencyProjection(weighting='not-implemented') | |
| # Initialize with None and change later | |
| uut = MixtureConsistencyProjection(weighting=None) | |
| uut.weighting = 'not-implemented' | |
| with pytest.raises(NotImplementedError): | |
| uut( | |
| mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), | |
| estimate=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), | |
| ) | |
| def test_unsupported_inputs(self): | |
| # Multi-channel mixtures are not supported | |
| uut = MixtureConsistencyProjection(weighting=None) | |
| with pytest.raises(ValueError): | |
| uut( | |
| mixture=torch.randn(1, 2, 1, 1, dtype=torch.cfloat), | |
| estimate=torch.randn(1, 2, 1, 1, dtype=torch.cfloat), | |
| ) | |
| # Consistency projection is applied in the time-frequency domain | |
| # It is expected that the mixture has a single channel, and shape (B, 1, F, N) | |
| with pytest.raises(TypeError): | |
| uut(mixture=torch.randn(1, 2, 1), estimate=torch.randn(1, 2, 1)) | |
| # It is expected that the estimate has shape (B, num_sources, F, N) | |
| with pytest.raises(TypeError): | |
| uut(mixture=torch.randn(1, 1, 1, 1, dtype=torch.cfloat), estimate=torch.randn(1, 2, 1)) | |