Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import shutil | |
| from abc import ABC, abstractmethod | |
| from contextlib import contextmanager | |
| from time import time | |
| from typing import Any, Dict, Optional, Union | |
| import lightning.pytorch as pl | |
| import torch | |
| from lightning.fabric.plugins import CheckpointIO | |
| from lightning.fabric.utilities.cloud_io import get_filesystem | |
| from lightning.fabric.utilities.types import _PATH | |
| from lightning.pytorch import Callback | |
| from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO | |
| from nemo.utils import logging | |
| try: | |
| from megatron.core import dist_checkpointing | |
| from megatron.core.dist_checkpointing.dict_utils import extract_matching_values | |
| from megatron.core.dist_checkpointing.mapping import ShardedBase | |
| from megatron.core.dist_checkpointing.serialization import ( | |
| get_default_load_sharded_strategy, | |
| get_default_save_sharded_strategy, | |
| ) | |
| from megatron.core.dist_checkpointing.strategies import tensorstore | |
| from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest | |
| from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy | |
| from megatron.core.dist_checkpointing.strategies.fully_parallel import ( | |
| FullyParallelLoadStrategyWrapper, | |
| FullyParallelSaveStrategyWrapper, | |
| ) | |
| from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy | |
| from megatron.core.dist_checkpointing.validation import StrictHandling | |
| from megatron.core.parallel_state import get_data_parallel_group | |
| HAVE_MEGATRON_CORE = True | |
| except (ImportError, ModuleNotFoundError) as e: | |
| HAVE_MEGATRON_CORE = False | |
| IMPORT_ERROR = ( | |
| "megatron-core was not found. " | |
| "Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." | |
| f" Exact error: {e}" | |
| ) | |
| def _debug_time(name: str): | |
| """Simple context manager for timing functions/code blocks.""" | |
| start = time() | |
| try: | |
| yield | |
| finally: | |
| logging.debug(f'{name} took {time() - start:.3f}s') | |
| class AsyncCompatibleCheckpointIO(CheckpointIO, ABC): | |
| """CheckpointIO that can be used together with async saving. | |
| Differs from the regular CheckpointIO only by the `save_checkpoint` | |
| return type. The `save_checkpoint` method itself is synchronous, but returns | |
| callbacks that can be performed asynchronously. | |
| """ | |
| def save_checkpoint( | |
| self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None | |
| ) -> 'AsyncRequest': | |
| """Interface to implement save_checkpoint and return an AsyncRequest""" | |
| raise NotImplementedError | |
| class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): | |
| """CheckpointIO wrapper for async checkpoint saving and synchronous finalization. | |
| Runs main part of the checkpoint save in a separate process (not thread as the PTL | |
| AsyncCheckpointIO does). Allows to perform a (synchronous) finalization | |
| function after all ranks finish checkpoint saving. | |
| NOTE: for correctness, this plugin must be used together with the | |
| AsyncFinalizerCallback callback which performs the finalization checks. | |
| Args: | |
| checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be | |
| of type AsyncCompatibleCheckpointIO. | |
| Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn. | |
| """ | |
| def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: | |
| if not HAVE_MEGATRON_CORE: | |
| raise ImportError(IMPORT_ERROR) | |
| if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO): | |
| raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') | |
| super().__init__(checkpoint_io) | |
| self.async_calls_queue = AsyncCallsQueue() | |
| def save_checkpoint( | |
| self, | |
| checkpoint: Dict[str, Any], | |
| path: _PATH, | |
| storage_options: Optional[Any] = None, | |
| ) -> None: | |
| """Executes async request returned from the underlying checkpoint_io asynchronously. | |
| Requires the underlying checkpoint_io.save_checkpoint to return an AsyncRequest. | |
| It is then applied with `self.async_calls_queue` asynchronously. | |
| Args: | |
| checkpoint (Dict[str, Any]): checkpoint to save. Passed to underlying | |
| checkpoint_io without modifications. | |
| path (_PATH): path to save the checkpoint. Passed to underlying | |
| checkpoint_io without modifications. | |
| storage_options (Any, optional): storage control modifiers. This class | |
| consumed the `finalize_fn` parameter (if any), which is expected to be | |
| a callback and is appended to async finalization functions. | |
| Applies underlying checkpoint_io finalize callback first, then the external one (postfix order). | |
| """ | |
| external_finalize_fn = (storage_options or {}).pop('finalize_fn', None) | |
| assert isinstance(self.checkpoint_io, AsyncCompatibleCheckpointIO), type(self.checkpoint_io) | |
| async_request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options) | |
| if external_finalize_fn is not None: | |
| async_request.add_finalize_fn(external_finalize_fn) | |
| call_idx = self.async_calls_queue.schedule_async_request(async_request) | |
| logging.debug(f'Scheduled an async call #{call_idx}') | |
| def maybe_finalize_save_checkpoint(self, blocking: bool = False): | |
| """Performs checkpoint finalization (if possible). | |
| Args: | |
| blocking (bool, optional): if True, waits until all async saves are | |
| completed. Otherwise, finalizes only those async calls which are | |
| already done on all ranks. Defaults to False. | |
| """ | |
| if self.async_calls_queue.get_num_unfinalized_calls() == 0: | |
| return False | |
| start_time = time() | |
| call_idx_finalized = self.async_calls_queue.maybe_finalize_async_calls(blocking) | |
| if call_idx_finalized: | |
| logging.debug(f'Finalized async calls: {[f"#{idx}" for idx in call_idx_finalized]}') | |
| end_time = time() | |
| logging.info(f"Async finalization time took {end_time - start_time:.3f} s") | |
| return len(call_idx_finalized) > 0 | |
| def teardown(self) -> None: | |
| """Warns if there are any pending checkpoint saves.""" | |
| super().teardown() | |
| if self.async_calls_queue.get_num_unfinalized_calls() > 0: | |
| # Can't do finalization now because some ranks might be lost | |
| logging.warning('Some async checkpoint saves might be not finalized properly.') | |
| class AsyncFinalizerCallback(Callback): | |
| """Callback which finalizes async saves initiated by the AsyncFinalizableCheckpointIO. | |
| Tries to perform non-blocking finalization on train_batch_end and train_epoch_end. | |
| On train_end performs a blocking finalization of all pending checkpoints. | |
| """ | |
| def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: | |
| """Override hook to finalize pending checkpoint(s) if they exist.""" | |
| self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) | |
| def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: | |
| """Override hook to finalize pending checkpoint(s) if they exist.""" | |
| self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False) | |
| def on_train_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: | |
| """Override hook to finalize pending checkpoint(s) if they exist.""" | |
| checkpoint_io = self._get_checkpoint_io(trainer) | |
| if checkpoint_io.async_calls_queue.get_num_unfinalized_calls() > 0: | |
| logging.info('Pending async checkpoint saves. Finalizing them synchronously now') | |
| self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=True) | |
| def _get_checkpoint_io(self, trainer) -> AsyncFinalizableCheckpointIO: | |
| checkpoint_io = trainer.strategy.checkpoint_io | |
| if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO): | |
| raise ValueError( | |
| f'Async finalizer requires an async compatible CheckpointIO, got: {checkpoint_io.__class__}' | |
| ) | |
| return checkpoint_io | |
| class DistributedCheckpointIO(AsyncCompatibleCheckpointIO): | |
| """CheckpointIO for a distributed checkpoint format. | |
| Args: | |
| save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. | |
| load_directly_on_device (bool, optional): if True, loads the weights directly | |
| on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed | |
| always loads on device). Defaults to True. | |
| load_strictness (StrictHandling, optional): defines loading strictness. | |
| If not None, overwrites the `strict` flag passed to `load_checkpoint`. | |
| Defaults to None. | |
| async_save (bool): whether to save asynchronously. Should be set to True if | |
| this class will be wrapped with AsyncFinalizableCheckpointIO. | |
| torch_dist_multiproc (int, optional): number of extra processes per rank | |
| used during ckpt save with PyTorch distributed format. Defaults, to None | |
| which means using an MCore default (2). | |
| parallel_save (bool): parallelizes the save across ranks. Defaults to True | |
| parallel_load (bool): parallelizes the load across ranks (followed by params all gather). | |
| Defaults to False due to some extra memory usage requirement. | |
| """ | |
| def __init__( | |
| self, | |
| save_ckpt_format: str, | |
| load_directly_on_device: bool = True, | |
| load_strictness: Optional['StrictHandling'] = None, | |
| async_save: bool = False, | |
| torch_dist_multiproc: Optional[int] = None, | |
| assume_constant_structure: bool = False, | |
| parallel_save: bool = False, | |
| parallel_save_within_dp: bool = False, | |
| parallel_load: bool = False, | |
| ): | |
| super().__init__() | |
| if not HAVE_MEGATRON_CORE: | |
| raise ImportError(IMPORT_ERROR) | |
| self.save_ckpt_format = save_ckpt_format | |
| self.load_directly_on_device = load_directly_on_device | |
| self.load_strictness = load_strictness | |
| self.async_save = async_save | |
| self.torch_dist_multiproc = torch_dist_multiproc | |
| self.assume_constant_structure = assume_constant_structure | |
| self.parallel_save = parallel_save | |
| self.parallel_save_within_dp = parallel_save_within_dp | |
| self.parallel_load = parallel_load | |
| self._save_sharded_strategy = None | |
| self.validated_consistency = False | |
| def from_config(cls, model_cfg: dict, async_save: bool = False): | |
| """Instantiates a DistributedCheckpointIO from a config dict. | |
| Args: | |
| model_cfg (dict): model config dict. Most of the configuration | |
| is extracted from this config. | |
| async_save (bool, optional): async_save flag is not part of the model config, | |
| it should be provided separately. Defaults to False. | |
| """ | |
| return cls( | |
| save_ckpt_format=model_cfg.get('dist_ckpt_format', 'torch_dist'), | |
| load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True), | |
| load_strictness=model_cfg.get('dist_ckpt_load_strictness', None), | |
| async_save=async_save, | |
| torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None), | |
| parallel_save=model_cfg.get('dist_ckpt_parallel_save', False), | |
| parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False), | |
| parallel_load=model_cfg.get('dist_ckpt_parallel_load', False), | |
| ) | |
| def save_checkpoint( | |
| self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None | |
| ) -> Optional['AsyncRequest']: | |
| """Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist. | |
| Args: | |
| checkpoint (Dict[str, Any]): sharded state dict to save | |
| path (_PATH): checkpoint directory | |
| storage_options (Any, optional): Optional parameters when saving the checkpoint | |
| """ | |
| fs = get_filesystem(path) | |
| fs.makedirs(path, exist_ok=True) | |
| validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) | |
| self.validated_consistency = True | |
| rank = torch.distributed.get_rank() | |
| iteration = _get_iteration_from_checkpoint(checkpoint) | |
| start_time = time() | |
| async_save_request = dist_checkpointing.save( | |
| sharded_state_dict=checkpoint, | |
| checkpoint_dir=path, | |
| sharded_strategy=self.save_sharded_strategy, | |
| validate_access_integrity=validate_sharding_integrity, | |
| async_sharded_save=self.async_save, | |
| ) | |
| end_time = time() | |
| log_parts = ( | |
| "Global Checkpoint Save", | |
| f"Rank: {rank}", | |
| f"Iteration: {iteration}" if iteration is not None else None, | |
| f"Start time: {start_time:.3f}s", | |
| f"Save duration: {end_time - start_time:.3f}s", | |
| ) | |
| log_message = " : ".join(part for part in log_parts if part is not None) | |
| logging.info(log_message) | |
| def iter_finalize_fn(): | |
| logging.info(f'Successfully saved checkpoint from iteration {int(iteration):7d} to {path}') | |
| if self.async_save: | |
| assert async_save_request is not None | |
| async_save_request.add_finalize_fn(iter_finalize_fn) | |
| return async_save_request | |
| def load_checkpoint( | |
| self, | |
| path: _PATH, | |
| map_location: Optional[Any] = None, | |
| sharded_state_dict: Dict[str, Any] = None, | |
| strict: Union[None, bool, 'StrictHandling'] = None, | |
| validate_access_integrity: Optional[bool] = True, | |
| ) -> Dict[str, Any]: | |
| """Loads a distributed checkpoint. | |
| Args: | |
| path (_PATH): checkpoint directory | |
| map_location (Any, optional): required to be None in this implementation | |
| sharded_state_dict (Dict[str, Any], optional): state dict which | |
| defines the loading procedure for the distributed checkpoint. | |
| Defaults to None to comply with the CheckpointIO interface, | |
| but it's a required argument. | |
| strict (bool, StrictHandling, optional): adjust load strictness. bool value | |
| is translated to StrictHandling instance. Gets overwritten by | |
| `self.load_strictness`. Defaults to None. If `self.load_strictness` | |
| is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED. | |
| Returns: | |
| Dist[str, Any]: loaded checkpoint. | |
| """ | |
| if sharded_state_dict is None: | |
| raise ValueError('DistributedCheckpointIO requires passing sharded_state_dict argument to load_checkpoint') | |
| if map_location is not None: | |
| raise ValueError('DistributedCheckpointIO doesnt handle map_location argument') | |
| if self.save_ckpt_format == 'zarr' and self.load_directly_on_device: | |
| sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True) | |
| else: | |
| sharded_strategy = None | |
| if self.parallel_load: | |
| if sharded_strategy is None: | |
| sharded_strategy = get_default_load_sharded_strategy(path) | |
| sharded_strategy = FullyParallelLoadStrategyWrapper( | |
| sharded_strategy, get_data_parallel_group(with_context_parallel=True) | |
| ) | |
| if sharded_strategy is not None: | |
| logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') | |
| if isinstance(strict, bool): | |
| # For backward-compatibility reasons and a bug in MCore (strict check not applied to factories) | |
| # we must apply a simple strict check here. | |
| if not strict: | |
| sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) | |
| strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL | |
| if self.load_strictness is not None: | |
| # Overwrites function argument | |
| strict = self.load_strictness | |
| if strict is None: | |
| # Default behavior | |
| strict = StrictHandling.ASSUME_OK_UNEXPECTED | |
| logging.debug(f'Dist ckpt load strictness: {strict}') | |
| start_time = time() | |
| ret = dist_checkpointing.load( | |
| sharded_state_dict=sharded_state_dict, | |
| checkpoint_dir=path, | |
| sharded_strategy=sharded_strategy, | |
| validate_access_integrity=validate_access_integrity, | |
| strict=strict, | |
| ) | |
| end_time = time() | |
| duration = end_time - start_time | |
| logging.info( | |
| "Global Checkpoint Load : " | |
| f"Rank : {torch.distributed.get_rank()} : " | |
| f"Start time : {start_time:.3f}s : " | |
| f"Time spent in load_checkpoint: {duration:.3f}s" | |
| ) | |
| return ret | |
| def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): | |
| """Remove unexpected keys from being loaded into the state dict.""" | |
| ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) | |
| loaded_keys = [] | |
| unexpected_keys = [] | |
| def should_remove_missing_sharded_base(x: Any): | |
| if isinstance(x, ShardedBase): | |
| if x.key in ckpt_sharded_metadata: | |
| loaded_keys.append(x.key) | |
| return False | |
| else: | |
| unexpected_keys.append(x.key) | |
| return True | |
| return False | |
| _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) | |
| logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') | |
| # TODO: compute missing_keys by: | |
| # 1. all_gather_object of loaded_keys | |
| # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys | |
| return sharded_state_dict | |
| def remove_checkpoint(self, path: _PATH) -> None: | |
| """Remove a distributed checkpoint. | |
| Due to potentially large number of files, the implementation remove the whole directory at once. | |
| """ | |
| shutil.rmtree(path, ignore_errors=True) | |
| def save_sharded_strategy(self) -> 'SaveShardedStrategy': | |
| """Conditionally initialize and get the sharded strategy to use for saving.""" | |
| if self._save_sharded_strategy is None: | |
| self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() | |
| return self._save_sharded_strategy | |
| def _determine_dist_ckpt_save_strategy(self): | |
| """Determine the saving strategy based on constructor args. | |
| Relies on the default MCore strategy unless extra PyT Distributed format arguments | |
| are passed in config or in case of a fully parallel save in which case | |
| a parallelization wrapper is applied. | |
| """ | |
| if self.save_ckpt_format == 'zarr': | |
| logging.warning( | |
| '`zarr` distributed checkpoint backend is deprecated.' | |
| ' Distributed optimizer checkpoint saving might be extremely slow.' | |
| ' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).' | |
| ) | |
| if self.async_save and self.save_ckpt_format != 'torch_dist': | |
| raise ValueError('Async dist-ckpt save supported only for torch_dist format') | |
| torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc) | |
| if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs: | |
| save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs) | |
| else: | |
| save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1) | |
| # MCore v0.8 introduces `use_cached_ckpt_structure` attribute | |
| if hasattr(save_strategy, 'use_cached_ckpt_structure'): | |
| save_strategy.use_cached_ckpt_structure = self.assume_constant_structure | |
| if self.parallel_save: | |
| parallelization_group = ( | |
| get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None | |
| ) | |
| save_strategy = FullyParallelSaveStrategyWrapper( | |
| save_strategy, parallelization_group, self.assume_constant_structure | |
| ) | |
| logging.info(f'Using {save_strategy} dist-ckpt save strategy.') | |
| return save_strategy | |
| def _get_iteration_from_checkpoint(checkpoint: Dict[str, Any]) -> Optional[int]: | |
| return ( | |
| checkpoint.get("loops", {}) | |
| .get("fit_loop", {}) | |
| .get("epoch_loop.batch_progress", {}) | |
| .get("total", {}) | |
| .get("completed", None) | |
| ) | |