subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack, contextmanager, nullcontext
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Dict,
Generator,
Iterator,
List,
Literal,
Optional,
Union,
)
import torch
from lightning.fabric.accelerators import CPUAccelerator
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.strategy import _validate_keys_for_strict_loading
from lightning.fabric.utilities.types import _PATH, _Stateful
from lightning.pytorch import LightningDataModule
from lightning.pytorch.loops.fetchers import _DataFetcher
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.utilities.combined_loader import CombinedLoader
try:
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
DistributedDataParallelConfig = object
OptimizerConfig = object
HAVE_MEGATRON_CORE = False
from torch import Tensor, nn
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override
from nemo.lightning import _strategy_lib
from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.io.pl import MegatronCheckpointIO, ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.utils.import_utils import safe_import
from nemo.utils.model_utils import unwrap_model
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
if TYPE_CHECKING:
from nemo.lightning.pytorch.plugins.data_sampler import DataSampler
DDPLiteral = Literal["megatron", "pytorch"]
class FabricMegatronStrategy(DDPStrategy):
"""
Fabric strategy for Megatron.
"""
def __init__(
self,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_comm_backend: str = None,
microbatch_group_size_per_vp_stage: Optional[int] = None,
context_parallel_size: int = 1,
sequence_parallel: bool = False,
expert_model_parallel_size: int = 1,
moe_extended_tp: bool = False,
expert_tensor_parallel_size: int = None,
encoder_tensor_model_parallel_size: Optional[int] = 0,
encoder_pipeline_model_parallel_size: Optional[int] = 0,
data_sampler: Optional["DataSampler"] = None,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
megatron_callbacks: Optional[CallbackConnector] = None,
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen",
no_ddp_communication_hook: bool = True,
output_data_idx: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
init_model_parallel: bool = True,
use_tp_pp_dp_mapping: bool = False,
num_distributed_optimizer_instances: int = 1,
nccl_communicator_config_path: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision=precision,
process_group_backend=process_group_backend,
timeout=timeout,
start_method=start_method,
**kwargs,
)
self.megatron_callbacks = CallbackConnector()
self.data_sampler: Optional['DataSampler'] = data_sampler
self.tensor_model_parallel_size = tensor_model_parallel_size
self.pipeline_model_parallel_size = pipeline_model_parallel_size
self.pipeline_model_parallel_comm_backend = pipeline_model_parallel_comm_backend
self.microbatch_group_size_per_vp_stage = (
microbatch_group_size_per_vp_stage
if microbatch_group_size_per_vp_stage is not None
else pipeline_model_parallel_size
)
self.context_parallel_size = context_parallel_size
self.expert_model_parallel_size = expert_model_parallel_size
self.expert_tensor_parallel_size = expert_tensor_parallel_size
self.moe_extended_tp = moe_extended_tp
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.sequence_parallel = sequence_parallel
self.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
self.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
self.pipeline_dtype = pipeline_dtype
self._init_model_parallel = init_model_parallel
self.use_tp_pp_dp_mapping = use_tp_pp_dp_mapping
self.num_distributed_optimizer_instances = num_distributed_optimizer_instances
self.nccl_communicator_config_path = nccl_communicator_config_path
self.no_ddp_communication_hook = no_ddp_communication_hook
self.megatron_callbacks = CallbackConnector()
if megatron_callbacks:
self.megatron_callbacks.add(megatron_callbacks)
self.output_data_idx = output_data_idx
self.data_sampler: Optional["DataSampler"] = data_sampler
# used in NVIDIA NGC PyTorch containers
_strategy_lib.enable_nvidia_optimizations()
self._ddp = ddp
if ddp == "megatron":
self.ddp_config = DistributedDataParallelConfig()
elif isinstance(ddp, DistributedDataParallelConfig):
self.ddp_config = ddp
elif ddp == "pytorch":
self.ddp_config = None
self.no_ddp_communication_hook = False
else:
raise ValueError(f"Invalid DDP type: {ddp}")
@override
def _setup_distributed(self) -> None:
self._set_world_ranks()
assert self.cluster_environment is not None
_strategy_lib.init_parallel_ranks(
world_size=self.cluster_environment.world_size(),
global_rank=self.cluster_environment.global_rank(),
local_rank=self.cluster_environment.local_rank(),
parallel_config=self.parallelism,
)
super()._setup_distributed()
torch.cuda.set_device(self.cluster_environment.local_rank())
# TODO: Fix this:
# if self.data_config is not None:
# _strategy_lib.initialize_data(self.cluster_environment.global_rank(), self.data_config)
_strategy_lib.init_model_parallel()
def process_datamodule(self, datamodule: LightningDataModule) -> LightningDataModule:
"""
Process the datamodule.
"""
datamodule.setup()
if not self.data_sampler and hasattr(datamodule, "data_sampler"):
self.data_sampler = datamodule.data_sampler
if self.data_sampler:
self.data_sampler.setup(self.cluster_environment.global_rank())
return datamodule
@override
def process_dataloader(self, dataloader: DataLoader) -> Iterator:
"""
Process the dataloader. Returns an iterator.
"""
if self.data_sampler:
dataloader = self.data_sampler.transform_dataloader(dataloader)
# Code taken from:
# https://github.com/Lightning-AI/pytorch-lightning
# /blob/6cbe9ceb560d798892bdae9186291acf9bf5d2e3/src/lightning/pytorch/loops/fit_loop.py
# L258-L260
output = _MegatronDataLoaderIterDataFetcher(output_data_idx=self.output_data_idx)
output.setup(CombinedLoader(dataloader, "max_size_cycle"))
iter(output)
return output
def setup_megatron_optimizer(
self,
model: MegatronParallel,
optimizer_config: OptimizerConfig,
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
) -> Optimizer:
"""
Setup the Megatron optimizer.
"""
if hasattr(self.precision, "convert_config"):
optimizer_config = self.precision.convert_config(optimizer_config)
assert optimizer_config.lr is not None, "Learning rate must be set in optimizer config"
return _strategy_lib.setup_megatron_optimizer(
model,
optimizer_config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
)
@override
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Pass the optimizer to the precision-plugin if needed & add it as callback."""
if hasattr(self._precision, "setup_optimizer"):
optimizer = self._precision.setup_optimizer(optimizer)
self.megatron_callbacks.add(optimizer)
return optimizer
@override
def setup_module(self, module: Module) -> MegatronParallel:
"""
Setup the torch module. Returns a MegatronParallel object.
"""
from megatron.core.utils import get_model_config
_strategy_lib.set_model_parallel_attributes(module, self.parallelism)
convert_module_fn = None
if hasattr(self.precision, "convert_module"):
convert_module_fn = self.precision.convert_module
if hasattr(self.precision, "convert_config"):
self.precision.convert_config(get_model_config(module))
if self.ddp_config:
self.precision.convert_config(self.ddp_config)
# Call configure_model if it's overridden (relevant for LightningModules with lazy initialization)
if hasattr(module, "configure_model"):
module.configure_model()
megatron_parallel = MegatronParallel(
module,
precision_plugin=self.precision,
vp_size=self.virtual_pipeline_model_parallel_size,
cpu=isinstance(self.accelerator, CPUAccelerator),
ddp_config=self.ddp_config,
convert_module_fn=convert_module_fn,
)
if self._init_model_parallel:
megatron_parallel.init_model_parallel()
if self.data_sampler:
megatron_parallel.callbacks.add(self.data_sampler)
if not self.ddp_config:
from megatron.core import mpu
from nemo.utils import AppState
app_state = AppState()
if app_state.model_parallel_size is not None:
self._ddp_kwargs["process_group"] = mpu.get_data_parallel_group()
dist_data_parallel = super().setup_module(megatron_parallel)
if self.no_ddp_communication_hook:
# When using custom gradient accumulation and allreduce, disable
# DDP communication hook that works on the gradient bucket.
# Instead, use the custom gradient function and communication hook,
# which is defined in the master optimizer wrapper.
dist_data_parallel.require_backward_grad_sync = False
dist_data_parallel.register_comm_hook(None, noop_hook)
return dist_data_parallel
return megatron_parallel
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
"""
Get the context manager used for initializing the module.
"""
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.megatron_context()
stack = ExitStack()
if empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
stack.enter_context(torch.device("meta"))
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
def module_to_device(self, module: nn.Module) -> None:
"""
Move the module to the device.
"""
pass
@override
def save_checkpoint(
self,
path: _PATH,
state: Dict[str, Union[Module, Optimizer, Any]],
storage_options: Optional[Any] = None,
filter_dict: Optional[Dict[str, Callable[[str, Any], bool]]] = None,
) -> None:
"""Save model, optimizer, and other state as a checkpoint file.
Args:
path: A path to where the file(s) should be saved
state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
state-dict will be retrieved and converted automatically.
storage_options: Additional options for the ``CheckpointIO`` plugin
filter: An optional dictionary containing filter callables that return a boolean indicating whether the
given item should be saved (``True``) or filtered out (``False``). Each filter key should match a
state key, where its filter will be applied to the ``state_dict`` generated.
"""
if not storage_options:
storage_options = {}
storage_options['content_metadata'] = self.sharded_state_dict_metadata
state = self._convert_stateful_objects_in_state(state, filter=(filter_dict or {}))
self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options)
def load_checkpoint(
self,
path: _PATH,
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""
Load the checkpoint.
"""
if isinstance(state, Optimizer):
raise NotImplementedError("Optimizer loading is not supported, pass it as a dict including the model")
unwrapped_model = unwrap_model(state["state_dict"])
from nemo.collections.vlm.llama4.model.base import Llama4OmniBaseModel
if HAVE_MODELOPT and isinstance(unwrapped_model, Llama4OmniBaseModel):
# If present, first restore and modify the model according to the ModelOpt state.
# Avoid quantizers being added to teacher model if model is a distillation model.
core_model = unwrapped_model.language_model
with core_model.hide_teacher_model() if hasattr(core_model, "hide_teacher_model") else nullcontext():
mto.plugins.restore_sharded_modelopt_state(
[core_model], ckpt_to_weights_subdir(path, is_saving=False), prefix="module.language_model."
)
if mto.ModeloptStateManager.is_converted(core_model):
print("Restored Model-Optimizer state from checkpoint.")
torch.cuda.empty_cache()
# After dist_checkpointing.load, sharded tensors will be replaced with tensors
sharded_sd_metadata = self.unwrapped_checkpoint_io.load_content_metadata(path)
sharded_state_dict = {}
if isinstance(state, Module):
sharded_state_dict["state_dict"] = state.sharded_state_dict(metadata=sharded_sd_metadata)
elif strict:
if isinstance(state['state_dict'], DistributedDataParallel):
state["state_dict"] = state['state_dict'].module
sharded_state_dict["state_dict"] = state["state_dict"].sharded_state_dict(metadata=sharded_sd_metadata)
if "optimizer" in state:
sharded_state_dict["optimizer"] = _strategy_lib.optimizer_sharded_state_dict(
state["state_dict"],
state["optimizer"],
is_loading=True,
metadata=sharded_sd_metadata,
)
else:
for obj in state.items():
if isinstance(obj, Module):
sharded_state_dict["state_dict"] = obj.sharded_state_dict(metadata=sharded_sd_metadata)
elif isinstance(obj, Optimizer):
sharded_state_dict["optimizer"] = _strategy_lib.optimizer_sharded_state_dict(
obj, is_loading=True, metadata=sharded_sd_metadata
)
checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict)
if isinstance(state, Module):
self.load_module_state_dict(module=state, state_dict=checkpoint, strict=strict)
return {}
_validate_keys_for_strict_loading(state.keys(), checkpoint.keys(), strict=strict)
for name, obj in state.copy().items():
if name not in checkpoint:
continue
if isinstance(obj, _Stateful):
if isinstance(obj, Module):
self.load_module_state_dict(module=obj, state_dict=checkpoint.pop(name), strict=strict)
else:
obj.load_state_dict(checkpoint.pop(name))
else:
state[name] = checkpoint.pop(name)
return checkpoint
@override
def load_module_state_dict(
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
) -> None:
"""
Load the module state dict.
"""
_strategy_lib.load_model_state_dict(module, state_dict, strict=strict)
@property
def sharded_state_dict_metadata(self):
"""Metadata used for sharded_state_dict generation during checkpoint save."""
metadata = {}
metadata['singleton_local_shards'] = False
metadata['chained_optim_avoid_prefix'] = True
if isinstance(self.ddp_config, DistributedDataParallelConfig) and self.ddp_config.use_distributed_optimizer:
metadata['distrib_optim_sharding_type'] = 'dp_reshardable'
return metadata
@contextmanager
def megatron_context(self) -> Generator[None, None, None]:
"""
Context manager for Megatron.
"""
from megatron.core.extensions import transformer_engine as _te
original = _te._get_extra_te_kwargs # noqa: SLF001
def _get_extra_te_kwargs_meta(c):
"""Forces device to meta"""
kwargs = original(c)
kwargs['device'] = 'meta'
return kwargs
_te._get_extra_te_kwargs = _get_extra_te_kwargs_meta # noqa: SLF001
_orig_perform_initialization = self.parallelism.perform_initialization
_orig_use_cpu_initialization = self.parallelism.use_cpu_initialization
self.parallelism.perform_initialization = False
self.parallelism.use_cpu_initialization = True
yield
_te._get_extra_te_kwargs = original # noqa: SLF001
self.parallelism.perform_initialization = _orig_perform_initialization
self.parallelism.use_cpu_initialization = _orig_use_cpu_initialization
@property
@override
def checkpoint_io(self) -> CheckpointIO:
"""
Get the checkpoint IO.
"""
if self._checkpoint_io is None:
self._checkpoint_io = MegatronCheckpointIO()
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = MegatronCheckpointIO()
return self._checkpoint_io
@property
def unwrapped_checkpoint_io(self) -> CheckpointIO:
"""Unwraps `checkpoint_io` from all wrappers."""
checkpoint_io = self.checkpoint_io
while isinstance(checkpoint_io, _WrappingCheckpointIO):
checkpoint_io = checkpoint_io.checkpoint_io
return checkpoint_io
@property
def parallelism(self):
"""
Get the parallelism config.
"""
from nemo.lightning.pytorch.strategies.megatron_strategy import ParallelismConfig
return ParallelismConfig(
tensor_model_parallel_size=self.tensor_model_parallel_size,
pipeline_model_parallel_size=self.pipeline_model_parallel_size,
pipeline_model_parallel_comm_backend=self.pipeline_model_parallel_comm_backend,
virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size,
microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage,
context_parallel_size=self.context_parallel_size,
sequence_parallel=self.sequence_parallel,
expert_model_parallel_size=self.expert_model_parallel_size,
expert_tensor_parallel_size=self.expert_tensor_parallel_size,
moe_extended_tp=self.moe_extended_tp,
encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size,
pipeline_dtype=self.pipeline_dtype,
use_tp_pp_dp_mapping=self.use_tp_pp_dp_mapping,
num_distributed_optimizer_instances=self.num_distributed_optimizer_instances,
nccl_communicator_config_path=self.nccl_communicator_config_path,
)
# TODO: Fix this
class _MegatronDataLoaderIterDataFetcher(_DataFetcher):
def __init__(self, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.output_data_idx = output_data_idx
self._batch: Any = None
self._batch_idx: int = 0
self._dataloader_idx: int = 0
def __iter__(self) -> "_MegatronDataLoaderIterDataFetcher":
super().__iter__()
self.iterator_wrapper = iter(_DataFetcherWrapper(self, output_data_idx=self.output_data_idx))
return self
def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override]
if self.done:
raise StopIteration
return self.iterator_wrapper
def reset(self) -> None:
"""
Reset the data fetcher.
"""
super().reset()
self._batch = None
self._batch_idx = 0
self._dataloader_idx = 0
class _DataFetcherWrapper(Iterator):
def __init__(
self,
data_fetcher: _MegatronDataLoaderIterDataFetcher,
output_data_idx: bool = False,
) -> None:
self.data_fetcher = data_fetcher
self.output_data_idx = output_data_idx
@property
def done(self) -> bool:
"""
Check if the data fetcher is done.
"""
return self.data_fetcher.done
@property
def fetched(self) -> int:
"""
Check if the data fetcher is fetched.
"""
return self.data_fetcher.fetched
@property
def length(self) -> Optional[int]:
"""
Get the length of the data fetcher.
"""
return self.data_fetcher.length
@property
def data_config(self):
"""
Get the data config.
"""
return self.data_fetcher.data_config
def __next__(self):
fetcher = self.data_fetcher
if fetcher.done:
raise StopIteration
batch, batch_idx, dataloader_idx = super(_MegatronDataLoaderIterDataFetcher, fetcher).__next__()
# save the state so the loops can access it
fetcher._batch = batch # noqa: SLF001
fetcher._batch_idx = batch_idx # noqa: SLF001
fetcher._dataloader_idx = dataloader_idx # noqa: SLF001
if not self.output_data_idx:
return batch
return batch, batch_idx, dataloader_idx
@to_fabric.register(MegatronStrategy)
def convert_megatron_strategy(strategy: MegatronStrategy) -> FabricMegatronStrategy:
"""
Convert the Megatron strategy to the Fabric strategy.
"""
return FabricMegatronStrategy(
tensor_model_parallel_size=strategy.tensor_model_parallel_size,
pipeline_model_parallel_size=strategy.pipeline_model_parallel_size,
pipeline_model_parallel_comm_backend=strategy.pipeline_model_parallel_comm_backend,
virtual_pipeline_model_parallel_size=strategy.virtual_pipeline_model_parallel_size,
microbatch_group_size_per_vp_stage=strategy.microbatch_group_size_per_vp_stage,
context_parallel_size=strategy.context_parallel_size,
sequence_parallel=strategy.sequence_parallel,
expert_model_parallel_size=strategy.expert_model_parallel_size,
expert_tensor_parallel_size=strategy.expert_tensor_parallel_size,
moe_extended_tp=strategy.moe_extended_tp,
encoder_tensor_model_parallel_size=strategy.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=strategy.encoder_pipeline_model_parallel_size,
pipeline_dtype=strategy.pipeline_dtype,
use_tp_pp_dp_mapping=strategy.use_tp_pp_dp_mapping,
ddp=strategy._ddp,
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
)