MagpieTTS_Internal_Demo / tests /utils /test_import_utils.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
from nemo.utils.import_utils import UnavailableError, UnavailableMeta, is_unavailable, safe_import, safe_import_from
class TestUnavailableMeta:
"""Test suite for the UnavailableMeta metaclass."""
def test_metaclass_creation(self):
"""Test that UnavailableMeta creates a class with the expected properties."""
# Create a class using UnavailableMeta
TestClass = UnavailableMeta("TestClass", (), {})
# The class name should be prefixed with "MISSING"
assert TestClass.__name__ == "MISSINGTestClass"
# The default error message should be set
assert TestClass._msg == "TestClass could not be imported"
def test_custom_error_message(self):
"""Test that a custom error message can be provided."""
custom_msg = "Custom error message"
TestClass = UnavailableMeta("TestClass", (), {"_msg": custom_msg})
assert TestClass._msg == custom_msg
# Verify the message is used in exceptions
with pytest.raises(UnavailableError, match=custom_msg):
TestClass()
def test_call_raises_error(self):
"""Test that attempting to instantiate the class raises UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
with pytest.raises(UnavailableError):
TestClass()
with pytest.raises(UnavailableError):
TestClass(1, 2, 3, key="value")
def test_attribute_access_raises_error(self):
"""Test that accessing attributes raises UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
with pytest.raises(UnavailableError):
TestClass.some_attribute
def test_arithmetic_operations_raise_error(self):
"""Test that arithmetic operations raise UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
operations = [
lambda c: c + 1,
lambda c: 1 + c, # __radd__
lambda c: c - 1,
lambda c: 1 - c, # __rsub__
lambda c: c * 2,
lambda c: 2 * c, # __rmul__
lambda c: c / 2,
lambda c: 2 / c, # __rtruediv__
lambda c: c // 2,
lambda c: 2 // c, # __rfloordiv__
lambda c: c**2,
lambda c: 2**c, # __rpow__
lambda c: -c, # __neg__
lambda c: abs(c), # __abs__
]
for op in operations:
with pytest.raises(UnavailableError):
op(TestClass)
def test_comparison_operations_raise_error(self):
"""Test that comparison operations raise UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
another_class = UnavailableMeta("AnotherClass", (), {})
comparisons = [
lambda c: c == another_class,
lambda c: c != another_class,
lambda c: c < another_class,
lambda c: c <= another_class,
lambda c: c > another_class,
lambda c: c >= another_class,
]
for comp in comparisons:
with pytest.raises(UnavailableError):
comp(TestClass)
def test_container_operations_raise_error(self):
"""Test that container operations raise UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
with pytest.raises(UnavailableError):
len(TestClass)
with pytest.raises(UnavailableError):
TestClass[0]
with pytest.raises(UnavailableError):
TestClass[0] = 1
with pytest.raises(UnavailableError):
del TestClass[0]
with pytest.raises(UnavailableError):
iter(TestClass)
def test_descriptor_operations_raise_error(self):
"""Test that descriptor operations raise UnavailableError."""
TestClass = UnavailableMeta("TestClass", (), {})
class DummyClass:
prop = TestClass
dummy = DummyClass()
with pytest.raises(UnavailableError):
TestClass.__get__(None, None)
with pytest.raises(UnavailableError):
TestClass.__delete__(None)
class TestSafeImport:
def test_successful_import(self):
"""Test safe_import with a module that exists."""
module, success = safe_import("os")
assert success is True
assert isinstance(module, types.ModuleType)
assert module.__name__ == "os"
def test_failed_import(self):
"""Test safe_import with a module that doesn't exist."""
module, success = safe_import("nonexistent_module")
assert success is False
assert is_unavailable(module)
assert type(module) is UnavailableMeta
def test_import_with_custom_message(self):
"""Test safe_import with a custom error message."""
custom_msg = "Custom error message"
module, success = safe_import("nonexistent_module", msg=custom_msg)
assert success is False
assert is_unavailable(module)
# Verify the custom message is used when trying to use the module
with pytest.raises(UnavailableError, match=custom_msg):
module()
def test_import_with_alternative(self):
"""Test safe_import with an alternative module."""
alt_module = object()
module, success = safe_import("nonexistent_module", alt=alt_module)
assert success is False
assert module is alt_module
def test_unavailable_module_raises_error_when_used(self):
"""Test that using a UnavailableMeta placeholder raises UnavailableError."""
module, success = safe_import("nonexistent_module")
assert success is False
# Test various operations that should raise UnavailableError
with pytest.raises(UnavailableError):
module()
with pytest.raises(UnavailableError):
module.attribute
with pytest.raises(UnavailableError):
module + 1
with pytest.raises(UnavailableError):
module == 1
class TestSafeImportFrom:
def test_successful_import_from(self):
"""Test safe_import_from with a symbol that exists."""
symbol, success = safe_import_from("os", "path")
assert success is True
import os
assert symbol is os.path
def test_failed_import_from_nonexistent_module(self):
"""Test safe_import_from with a module that doesn't exist."""
symbol, success = safe_import_from("nonexistent_module", "nonexistent_symbol")
assert success is False
assert is_unavailable(symbol)
def test_failed_import_from_nonexistent_symbol(self):
"""Test safe_import_from with a symbol that doesn't exist in an existing module."""
symbol, success = safe_import_from("os", "nonexistent_symbol")
assert success is False
assert is_unavailable(symbol)
def test_import_from_with_custom_message(self):
"""Test safe_import_from with a custom error message."""
custom_msg = "Custom error message for symbol"
symbol, success = safe_import_from("os", "nonexistent_symbol", msg=custom_msg)
assert success is False
# Verify the custom message is used when trying to use the symbol
with pytest.raises(UnavailableError, match=custom_msg):
symbol()
def test_import_from_with_alternative(self):
"""Test safe_import_from with an alternative symbol."""
alt_symbol = object()
symbol, success = safe_import_from("os", "nonexistent_symbol", alt=alt_symbol)
assert success is False
assert symbol is alt_symbol
def test_fallback_module(self):
"""Test safe_import_from with a fallback module."""
# First import fails, but fallback succeeds
with patch('importlib.import_module') as mock_import:
# Mock the first import to fail as AttributeError
def side_effect(name):
if name == "primary_module":
raise AttributeError("Symbol not found")
elif name == "fallback_module":
mock_module = MagicMock()
mock_module.symbol = "fallback_symbol"
return mock_module
else:
raise ImportError(f"Unexpected module: {name}")
mock_import.side_effect = side_effect
symbol, success = safe_import_from("primary_module", "symbol", fallback_module="fallback_module")
assert success is True
assert symbol == "fallback_symbol"
def test_fallback_module_both_fail(self):
"""Test safe_import_from when both primary and fallback modules fail."""
symbol, success = safe_import_from("nonexistent_primary", "symbol", fallback_module="nonexistent_fallback")
assert success is False
assert is_unavailable(symbol)