# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import os.path from io import BytesIO from pathlib import Path from typing import Any, Dict, Union import numpy # tenosrstore is needed to register 'bfloat16' dtype with numpy for zarr compatibility import tensorstore # noqa: F401 pylint: disable=unused-import import torch from torch.distributed.checkpoint import FileSystemReader, load from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata from nemo.export.tarutils import TarPath, ZarrPathStore from nemo.export.utils._mock_import import _mock_import LOGGER = logging.getLogger("NeMo") def nemo_to_path(nemo_checkpoint: Union[Path, str]) -> Union[Path, TarPath]: """ Creates Path / TarPath object suitable for navigating inside the nemo checkpoint. Args: nemo_checkpoint (Path, str): Path to the NeMo checkpoint. Returns: Path | TarPath: Suitable Path object for navigating through the checkpoint. """ string_path = str(nemo_checkpoint) if os.path.isdir(string_path): return Path(string_path) return TarPath(string_path) class TarFileSystemReader(FileSystemReader): """Reader that accepts both Path and TarPath checkpoint directory. The FileSystemReader works with TarPath, but expects a pure Path. It's enough to skip the Path check in __init__. """ def __init__(self, path: Union[Path, TarPath]) -> None: """Makes sure that super().__init__ gets a pure path as expected.""" super_path = str(path) if isinstance(path, TarPath) else path super().__init__(super_path) if isinstance(path, TarPath): self.path = path # overwrites path set in super().__init__ call def load_sharded_metadata_torch_dist( checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False ) -> Dict[str, Any]: """ Loads model state dictionary from torch_dist checkpoint. Args: checkpoint_dir (Path | TarPath): Path to the model weights directory. load_extra_states (bool): If set to true, loads BytesIO objects, related to the extra states. Returns: dict: Loaded model state dictionary (weights are stored in torch tensors). """ fs_reader = TarFileSystemReader(checkpoint_dir) metadata = fs_reader.read_metadata() state_dict = { k: torch.empty(tp.size, dtype=tp.properties.dtype) for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, TensorStorageMetadata) } if load_extra_states: state_dict.update( {k: [] for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, BytesStorageMetadata)} ) load(state_dict, storage_reader=fs_reader) return state_dict def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]) -> Dict[str, BytesIO]: """ Loads model extra states from the .pt shards. Args: dir (Path | TarPath): Path to the directory with sharded extra states. Returns: dict: State dictionary corresponding to the loaded extra states. """ pt_files = list(dir.glob('shard_*_*.pt')) extra_states = {} for file in pt_files: shard_name = file.name.split('.')[0] with file.open('rb') as opened_file: extra_states[dir.name + '/' + shard_name] = torch.load(opened_file, weights_only=True) return extra_states def contains_extra_states(subdir: Union[Path, TarPath]) -> bool: """ Checks if zarr directory contains extra states. Args: subdir (Path | TarPath): Directory inside the zarr checkpoint. Returns: bool: Is a directory with extra states """ return list(subdir.glob('shard_0_*.pt')) != [] def load_sharded_metadata_zarr( checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False ) -> Dict[str, Any]: """ Loads model dictionary from the zarr format. Args: checkpoint_dir (Path | TarPath): Path to the NeMo checkpoint. load_extra_states (bool): If set to True, the function will load BufferIO objects with extra states. Returns: dict: Model state dictionary. """ if load_extra_states: torch.serialization.add_safe_globals([BytesIO]) sharded_state_dict = {} for subdir in checkpoint_dir.iterdir(): if not subdir.is_dir(): continue if load_extra_states and contains_extra_states(subdir): sharded_state_dict.update(load_sharded_pickle_extra_state_scale(subdir)) elif (subdir / '.zarray').exists(): key = subdir.name zstore = ZarrPathStore(subdir) import zarr arr = zarr.open(zstore, 'r') if arr.dtype.name == "bfloat16": sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) else: sharded_state_dict[key] = torch.from_numpy(arr[:]) return sharded_state_dict def nemo_weights_directory(nemo_path: Union[Path, TarPath]) -> Union[Path, TarPath]: """ Returns a Path pointing to the weights directory inside the NeMo checkpoint. Args: nemo_path (Path | TarPath): Path to the nemo checkpoint. Returns: Path | TarPath: Path to the weights directory inside the model checkpoint. """ if (nemo_path / "model_weights").exists(): return nemo_path / "model_weights" if (nemo_path / "weights").exists(): return nemo_path / "weights" return nemo_path def load_model_weights(checkpoint_path: Union[str, Path], load_extra_states: bool = False) -> Dict[str, Any]: """ Loads NeMo state dictionary. Weights are stored in torch.Tensor Args: checkpoint_path (str | Path): Path to the NeMo checkpoint. load_extra_states (bool): If True, loads BytesIO objects, corresponding to the extra states. Returns: dict: Model state dictionary. """ nemo_path = nemo_to_path(checkpoint_path) nemo_weights = nemo_weights_directory(nemo_path) with (nemo_weights / 'metadata.json').open(mode='r') as f: config_dict = json.load(f) if config_dict['sharded_backend'] == 'zarr': return load_sharded_metadata_zarr(nemo_weights, load_extra_states=load_extra_states) elif config_dict['sharded_backend'] == 'torch_dist': # TODO: Remove mocking imports once MCore is available in NIM containers with _mock_import("megatron.core.dist_checkpointing.strategies.torch"): return load_sharded_metadata_torch_dist(nemo_weights, load_extra_states=load_extra_states) raise NotImplementedError(f'Distributed checkpoint backend {config_dict["sharded_backend"]} not supported')