# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import shutil from abc import ABC, abstractmethod from contextlib import contextmanager from time import time from typing import Any, Dict, Optional, Union import lightning.pytorch as pl import torch from lightning.fabric.plugins import CheckpointIO from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Callback from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from nemo.utils import logging try: from megatron.core import dist_checkpointing from megatron.core.dist_checkpointing.dict_utils import extract_matching_values from megatron.core.dist_checkpointing.mapping import ShardedBase from megatron.core.dist_checkpointing.serialization import ( get_default_load_sharded_strategy, get_default_save_sharded_strategy, ) from megatron.core.dist_checkpointing.strategies import tensorstore from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy from megatron.core.dist_checkpointing.strategies.fully_parallel import ( FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy from megatron.core.dist_checkpointing.validation import StrictHandling from megatron.core.parallel_state import get_data_parallel_group HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError) as e: HAVE_MEGATRON_CORE = False IMPORT_ERROR = ( "megatron-core was not found. " "Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." f" Exact error: {e}" ) @contextmanager def _debug_time(name: str): """Simple context manager for timing functions/code blocks.""" start = time() try: yield finally: logging.debug(f'{name} took {time() - start:.3f}s') class AsyncCompatibleCheckpointIO(CheckpointIO, ABC): """CheckpointIO that can be used together with async saving. Differs from the regular CheckpointIO only by the `save_checkpoint` return type. The `save_checkpoint` method itself is synchronous, but returns callbacks that can be performed asynchronously. """ @abstractmethod def save_checkpoint( self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None ) -> 'AsyncRequest': """Interface to implement save_checkpoint and return an AsyncRequest""" raise NotImplementedError class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): """CheckpointIO wrapper for async checkpoint saving and synchronous finalization. Runs main part of the checkpoint save in a separate process (not thread as the PTL AsyncCheckpointIO does). Allows to perform a (synchronous) finalization function after all ranks finish checkpoint saving. NOTE: for correctness, this plugin must be used together with the AsyncFinalizerCallback callback which performs the finalization checks. Args: checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be of type AsyncCompatibleCheckpointIO. Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn. """ def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: if not HAVE_MEGATRON_CORE: raise ImportError(IMPORT_ERROR) if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') super().__init__(checkpoint_io) self.async_calls_queue = AsyncCallsQueue() def save_checkpoint( self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None, ) -> None: """Executes async request returned from the underlying checkpoint_io asynchronously. Requires the underlying checkpoint_io.save_checkpoint to return an AsyncRequest. It is then applied with `self.async_calls_queue` asynchronously. Args: checkpoint (Dict[str, Any]): checkpoint to save. Passed to underlying checkpoint_io without modifications. path (_PATH): path to save the checkpoint. Passed to underlying checkpoint_io without modifications. storage_options (Any, optional): storage control modifiers. This class consumed the `finalize_fn` parameter (if any), which is expected to be a callback and is appended to async finalization functions. Applies underlying checkpoint_io finalize callback first, then the external one (postfix order). """ external_finalize_fn = (storage_options or {}).pop('finalize_fn', None) assert isinstance(self.checkpoint_io, AsyncCompatibleCheckpointIO), type(self.checkpoint_io) async_request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options) if external_finalize_fn is not None: async_request.add_finalize_fn(external_finalize_fn) call_idx = self.async_calls_queue.schedule_async_request(async_request) logging.debug(f'Scheduled an async call #{call_idx}') @_debug_time('AsyncFinalizableCheckpointIO.maybe_finalize_save_checkpoint') def maybe_finalize_save_checkpoint(self, blocking: bool = False): """Performs checkpoint finalization (if possible). Args: blocking (bool, optional): if True, waits until all async saves are completed. Otherwise, finalizes only those async calls which are already done on all ranks. Defaults to False. """ if self.async_calls_queue.get_num_unfinalized_calls() == 0: return False start_time = time() call_idx_finalized = self.async_calls_queue.maybe_finalize_async_calls(blocking) if call_idx_finalized: logging.debug(f'Finalized async calls: {[f"#{idx}" for idx in call_idx_finalized]}') end_time = time() logging.info(f"Async finalization time took {end_time - start_time:.3f} s") return len(call_idx_finalized) > 0 def teardown(self) -> None: """Warns if there are any pending checkpoint saves.""" super().teardown() if self.async_calls_queue.get_num_unfinalized_calls() > 0: # Can't do finalization now because some ranks might be lost logging.warning('Some async checkpoint saves might be not finalized properly.') class AsyncFinalizerCallback(Callback): """Callback which finalizes async saves initiated by the AsyncFinalizableCheckpointIO. Tries to perform non-blocking finalization on train_batch_end and train_epoch_end. On train_end performs a blocking finalization of all pending checkpoints. """ def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: """Override hook to finalize pending checkpoint(s) if they exist.""" self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: """Override hook to finalize pending checkpoint(s) if they exist.""" self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) def on_train_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: """Override hook to finalize pending checkpoint(s) if they exist.""" checkpoint_io = self._get_checkpoint_io(trainer) if checkpoint_io.async_calls_queue.get_num_unfinalized_calls() > 0: logging.info('Pending async checkpoint saves. Finalizing them synchronously now') self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=True) def _get_checkpoint_io(self, trainer) -> AsyncFinalizableCheckpointIO: checkpoint_io = trainer.strategy.checkpoint_io if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): raise ValueError( f'Async finalizer requires an async compatible CheckpointIO, got: {checkpoint_io.__class__}' ) return checkpoint_io class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): """CheckpointIO for a distributed checkpoint format. Args: save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. load_directly_on_device (bool, optional): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. load_strictness (StrictHandling, optional): defines loading strictness. If not None, overwrites the `strict` flag passed to `load_checkpoint`. Defaults to None. async_save (bool): whether to save asynchronously. Should be set to True if this class will be wrapped with AsyncFinalizableCheckpointIO. torch_dist_multiproc (int, optional): number of extra processes per rank used during ckpt save with PyTorch distributed format. Defaults, to None which means using an MCore default (2). parallel_save (bool): parallelizes the save across ranks. Defaults to True parallel_load (bool): parallelizes the load across ranks (followed by params all gather). Defaults to False due to some extra memory usage requirement. """ def __init__( self, save_ckpt_format: str, load_directly_on_device: bool = True, load_strictness: Optional['StrictHandling'] = None, async_save: bool = False, torch_dist_multiproc: Optional[int] = None, assume_constant_structure: bool = False, parallel_save: bool = False, parallel_save_within_dp: bool = False, parallel_load: bool = False, ): super().__init__() if not HAVE_MEGATRON_CORE: raise ImportError(IMPORT_ERROR) self.save_ckpt_format = save_ckpt_format self.load_directly_on_device = load_directly_on_device self.load_strictness = load_strictness self.async_save = async_save self.torch_dist_multiproc = torch_dist_multiproc self.assume_constant_structure = assume_constant_structure self.parallel_save = parallel_save self.parallel_save_within_dp = parallel_save_within_dp self.parallel_load = parallel_load self._save_sharded_strategy = None self.validated_consistency = False @classmethod def from_config(cls, model_cfg: dict, async_save: bool = False): """Instantiates a DistributedCheckpointIO from a config dict. Args: model_cfg (dict): model config dict. Most of the configuration is extracted from this config. async_save (bool, optional): async_save flag is not part of the model config, it should be provided separately. Defaults to False. """ return cls( save_ckpt_format=model_cfg.get('dist_ckpt_format', 'torch_dist'), load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), load_strictness=model_cfg.get('dist_ckpt_load_strictness', None), async_save=async_save, torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False), parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), ) @_debug_time('DistributedCheckpointIO.save_checkpoint') def save_checkpoint( self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None ) -> Optional['AsyncRequest']: """Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist. Args: checkpoint (Dict[str, Any]): sharded state dict to save path (_PATH): checkpoint directory storage_options (Any, optional): Optional parameters when saving the checkpoint """ fs = get_filesystem(path) fs.makedirs(path, exist_ok=True) validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) self.validated_consistency = True rank = torch.distributed.get_rank() iteration = _get_iteration_from_checkpoint(checkpoint) start_time = time() async_save_request = dist_checkpointing.save( sharded_state_dict=checkpoint, checkpoint_dir=path, sharded_strategy=self.save_sharded_strategy, validate_access_integrity=validate_sharding_integrity, async_sharded_save=self.async_save, ) end_time = time() log_parts = ( "Global Checkpoint Save", f"Rank: {rank}", f"Iteration: {iteration}" if iteration is not None else None, f"Start time: {start_time:.3f}s", f"Save duration: {end_time - start_time:.3f}s", ) log_message = " : ".join(part for part in log_parts if part is not None) logging.info(log_message) def iter_finalize_fn(): logging.info(f'Successfully saved checkpoint from iteration {int(iteration):7d} to {path}') if self.async_save: assert async_save_request is not None async_save_request.add_finalize_fn(iter_finalize_fn) return async_save_request @_debug_time('DistributedCheckpointIO.load_checkpoint') def load_checkpoint( self, path: _PATH, map_location: Optional[Any] = None, sharded_state_dict: Dict[str, Any] = None, strict: Union[None, bool, 'StrictHandling'] = None, validate_access_integrity: Optional[bool] = True, ) -> Dict[str, Any]: """Loads a distributed checkpoint. Args: path (_PATH): checkpoint directory map_location (Any, optional): required to be None in this implementation sharded_state_dict (Dict[str, Any], optional): state dict which defines the loading procedure for the distributed checkpoint. Defaults to None to comply with the CheckpointIO interface, but it's a required argument. strict (bool, StrictHandling, optional): adjust load strictness. bool value is translated to StrictHandling instance. Gets overwritten by `self.load_strictness`. Defaults to None. If `self.load_strictness` is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED. Returns: Dist[str, Any]: loaded checkpoint. """ if sharded_state_dict is None: raise ValueError('DistributedCheckpointIO requires passing sharded_state_dict argument to load_checkpoint') if map_location is not None: raise ValueError('DistributedCheckpointIO doesnt handle map_location argument') if self.save_ckpt_format == 'zarr' and self.load_directly_on_device: sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) else: sharded_strategy = None if self.parallel_load: if sharded_strategy is None: sharded_strategy = get_default_load_sharded_strategy(path) sharded_strategy = FullyParallelLoadStrategyWrapper( sharded_strategy, get_data_parallel_group(with_context_parallel=True) ) if sharded_strategy is not None: logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') if isinstance(strict, bool): # For backward-compatibility reasons and a bug in MCore (strict check not applied to factories) # we must apply a simple strict check here. if not strict: sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL if self.load_strictness is not None: # Overwrites function argument strict = self.load_strictness if strict is None: # Default behavior strict = StrictHandling.ASSUME_OK_UNEXPECTED logging.debug(f'Dist ckpt load strictness: {strict}') start_time = time() ret = dist_checkpointing.load( sharded_state_dict=sharded_state_dict, checkpoint_dir=path, sharded_strategy=sharded_strategy, validate_access_integrity=validate_access_integrity, strict=strict, ) end_time = time() duration = end_time - start_time logging.info( "Global Checkpoint Load : " f"Rank : {torch.distributed.get_rank()} : " f"Start time : {start_time:.3f}s : " f"Time spent in load_checkpoint: {duration:.3f}s" ) return ret def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): """Remove unexpected keys from being loaded into the state dict.""" ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) loaded_keys = [] unexpected_keys = [] def should_remove_missing_sharded_base(x: Any): if isinstance(x, ShardedBase): if x.key in ckpt_sharded_metadata: loaded_keys.append(x.key) return False else: unexpected_keys.append(x.key) return True return False _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') # TODO: compute missing_keys by: # 1. all_gather_object of loaded_keys # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys return sharded_state_dict @_debug_time('DistributedCheckpointIO.remove_checkpoint') def remove_checkpoint(self, path: _PATH) -> None: """Remove a distributed checkpoint. Due to potentially large number of files, the implementation remove the whole directory at once. """ shutil.rmtree(path, ignore_errors=True) @property def save_sharded_strategy(self) -> 'SaveShardedStrategy': """Conditionally initialize and get the sharded strategy to use for saving.""" if self._save_sharded_strategy is None: self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() return self._save_sharded_strategy def _determine_dist_ckpt_save_strategy(self): """Determine the saving strategy based on constructor args. Relies on the default MCore strategy unless extra PyT Distributed format arguments are passed in config or in case of a fully parallel save in which case a parallelization wrapper is applied. """ if self.save_ckpt_format == 'zarr': logging.warning( '`zarr` distributed checkpoint backend is deprecated.' ' Distributed optimizer checkpoint saving might be extremely slow.' ' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).' ) if self.async_save and self.save_ckpt_format != 'torch_dist': raise ValueError('Async dist-ckpt save supported only for torch_dist format') torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc) if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs: save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs) else: save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1) # MCore v0.8 introduces `use_cached_ckpt_structure` attribute if hasattr(save_strategy, 'use_cached_ckpt_structure'): save_strategy.use_cached_ckpt_structure = self.assume_constant_structure if self.parallel_save: parallelization_group = ( get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None ) save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, parallelization_group, self.assume_constant_structure ) logging.info(f'Using {save_strategy} dist-ckpt save strategy.') return save_strategy def _get_iteration_from_checkpoint(checkpoint: Dict[str, Any]) -> Optional[int]: return ( checkpoint.get("loops", {}) .get("fit_loop", {}) .get("epoch_loop.batch_progress", {}) .get("total", {}) .get("completed", None) )