Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import sys | |
| import types | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from nemo.utils.import_utils import UnavailableError, UnavailableMeta, is_unavailable, safe_import, safe_import_from | |
| class TestUnavailableMeta: | |
| """Test suite for the UnavailableMeta metaclass.""" | |
| def test_metaclass_creation(self): | |
| """Test that UnavailableMeta creates a class with the expected properties.""" | |
| # Create a class using UnavailableMeta | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| # The class name should be prefixed with "MISSING" | |
| assert TestClass.__name__ == "MISSINGTestClass" | |
| # The default error message should be set | |
| assert TestClass._msg == "TestClass could not be imported" | |
| def test_custom_error_message(self): | |
| """Test that a custom error message can be provided.""" | |
| custom_msg = "Custom error message" | |
| TestClass = UnavailableMeta("TestClass", (), {"_msg": custom_msg}) | |
| assert TestClass._msg == custom_msg | |
| # Verify the message is used in exceptions | |
| with pytest.raises(UnavailableError, match=custom_msg): | |
| TestClass() | |
| def test_call_raises_error(self): | |
| """Test that attempting to instantiate the class raises UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| with pytest.raises(UnavailableError): | |
| TestClass() | |
| with pytest.raises(UnavailableError): | |
| TestClass(1, 2, 3, key="value") | |
| def test_attribute_access_raises_error(self): | |
| """Test that accessing attributes raises UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| with pytest.raises(UnavailableError): | |
| TestClass.some_attribute | |
| def test_arithmetic_operations_raise_error(self): | |
| """Test that arithmetic operations raise UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| operations = [ | |
| lambda c: c + 1, | |
| lambda c: 1 + c, # __radd__ | |
| lambda c: c - 1, | |
| lambda c: 1 - c, # __rsub__ | |
| lambda c: c * 2, | |
| lambda c: 2 * c, # __rmul__ | |
| lambda c: c / 2, | |
| lambda c: 2 / c, # __rtruediv__ | |
| lambda c: c // 2, | |
| lambda c: 2 // c, # __rfloordiv__ | |
| lambda c: c**2, | |
| lambda c: 2**c, # __rpow__ | |
| lambda c: -c, # __neg__ | |
| lambda c: abs(c), # __abs__ | |
| ] | |
| for op in operations: | |
| with pytest.raises(UnavailableError): | |
| op(TestClass) | |
| def test_comparison_operations_raise_error(self): | |
| """Test that comparison operations raise UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| another_class = UnavailableMeta("AnotherClass", (), {}) | |
| comparisons = [ | |
| lambda c: c == another_class, | |
| lambda c: c != another_class, | |
| lambda c: c < another_class, | |
| lambda c: c <= another_class, | |
| lambda c: c > another_class, | |
| lambda c: c >= another_class, | |
| ] | |
| for comp in comparisons: | |
| with pytest.raises(UnavailableError): | |
| comp(TestClass) | |
| def test_container_operations_raise_error(self): | |
| """Test that container operations raise UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| with pytest.raises(UnavailableError): | |
| len(TestClass) | |
| with pytest.raises(UnavailableError): | |
| TestClass[0] | |
| with pytest.raises(UnavailableError): | |
| TestClass[0] = 1 | |
| with pytest.raises(UnavailableError): | |
| del TestClass[0] | |
| with pytest.raises(UnavailableError): | |
| iter(TestClass) | |
| def test_descriptor_operations_raise_error(self): | |
| """Test that descriptor operations raise UnavailableError.""" | |
| TestClass = UnavailableMeta("TestClass", (), {}) | |
| class DummyClass: | |
| prop = TestClass | |
| dummy = DummyClass() | |
| with pytest.raises(UnavailableError): | |
| TestClass.__get__(None, None) | |
| with pytest.raises(UnavailableError): | |
| TestClass.__delete__(None) | |
| class TestSafeImport: | |
| def test_successful_import(self): | |
| """Test safe_import with a module that exists.""" | |
| module, success = safe_import("os") | |
| assert success is True | |
| assert isinstance(module, types.ModuleType) | |
| assert module.__name__ == "os" | |
| def test_failed_import(self): | |
| """Test safe_import with a module that doesn't exist.""" | |
| module, success = safe_import("nonexistent_module") | |
| assert success is False | |
| assert is_unavailable(module) | |
| assert type(module) is UnavailableMeta | |
| def test_import_with_custom_message(self): | |
| """Test safe_import with a custom error message.""" | |
| custom_msg = "Custom error message" | |
| module, success = safe_import("nonexistent_module", msg=custom_msg) | |
| assert success is False | |
| assert is_unavailable(module) | |
| # Verify the custom message is used when trying to use the module | |
| with pytest.raises(UnavailableError, match=custom_msg): | |
| module() | |
| def test_import_with_alternative(self): | |
| """Test safe_import with an alternative module.""" | |
| alt_module = object() | |
| module, success = safe_import("nonexistent_module", alt=alt_module) | |
| assert success is False | |
| assert module is alt_module | |
| def test_unavailable_module_raises_error_when_used(self): | |
| """Test that using a UnavailableMeta placeholder raises UnavailableError.""" | |
| module, success = safe_import("nonexistent_module") | |
| assert success is False | |
| # Test various operations that should raise UnavailableError | |
| with pytest.raises(UnavailableError): | |
| module() | |
| with pytest.raises(UnavailableError): | |
| module.attribute | |
| with pytest.raises(UnavailableError): | |
| module + 1 | |
| with pytest.raises(UnavailableError): | |
| module == 1 | |
| class TestSafeImportFrom: | |
| def test_successful_import_from(self): | |
| """Test safe_import_from with a symbol that exists.""" | |
| symbol, success = safe_import_from("os", "path") | |
| assert success is True | |
| import os | |
| assert symbol is os.path | |
| def test_failed_import_from_nonexistent_module(self): | |
| """Test safe_import_from with a module that doesn't exist.""" | |
| symbol, success = safe_import_from("nonexistent_module", "nonexistent_symbol") | |
| assert success is False | |
| assert is_unavailable(symbol) | |
| def test_failed_import_from_nonexistent_symbol(self): | |
| """Test safe_import_from with a symbol that doesn't exist in an existing module.""" | |
| symbol, success = safe_import_from("os", "nonexistent_symbol") | |
| assert success is False | |
| assert is_unavailable(symbol) | |
| def test_import_from_with_custom_message(self): | |
| """Test safe_import_from with a custom error message.""" | |
| custom_msg = "Custom error message for symbol" | |
| symbol, success = safe_import_from("os", "nonexistent_symbol", msg=custom_msg) | |
| assert success is False | |
| # Verify the custom message is used when trying to use the symbol | |
| with pytest.raises(UnavailableError, match=custom_msg): | |
| symbol() | |
| def test_import_from_with_alternative(self): | |
| """Test safe_import_from with an alternative symbol.""" | |
| alt_symbol = object() | |
| symbol, success = safe_import_from("os", "nonexistent_symbol", alt=alt_symbol) | |
| assert success is False | |
| assert symbol is alt_symbol | |
| def test_fallback_module(self): | |
| """Test safe_import_from with a fallback module.""" | |
| # First import fails, but fallback succeeds | |
| with patch('importlib.import_module') as mock_import: | |
| # Mock the first import to fail as AttributeError | |
| def side_effect(name): | |
| if name == "primary_module": | |
| raise AttributeError("Symbol not found") | |
| elif name == "fallback_module": | |
| mock_module = MagicMock() | |
| mock_module.symbol = "fallback_symbol" | |
| return mock_module | |
| else: | |
| raise ImportError(f"Unexpected module: {name}") | |
| mock_import.side_effect = side_effect | |
| symbol, success = safe_import_from("primary_module", "symbol", fallback_module="fallback_module") | |
| assert success is True | |
| assert symbol == "fallback_symbol" | |
| def test_fallback_module_both_fail(self): | |
| """Test safe_import_from when both primary and fallback modules fail.""" | |
| symbol, success = safe_import_from("nonexistent_primary", "symbol", fallback_module="nonexistent_fallback") | |
| assert success is False | |
| assert is_unavailable(symbol) | |