# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest from lightning.fabric import plugins as fl_plugins from lightning.fabric import strategies as fl_strategies from lightning.pytorch import plugins as pl_plugins from lightning.pytorch import strategies as pl_strategies from nemo import lightning as nl from nemo.lightning.fabric.conversion import to_fabric class TestConversion: def test_ddp_strategy_conversion(self): pl_strategy = pl_strategies.DDPStrategy() fabric_strategy = to_fabric(pl_strategy) assert isinstance(fabric_strategy, fl_strategies.DDPStrategy) def test_fsdp_strategy_conversion(self): pl_strategy = pl_strategies.FSDPStrategy( cpu_offload=True, ) fabric_strategy = to_fabric(pl_strategy) assert isinstance(fabric_strategy, fl_strategies.FSDPStrategy) assert fabric_strategy.cpu_offload.offload_params is True def test_mixed_precision_plugin_conversion(self): pl_plugin = pl_plugins.MixedPrecision(precision='16-mixed', device='cpu') fabric_plugin = to_fabric(pl_plugin) assert isinstance(fabric_plugin, fl_plugins.MixedPrecision) assert fabric_plugin.precision == '16-mixed' def test_fsdp_precision_plugin_conversion(self): pl_plugin = pl_plugins.FSDPPrecision(precision='16-mixed') fabric_plugin = to_fabric(pl_plugin) assert isinstance(fabric_plugin, fl_plugins.FSDPPrecision) assert fabric_plugin.precision == '16-mixed' def test_unsupported_object_conversion(self): class UnsupportedObject: pass with pytest.raises(NotImplementedError) as excinfo: to_fabric(UnsupportedObject()) assert "No Fabric converter registered for UnsupportedObject" in str(excinfo.value) def test_megatron_strategy_conversion(self): pl_strategy = nl.MegatronStrategy( tensor_model_parallel_size=2, pipeline_model_parallel_size=2, virtual_pipeline_model_parallel_size=2, context_parallel_size=2, sequence_parallel=True, expert_model_parallel_size=2, moe_extended_tp=True, ) fabric_strategy = to_fabric(pl_strategy) assert isinstance(fabric_strategy, nl.FabricMegatronStrategy) assert fabric_strategy.tensor_model_parallel_size == 2 assert fabric_strategy.pipeline_model_parallel_size == 2 assert fabric_strategy.virtual_pipeline_model_parallel_size == 2 assert fabric_strategy.context_parallel_size == 2 assert fabric_strategy.sequence_parallel is True assert fabric_strategy.expert_model_parallel_size == 2 assert fabric_strategy.moe_extended_tp is True def test_megatron_precision_conversion(self): pl_plugin = nl.MegatronMixedPrecision(precision='16-mixed') fabric_plugin = to_fabric(pl_plugin) assert isinstance(fabric_plugin, nl.FabricMegatronMixedPrecision) assert fabric_plugin.precision == '16-mixed'