Spaces:
Runtime error
Runtime error
File size: 8,113 Bytes
0558aa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from unittest.mock import MagicMock
import pytest
from megatron.core import parallel_state
from torch import nn
from nemo import lightning as nl
from nemo.lightning import megatron_parallel as mp
class TestMegatronParallel:
"""Unit tests for the MegatronParallel class."""
@pytest.fixture
def mock_pipeline(self, mocker):
"""Fixture to create a mock pipeline."""
class DummyModule(nn.Module):
def __init__(self, dummy_arg=None):
self.dummy_arg = dummy_arg
super().__init__()
def forward(self, x):
return x
return DummyModule()
@pytest.fixture
def mock_precision_plugin(self, mocker):
"""Fixture to create a mock precision plugin."""
return nl.MegatronMixedPrecision(precision="bf16-mixed")
@pytest.fixture
def mock_callbacks(self, mocker):
"""Fixture to create a mock callback connector."""
return mocker.MagicMock(spec=mp.CallbackConnector)
@pytest.fixture
def mock_data_step(self, mocker):
"""Fixture to create a mock data step function."""
return mocker.MagicMock()
@pytest.fixture
def mock_forward_step(self, mocker):
"""Fixture to create a mock forward step function."""
return mocker.MagicMock()
@pytest.fixture
def mock_loss_reduction(self, mocker):
"""Fixture to create a mock loss reduction function."""
return mocker.MagicMock()
def test_init_with_defaults(self, mocker, mock_pipeline):
"""Test __init__ with default parameters."""
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)
megatron_parallel = mp.MegatronParallel(pipeline=mock_pipeline, cpu=True)
assert megatron_parallel.pipeline == mock_pipeline
assert megatron_parallel.precision_plugin is None
assert isinstance(megatron_parallel.callbacks, mp.CallbackConnector)
assert megatron_parallel.data_step == mp.default_data_step
assert megatron_parallel.forward_step == mp.default_forward_step
assert megatron_parallel.loss_reduction is None
def test_init_with_custom_parameters(
self,
mocker,
mock_pipeline,
mock_precision_plugin,
mock_callbacks,
mock_data_step,
mock_forward_step,
mock_loss_reduction,
):
"""Test __init__ with custom parameters."""
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)
megatron_parallel = mp.MegatronParallel(
pipeline=mock_pipeline,
precision_plugin=mock_precision_plugin,
callbacks=mock_callbacks,
data_step=mock_data_step,
forward_step=mock_forward_step,
loss_reduction=mock_loss_reduction,
cpu=True,
)
assert megatron_parallel.pipeline == mock_pipeline
assert megatron_parallel.precision_plugin == mock_precision_plugin
assert megatron_parallel.callbacks == mock_callbacks
assert megatron_parallel.data_step == mock_data_step
assert megatron_parallel.forward_step == mock_forward_step
assert megatron_parallel.loss_reduction == mock_loss_reduction
class TestCallbackConnector:
def test_add_callbacks(self) -> None:
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
assert callback in callback_connector.callbacks["on_megatron_step_start"]
assert callback in callback_connector.callbacks["on_megatron_microbatch_start"]
def test_event(self) -> None:
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
# Replace mocker.spy with manual mocking
callback.on_megatron_step_start = MagicMock()
callback.on_megatron_microbatch_start = MagicMock()
callback_connector.event("on_megatron_step_start")
callback_connector.event("on_megatron_microbatch_start")
assert callback.on_megatron_step_start.call_count == 1
assert callback.on_megatron_microbatch_start.call_count == 1
def test_add_connector(self) -> None:
callback_connector1 = mp.CallbackConnector()
callback_connector2 = mp.CallbackConnector()
callback1 = TestCallback()
callback2 = TestCallback()
callback_connector1.add(callback1)
callback_connector2.add(callback2)
callback_connector1 += callback_connector2
assert callback1 in callback_connector1.callbacks["on_megatron_step_start"]
assert callback2 in callback_connector1.callbacks["on_megatron_step_start"]
def test_contains(self):
callback_connector = mp.CallbackConnector()
callback = TestCallback()
callback_connector.add(callback)
assert callback in callback_connector
def test_add_count_callback(self):
"""Test adding a CountCallback to the CallbackConnector."""
connector = mp.CallbackConnector()
count_callback = CountCallback()
connector.add(count_callback)
# Check if the CountCallback has been added correctly
assert count_callback in connector, "CountCallback should be in the CallbackConnector"
def test_event_trigger_with_count_callback(self):
"""Test if the event triggers the method in CountCallback."""
connector = mp.CallbackConnector()
count_callback = CountCallback()
connector.add(count_callback)
# Simulate an event that CountCallback listens to
connector.event('on_megatron_step_start')
# Check if the CountCallback's method was called
assert (
count_callback.counts["on_megatron_step_start"] == 1
), "CountCallback's method should have been triggered once"
class TestCallback:
def on_megatron_step_start(self):
pass
def on_megatron_microbatch_start(self):
pass
class CountCallback:
def __init__(self) -> None:
self.counts = defaultdict(int)
def on_megatron_step_start(self, *args, **kwargs) -> None:
# assert len(kwargs) == 12
self.counts["on_megatron_step_start"] += 1
def on_megatron_microbatch_start(self, *args, **kwargs) -> None:
# assert len(kwargs) == 14
self.counts["on_megatron_microbatch_start"] += 1
def on_megatron_microbatch_callback(self, *args, **kwargs) -> None:
self.counts["on_megatron_microbatches_callback"] += 1
def on_megatron_microbatch_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_microbatches_end"] += 1
def on_megatron_reduce_microbatches_start(self, *args, **kwargs) -> None:
self.counts["on_megatron_reduce_microbatches_start"] += 1
def on_megatron_reduce_microbatches_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_reduce_microbatches_end"] += 1
def on_megatron_log_step_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_log_step_end"] += 1
def on_megatron_step_end(self, *args, **kwargs) -> None:
self.counts["on_megatron_step_end"] += 1
|