# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from unittest.mock import MagicMock import pytest from megatron.core import parallel_state from torch import nn from nemo import lightning as nl from nemo.lightning import megatron_parallel as mp class TestMegatronParallel: """Unit tests for the MegatronParallel class.""" @pytest.fixture def mock_pipeline(self, mocker): """Fixture to create a mock pipeline.""" class DummyModule(nn.Module): def __init__(self, dummy_arg=None): self.dummy_arg = dummy_arg super().__init__() def forward(self, x): return x return DummyModule() @pytest.fixture def mock_precision_plugin(self, mocker): """Fixture to create a mock precision plugin.""" return nl.MegatronMixedPrecision(precision="bf16-mixed") @pytest.fixture def mock_callbacks(self, mocker): """Fixture to create a mock callback connector.""" return mocker.MagicMock(spec=mp.CallbackConnector) @pytest.fixture def mock_data_step(self, mocker): """Fixture to create a mock data step function.""" return mocker.MagicMock() @pytest.fixture def mock_forward_step(self, mocker): """Fixture to create a mock forward step function.""" return mocker.MagicMock() @pytest.fixture def mock_loss_reduction(self, mocker): """Fixture to create a mock loss reduction function.""" return mocker.MagicMock() def test_init_with_defaults(self, mocker, mock_pipeline): """Test __init__ with default parameters.""" mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) megatron_parallel = mp.MegatronParallel(pipeline=mock_pipeline, cpu=True) assert megatron_parallel.pipeline == mock_pipeline assert megatron_parallel.precision_plugin is None assert isinstance(megatron_parallel.callbacks, mp.CallbackConnector) assert megatron_parallel.data_step == mp.default_data_step assert megatron_parallel.forward_step == mp.default_forward_step assert megatron_parallel.loss_reduction is None def test_init_with_custom_parameters( self, mocker, mock_pipeline, mock_precision_plugin, mock_callbacks, mock_data_step, mock_forward_step, mock_loss_reduction, ): """Test __init__ with custom parameters.""" mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) megatron_parallel = mp.MegatronParallel( pipeline=mock_pipeline, precision_plugin=mock_precision_plugin, callbacks=mock_callbacks, data_step=mock_data_step, forward_step=mock_forward_step, loss_reduction=mock_loss_reduction, cpu=True, ) assert megatron_parallel.pipeline == mock_pipeline assert megatron_parallel.precision_plugin == mock_precision_plugin assert megatron_parallel.callbacks == mock_callbacks assert megatron_parallel.data_step == mock_data_step assert megatron_parallel.forward_step == mock_forward_step assert megatron_parallel.loss_reduction == mock_loss_reduction class TestCallbackConnector: def test_add_callbacks(self) -> None: callback_connector = mp.CallbackConnector() callback = TestCallback() callback_connector.add(callback) assert callback in callback_connector.callbacks["on_megatron_step_start"] assert callback in callback_connector.callbacks["on_megatron_microbatch_start"] def test_event(self) -> None: callback_connector = mp.CallbackConnector() callback = TestCallback() callback_connector.add(callback) # Replace mocker.spy with manual mocking callback.on_megatron_step_start = MagicMock() callback.on_megatron_microbatch_start = MagicMock() callback_connector.event("on_megatron_step_start") callback_connector.event("on_megatron_microbatch_start") assert callback.on_megatron_step_start.call_count == 1 assert callback.on_megatron_microbatch_start.call_count == 1 def test_add_connector(self) -> None: callback_connector1 = mp.CallbackConnector() callback_connector2 = mp.CallbackConnector() callback1 = TestCallback() callback2 = TestCallback() callback_connector1.add(callback1) callback_connector2.add(callback2) callback_connector1 += callback_connector2 assert callback1 in callback_connector1.callbacks["on_megatron_step_start"] assert callback2 in callback_connector1.callbacks["on_megatron_step_start"] def test_contains(self): callback_connector = mp.CallbackConnector() callback = TestCallback() callback_connector.add(callback) assert callback in callback_connector def test_add_count_callback(self): """Test adding a CountCallback to the CallbackConnector.""" connector = mp.CallbackConnector() count_callback = CountCallback() connector.add(count_callback) # Check if the CountCallback has been added correctly assert count_callback in connector, "CountCallback should be in the CallbackConnector" def test_event_trigger_with_count_callback(self): """Test if the event triggers the method in CountCallback.""" connector = mp.CallbackConnector() count_callback = CountCallback() connector.add(count_callback) # Simulate an event that CountCallback listens to connector.event('on_megatron_step_start') # Check if the CountCallback's method was called assert ( count_callback.counts["on_megatron_step_start"] == 1 ), "CountCallback's method should have been triggered once" class TestCallback: def on_megatron_step_start(self): pass def on_megatron_microbatch_start(self): pass class CountCallback: def __init__(self) -> None: self.counts = defaultdict(int) def on_megatron_step_start(self, *args, **kwargs) -> None: # assert len(kwargs) == 12 self.counts["on_megatron_step_start"] += 1 def on_megatron_microbatch_start(self, *args, **kwargs) -> None: # assert len(kwargs) == 14 self.counts["on_megatron_microbatch_start"] += 1 def on_megatron_microbatch_callback(self, *args, **kwargs) -> None: self.counts["on_megatron_microbatches_callback"] += 1 def on_megatron_microbatch_end(self, *args, **kwargs) -> None: self.counts["on_megatron_microbatches_end"] += 1 def on_megatron_reduce_microbatches_start(self, *args, **kwargs) -> None: self.counts["on_megatron_reduce_microbatches_start"] += 1 def on_megatron_reduce_microbatches_end(self, *args, **kwargs) -> None: self.counts["on_megatron_reduce_microbatches_end"] += 1 def on_megatron_log_step_end(self, *args, **kwargs) -> None: self.counts["on_megatron_log_step_end"] += 1 def on_megatron_step_end(self, *args, **kwargs) -> None: self.counts["on_megatron_step_end"] += 1