subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from abc import ABC, abstractmethod
from contextlib import contextmanager
from time import time
from typing import Any, Dict, Optional, Union
import lightning.pytorch as pl
import torch
from lightning.fabric.plugins import CheckpointIO
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import Callback
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from nemo.utils import logging
try:
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.dict_utils import extract_matching_values
from megatron.core.dist_checkpointing.mapping import ShardedBase
from megatron.core.dist_checkpointing.serialization import (
get_default_load_sharded_strategy,
get_default_save_sharded_strategy,
)
from megatron.core.dist_checkpointing.strategies import tensorstore
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
FullyParallelLoadStrategyWrapper,
FullyParallelSaveStrategyWrapper,
)
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.parallel_state import get_data_parallel_group
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError) as e:
HAVE_MEGATRON_CORE = False
IMPORT_ERROR = (
"megatron-core was not found. "
"Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
f" Exact error: {e}"
)
@contextmanager
def _debug_time(name: str):
"""Simple context manager for timing functions/code blocks."""
start = time()
try:
yield
finally:
logging.debug(f'{name} took {time() - start:.3f}s')
class AsyncCompatibleCheckpointIO(CheckpointIO, ABC):
"""CheckpointIO that can be used together with async saving.
Differs from the regular CheckpointIO only by the `save_checkpoint`
return type. The `save_checkpoint` method itself is synchronous, but returns
callbacks that can be performed asynchronously.
"""
@abstractmethod
def save_checkpoint(
self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None
) -> 'AsyncRequest':
"""Interface to implement save_checkpoint and return an AsyncRequest"""
raise NotImplementedError
class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO):
"""CheckpointIO wrapper for async checkpoint saving and synchronous finalization.
Runs main part of the checkpoint save in a separate process (not thread as the PTL
AsyncCheckpointIO does). Allows to perform a (synchronous) finalization
function after all ranks finish checkpoint saving.
NOTE: for correctness, this plugin must be used together with the
AsyncFinalizerCallback callback which performs the finalization checks.
Args:
checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be
of type AsyncCompatibleCheckpointIO.
Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn.
"""
def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None:
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR)
if not isinstance(checkpoint_io, AsyncCompatibleCheckpointIO):
raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}')
super().__init__(checkpoint_io)
self.async_calls_queue = AsyncCallsQueue()
def save_checkpoint(
self,
checkpoint: Dict[str, Any],
path: _PATH,
storage_options: Optional[Any] = None,
) -> None:
"""Executes async request returned from the underlying checkpoint_io asynchronously.
Requires the underlying checkpoint_io.save_checkpoint to return an AsyncRequest.
It is then applied with `self.async_calls_queue` asynchronously.
Args:
checkpoint (Dict[str, Any]): checkpoint to save. Passed to underlying
checkpoint_io without modifications.
path (_PATH): path to save the checkpoint. Passed to underlying
checkpoint_io without modifications.
storage_options (Any, optional): storage control modifiers. This class
consumed the `finalize_fn` parameter (if any), which is expected to be
a callback and is appended to async finalization functions.
Applies underlying checkpoint_io finalize callback first, then the external one (postfix order).
"""
external_finalize_fn = (storage_options or {}).pop('finalize_fn', None)
assert isinstance(self.checkpoint_io, AsyncCompatibleCheckpointIO), type(self.checkpoint_io)
async_request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options)
if external_finalize_fn is not None:
async_request.add_finalize_fn(external_finalize_fn)
call_idx = self.async_calls_queue.schedule_async_request(async_request)
logging.debug(f'Scheduled an async call #{call_idx}')
@_debug_time('AsyncFinalizableCheckpointIO.maybe_finalize_save_checkpoint')
def maybe_finalize_save_checkpoint(self, blocking: bool = False):
"""Performs checkpoint finalization (if possible).
Args:
blocking (bool, optional): if True, waits until all async saves are
completed. Otherwise, finalizes only those async calls which are
already done on all ranks. Defaults to False.
"""
if self.async_calls_queue.get_num_unfinalized_calls() == 0:
return False
start_time = time()
call_idx_finalized = self.async_calls_queue.maybe_finalize_async_calls(blocking)
if call_idx_finalized:
logging.debug(f'Finalized async calls: {[f"#{idx}" for idx in call_idx_finalized]}')
end_time = time()
logging.info(f"Async finalization time took {end_time - start_time:.3f} s")
return len(call_idx_finalized) > 0
def teardown(self) -> None:
"""Warns if there are any pending checkpoint saves."""
super().teardown()
if self.async_calls_queue.get_num_unfinalized_calls() > 0:
# Can't do finalization now because some ranks might be lost
logging.warning('Some async checkpoint saves might be not finalized properly.')
class AsyncFinalizerCallback(Callback):
"""Callback which finalizes async saves initiated by the AsyncFinalizableCheckpointIO.
Tries to perform non-blocking finalization on train_batch_end and train_epoch_end.
On train_end performs a blocking finalization of all pending checkpoints.
"""
def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
"""Override hook to finalize pending checkpoint(s) if they exist."""
self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False)
def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
"""Override hook to finalize pending checkpoint(s) if they exist."""
self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=False)
def on_train_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
"""Override hook to finalize pending checkpoint(s) if they exist."""
checkpoint_io = self._get_checkpoint_io(trainer)
if checkpoint_io.async_calls_queue.get_num_unfinalized_calls() > 0:
logging.info('Pending async checkpoint saves. Finalizing them synchronously now')
self._get_checkpoint_io(trainer).maybe_finalize_save_checkpoint(blocking=True)
def _get_checkpoint_io(self, trainer) -> AsyncFinalizableCheckpointIO:
checkpoint_io = trainer.strategy.checkpoint_io
if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
raise ValueError(
f'Async finalizer requires an async compatible CheckpointIO, got: {checkpoint_io.__class__}'
)
return checkpoint_io
class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
"""CheckpointIO for a distributed checkpoint format.
Args:
save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving.
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
load_strictness (StrictHandling, optional): defines loading strictness.
If not None, overwrites the `strict` flag passed to `load_checkpoint`.
Defaults to None.
async_save (bool): whether to save asynchronously. Should be set to True if
this class will be wrapped with AsyncFinalizableCheckpointIO.
torch_dist_multiproc (int, optional): number of extra processes per rank
used during ckpt save with PyTorch distributed format. Defaults, to None
which means using an MCore default (2).
parallel_save (bool): parallelizes the save across ranks. Defaults to True
parallel_load (bool): parallelizes the load across ranks (followed by params all gather).
Defaults to False due to some extra memory usage requirement.
"""
def __init__(
self,
save_ckpt_format: str,
load_directly_on_device: bool = True,
load_strictness: Optional['StrictHandling'] = None,
async_save: bool = False,
torch_dist_multiproc: Optional[int] = None,
assume_constant_structure: bool = False,
parallel_save: bool = False,
parallel_save_within_dp: bool = False,
parallel_load: bool = False,
):
super().__init__()
if not HAVE_MEGATRON_CORE:
raise ImportError(IMPORT_ERROR)
self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device
self.load_strictness = load_strictness
self.async_save = async_save
self.torch_dist_multiproc = torch_dist_multiproc
self.assume_constant_structure = assume_constant_structure
self.parallel_save = parallel_save
self.parallel_save_within_dp = parallel_save_within_dp
self.parallel_load = parallel_load
self._save_sharded_strategy = None
self.validated_consistency = False
@classmethod
def from_config(cls, model_cfg: dict, async_save: bool = False):
"""Instantiates a DistributedCheckpointIO from a config dict.
Args:
model_cfg (dict): model config dict. Most of the configuration
is extracted from this config.
async_save (bool, optional): async_save flag is not part of the model config,
it should be provided separately. Defaults to False.
"""
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'torch_dist'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
async_save=async_save,
torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False),
parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
)
@_debug_time('DistributedCheckpointIO.save_checkpoint')
def save_checkpoint(
self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None
) -> Optional['AsyncRequest']:
"""Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist.
Args:
checkpoint (Dict[str, Any]): sharded state dict to save
path (_PATH): checkpoint directory
storage_options (Any, optional): Optional parameters when saving the checkpoint
"""
fs = get_filesystem(path)
fs.makedirs(path, exist_ok=True)
validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure)
self.validated_consistency = True
rank = torch.distributed.get_rank()
iteration = _get_iteration_from_checkpoint(checkpoint)
start_time = time()
async_save_request = dist_checkpointing.save(
sharded_state_dict=checkpoint,
checkpoint_dir=path,
sharded_strategy=self.save_sharded_strategy,
validate_access_integrity=validate_sharding_integrity,
async_sharded_save=self.async_save,
)
end_time = time()
log_parts = (
"Global Checkpoint Save",
f"Rank: {rank}",
f"Iteration: {iteration}" if iteration is not None else None,
f"Start time: {start_time:.3f}s",
f"Save duration: {end_time - start_time:.3f}s",
)
log_message = " : ".join(part for part in log_parts if part is not None)
logging.info(log_message)
def iter_finalize_fn():
logging.info(f'Successfully saved checkpoint from iteration {int(iteration):7d} to {path}')
if self.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(iter_finalize_fn)
return async_save_request
@_debug_time('DistributedCheckpointIO.load_checkpoint')
def load_checkpoint(
self,
path: _PATH,
map_location: Optional[Any] = None,
sharded_state_dict: Dict[str, Any] = None,
strict: Union[None, bool, 'StrictHandling'] = None,
validate_access_integrity: Optional[bool] = True,
) -> Dict[str, Any]:
"""Loads a distributed checkpoint.
Args:
path (_PATH): checkpoint directory
map_location (Any, optional): required to be None in this implementation
sharded_state_dict (Dict[str, Any], optional): state dict which
defines the loading procedure for the distributed checkpoint.
Defaults to None to comply with the CheckpointIO interface,
but it's a required argument.
strict (bool, StrictHandling, optional): adjust load strictness. bool value
is translated to StrictHandling instance. Gets overwritten by
`self.load_strictness`. Defaults to None. If `self.load_strictness`
is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.
Returns:
Dist[str, Any]: loaded checkpoint.
"""
if sharded_state_dict is None:
raise ValueError('DistributedCheckpointIO requires passing sharded_state_dict argument to load_checkpoint')
if map_location is not None:
raise ValueError('DistributedCheckpointIO doesnt handle map_location argument')
if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
else:
sharded_strategy = None
if self.parallel_load:
if sharded_strategy is None:
sharded_strategy = get_default_load_sharded_strategy(path)
sharded_strategy = FullyParallelLoadStrategyWrapper(
sharded_strategy, get_data_parallel_group(with_context_parallel=True)
)
if sharded_strategy is not None:
logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')
if isinstance(strict, bool):
# For backward-compatibility reasons and a bug in MCore (strict check not applied to factories)
# we must apply a simple strict check here.
if not strict:
sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
if self.load_strictness is not None:
# Overwrites function argument
strict = self.load_strictness
if strict is None:
# Default behavior
strict = StrictHandling.ASSUME_OK_UNEXPECTED
logging.debug(f'Dist ckpt load strictness: {strict}')
start_time = time()
ret = dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
checkpoint_dir=path,
sharded_strategy=sharded_strategy,
validate_access_integrity=validate_access_integrity,
strict=strict,
)
end_time = time()
duration = end_time - start_time
logging.info(
"Global Checkpoint Load : "
f"Rank : {torch.distributed.get_rank()} : "
f"Start time : {start_time:.3f}s : "
f"Time spent in load_checkpoint: {duration:.3f}s"
)
return ret
def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
"""Remove unexpected keys from being loaded into the state dict."""
ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path)
loaded_keys = []
unexpected_keys = []
def should_remove_missing_sharded_base(x: Any):
if isinstance(x, ShardedBase):
if x.key in ckpt_sharded_metadata:
loaded_keys.append(x.key)
return False
else:
unexpected_keys.append(x.key)
return True
return False
_, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base)
logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}')
# TODO: compute missing_keys by:
# 1. all_gather_object of loaded_keys
# 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys
return sharded_state_dict
@_debug_time('DistributedCheckpointIO.remove_checkpoint')
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove a distributed checkpoint.
Due to potentially large number of files, the implementation remove the whole directory at once.
"""
shutil.rmtree(path, ignore_errors=True)
@property
def save_sharded_strategy(self) -> 'SaveShardedStrategy':
"""Conditionally initialize and get the sharded strategy to use for saving."""
if self._save_sharded_strategy is None:
self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy()
return self._save_sharded_strategy
def _determine_dist_ckpt_save_strategy(self):
"""Determine the saving strategy based on constructor args.
Relies on the default MCore strategy unless extra PyT Distributed format arguments
are passed in config or in case of a fully parallel save in which case
a parallelization wrapper is applied.
"""
if self.save_ckpt_format == 'zarr':
logging.warning(
'`zarr` distributed checkpoint backend is deprecated.'
' Distributed optimizer checkpoint saving might be extremely slow.'
' Please switch to PyTorch Distributed format (model.dist_ckpt_format=torch_dist).'
)
if self.async_save and self.save_ckpt_format != 'torch_dist':
raise ValueError('Async dist-ckpt save supported only for torch_dist format')
torch_dist_kwargs = {} if self.torch_dist_multiproc is None else dict(thread_count=self.torch_dist_multiproc)
if self.save_ckpt_format == 'torch_dist' and torch_dist_kwargs:
save_strategy = TorchDistSaveShardedStrategy(self.save_ckpt_format, 1, **torch_dist_kwargs)
else:
save_strategy = get_default_save_sharded_strategy(self.save_ckpt_format, 1)
# MCore v0.8 introduces `use_cached_ckpt_structure` attribute
if hasattr(save_strategy, 'use_cached_ckpt_structure'):
save_strategy.use_cached_ckpt_structure = self.assume_constant_structure
if self.parallel_save:
parallelization_group = (
get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
)
save_strategy = FullyParallelSaveStrategyWrapper(
save_strategy, parallelization_group, self.assume_constant_structure
)
logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
return save_strategy
def _get_iteration_from_checkpoint(checkpoint: Dict[str, Any]) -> Optional[int]:
return (
checkpoint.get("loops", {})
.get("fit_loop", {})
.get("epoch_loop.batch_progress", {})
.get("total", {})
.get("completed", None)
)