# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import types from unittest.mock import MagicMock, patch import pytest from nemo.utils.import_utils import UnavailableError, UnavailableMeta, is_unavailable, safe_import, safe_import_from class TestUnavailableMeta: """Test suite for the UnavailableMeta metaclass.""" def test_metaclass_creation(self): """Test that UnavailableMeta creates a class with the expected properties.""" # Create a class using UnavailableMeta TestClass = UnavailableMeta("TestClass", (), {}) # The class name should be prefixed with "MISSING" assert TestClass.__name__ == "MISSINGTestClass" # The default error message should be set assert TestClass._msg == "TestClass could not be imported" def test_custom_error_message(self): """Test that a custom error message can be provided.""" custom_msg = "Custom error message" TestClass = UnavailableMeta("TestClass", (), {"_msg": custom_msg}) assert TestClass._msg == custom_msg # Verify the message is used in exceptions with pytest.raises(UnavailableError, match=custom_msg): TestClass() def test_call_raises_error(self): """Test that attempting to instantiate the class raises UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) with pytest.raises(UnavailableError): TestClass() with pytest.raises(UnavailableError): TestClass(1, 2, 3, key="value") def test_attribute_access_raises_error(self): """Test that accessing attributes raises UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) with pytest.raises(UnavailableError): TestClass.some_attribute def test_arithmetic_operations_raise_error(self): """Test that arithmetic operations raise UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) operations = [ lambda c: c + 1, lambda c: 1 + c, # __radd__ lambda c: c - 1, lambda c: 1 - c, # __rsub__ lambda c: c * 2, lambda c: 2 * c, # __rmul__ lambda c: c / 2, lambda c: 2 / c, # __rtruediv__ lambda c: c // 2, lambda c: 2 // c, # __rfloordiv__ lambda c: c**2, lambda c: 2**c, # __rpow__ lambda c: -c, # __neg__ lambda c: abs(c), # __abs__ ] for op in operations: with pytest.raises(UnavailableError): op(TestClass) def test_comparison_operations_raise_error(self): """Test that comparison operations raise UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) another_class = UnavailableMeta("AnotherClass", (), {}) comparisons = [ lambda c: c == another_class, lambda c: c != another_class, lambda c: c < another_class, lambda c: c <= another_class, lambda c: c > another_class, lambda c: c >= another_class, ] for comp in comparisons: with pytest.raises(UnavailableError): comp(TestClass) def test_container_operations_raise_error(self): """Test that container operations raise UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) with pytest.raises(UnavailableError): len(TestClass) with pytest.raises(UnavailableError): TestClass[0] with pytest.raises(UnavailableError): TestClass[0] = 1 with pytest.raises(UnavailableError): del TestClass[0] with pytest.raises(UnavailableError): iter(TestClass) def test_descriptor_operations_raise_error(self): """Test that descriptor operations raise UnavailableError.""" TestClass = UnavailableMeta("TestClass", (), {}) class DummyClass: prop = TestClass dummy = DummyClass() with pytest.raises(UnavailableError): TestClass.__get__(None, None) with pytest.raises(UnavailableError): TestClass.__delete__(None) class TestSafeImport: def test_successful_import(self): """Test safe_import with a module that exists.""" module, success = safe_import("os") assert success is True assert isinstance(module, types.ModuleType) assert module.__name__ == "os" def test_failed_import(self): """Test safe_import with a module that doesn't exist.""" module, success = safe_import("nonexistent_module") assert success is False assert is_unavailable(module) assert type(module) is UnavailableMeta def test_import_with_custom_message(self): """Test safe_import with a custom error message.""" custom_msg = "Custom error message" module, success = safe_import("nonexistent_module", msg=custom_msg) assert success is False assert is_unavailable(module) # Verify the custom message is used when trying to use the module with pytest.raises(UnavailableError, match=custom_msg): module() def test_import_with_alternative(self): """Test safe_import with an alternative module.""" alt_module = object() module, success = safe_import("nonexistent_module", alt=alt_module) assert success is False assert module is alt_module def test_unavailable_module_raises_error_when_used(self): """Test that using a UnavailableMeta placeholder raises UnavailableError.""" module, success = safe_import("nonexistent_module") assert success is False # Test various operations that should raise UnavailableError with pytest.raises(UnavailableError): module() with pytest.raises(UnavailableError): module.attribute with pytest.raises(UnavailableError): module + 1 with pytest.raises(UnavailableError): module == 1 class TestSafeImportFrom: def test_successful_import_from(self): """Test safe_import_from with a symbol that exists.""" symbol, success = safe_import_from("os", "path") assert success is True import os assert symbol is os.path def test_failed_import_from_nonexistent_module(self): """Test safe_import_from with a module that doesn't exist.""" symbol, success = safe_import_from("nonexistent_module", "nonexistent_symbol") assert success is False assert is_unavailable(symbol) def test_failed_import_from_nonexistent_symbol(self): """Test safe_import_from with a symbol that doesn't exist in an existing module.""" symbol, success = safe_import_from("os", "nonexistent_symbol") assert success is False assert is_unavailable(symbol) def test_import_from_with_custom_message(self): """Test safe_import_from with a custom error message.""" custom_msg = "Custom error message for symbol" symbol, success = safe_import_from("os", "nonexistent_symbol", msg=custom_msg) assert success is False # Verify the custom message is used when trying to use the symbol with pytest.raises(UnavailableError, match=custom_msg): symbol() def test_import_from_with_alternative(self): """Test safe_import_from with an alternative symbol.""" alt_symbol = object() symbol, success = safe_import_from("os", "nonexistent_symbol", alt=alt_symbol) assert success is False assert symbol is alt_symbol def test_fallback_module(self): """Test safe_import_from with a fallback module.""" # First import fails, but fallback succeeds with patch('importlib.import_module') as mock_import: # Mock the first import to fail as AttributeError def side_effect(name): if name == "primary_module": raise AttributeError("Symbol not found") elif name == "fallback_module": mock_module = MagicMock() mock_module.symbol = "fallback_symbol" return mock_module else: raise ImportError(f"Unexpected module: {name}") mock_import.side_effect = side_effect symbol, success = safe_import_from("primary_module", "symbol", fallback_module="fallback_module") assert success is True assert symbol == "fallback_symbol" def test_fallback_module_both_fail(self): """Test safe_import_from when both primary and fallback modules fail.""" symbol, success = safe_import_from("nonexistent_primary", "symbol", fallback_module="nonexistent_fallback") assert success is False assert is_unavailable(symbol)