Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import logging | |
| import os.path | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Dict, Union | |
| import numpy | |
| # tenosrstore is needed to register 'bfloat16' dtype with numpy for zarr compatibility | |
| import tensorstore # noqa: F401 pylint: disable=unused-import | |
| import torch | |
| from torch.distributed.checkpoint import FileSystemReader, load | |
| from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata | |
| from nemo.export.tarutils import TarPath, ZarrPathStore | |
| from nemo.export.utils._mock_import import _mock_import | |
| LOGGER = logging.getLogger("NeMo") | |
| def nemo_to_path(nemo_checkpoint: Union[Path, str]) -> Union[Path, TarPath]: | |
| """ | |
| Creates Path / TarPath object suitable for navigating inside the nemo checkpoint. | |
| Args: | |
| nemo_checkpoint (Path, str): Path to the NeMo checkpoint. | |
| Returns: | |
| Path | TarPath: Suitable Path object for navigating through the checkpoint. | |
| """ | |
| string_path = str(nemo_checkpoint) | |
| if os.path.isdir(string_path): | |
| return Path(string_path) | |
| return TarPath(string_path) | |
| class TarFileSystemReader(FileSystemReader): | |
| """Reader that accepts both Path and TarPath checkpoint directory. | |
| The FileSystemReader works with TarPath, but expects a pure Path. | |
| It's enough to skip the Path check in __init__. | |
| """ | |
| def __init__(self, path: Union[Path, TarPath]) -> None: | |
| """Makes sure that super().__init__ gets a pure path as expected.""" | |
| super_path = str(path) if isinstance(path, TarPath) else path | |
| super().__init__(super_path) | |
| if isinstance(path, TarPath): | |
| self.path = path # overwrites path set in super().__init__ call | |
| def load_sharded_metadata_torch_dist( | |
| checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False | |
| ) -> Dict[str, Any]: | |
| """ | |
| Loads model state dictionary from torch_dist checkpoint. | |
| Args: | |
| checkpoint_dir (Path | TarPath): Path to the model weights directory. | |
| load_extra_states (bool): If set to true, loads BytesIO objects, related to the extra states. | |
| Returns: | |
| dict: Loaded model state dictionary (weights are stored in torch tensors). | |
| """ | |
| fs_reader = TarFileSystemReader(checkpoint_dir) | |
| metadata = fs_reader.read_metadata() | |
| state_dict = { | |
| k: torch.empty(tp.size, dtype=tp.properties.dtype) | |
| for k, tp in metadata.state_dict_metadata.items() | |
| if isinstance(tp, TensorStorageMetadata) | |
| } | |
| if load_extra_states: | |
| state_dict.update( | |
| {k: [] for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, BytesStorageMetadata)} | |
| ) | |
| load(state_dict, storage_reader=fs_reader) | |
| return state_dict | |
| def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]) -> Dict[str, BytesIO]: | |
| """ | |
| Loads model extra states from the .pt shards. | |
| Args: | |
| dir (Path | TarPath): Path to the directory with sharded extra states. | |
| Returns: | |
| dict: State dictionary corresponding to the loaded extra states. | |
| """ | |
| pt_files = list(dir.glob('shard_*_*.pt')) | |
| extra_states = {} | |
| for file in pt_files: | |
| shard_name = file.name.split('.')[0] | |
| with file.open('rb') as opened_file: | |
| extra_states[dir.name + '/' + shard_name] = torch.load(opened_file, weights_only=True) | |
| return extra_states | |
| def contains_extra_states(subdir: Union[Path, TarPath]) -> bool: | |
| """ | |
| Checks if zarr directory contains extra states. | |
| Args: | |
| subdir (Path | TarPath): Directory inside the zarr checkpoint. | |
| Returns: | |
| bool: Is a directory with extra states | |
| """ | |
| return list(subdir.glob('shard_0_*.pt')) != [] | |
| def load_sharded_metadata_zarr( | |
| checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False | |
| ) -> Dict[str, Any]: | |
| """ | |
| Loads model dictionary from the zarr format. | |
| Args: | |
| checkpoint_dir (Path | TarPath): Path to the NeMo checkpoint. | |
| load_extra_states (bool): If set to True, the function will load BufferIO objects with extra states. | |
| Returns: | |
| dict: Model state dictionary. | |
| """ | |
| if load_extra_states: | |
| torch.serialization.add_safe_globals([BytesIO]) | |
| sharded_state_dict = {} | |
| for subdir in checkpoint_dir.iterdir(): | |
| if not subdir.is_dir(): | |
| continue | |
| if load_extra_states and contains_extra_states(subdir): | |
| sharded_state_dict.update(load_sharded_pickle_extra_state_scale(subdir)) | |
| elif (subdir / '.zarray').exists(): | |
| key = subdir.name | |
| zstore = ZarrPathStore(subdir) | |
| import zarr | |
| arr = zarr.open(zstore, 'r') | |
| if arr.dtype.name == "bfloat16": | |
| sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) | |
| else: | |
| sharded_state_dict[key] = torch.from_numpy(arr[:]) | |
| return sharded_state_dict | |
| def nemo_weights_directory(nemo_path: Union[Path, TarPath]) -> Union[Path, TarPath]: | |
| """ | |
| Returns a Path pointing to the weights directory inside the NeMo checkpoint. | |
| Args: | |
| nemo_path (Path | TarPath): Path to the nemo checkpoint. | |
| Returns: | |
| Path | TarPath: Path to the weights directory inside the model checkpoint. | |
| """ | |
| if (nemo_path / "model_weights").exists(): | |
| return nemo_path / "model_weights" | |
| if (nemo_path / "weights").exists(): | |
| return nemo_path / "weights" | |
| return nemo_path | |
| def load_model_weights(checkpoint_path: Union[str, Path], load_extra_states: bool = False) -> Dict[str, Any]: | |
| """ | |
| Loads NeMo state dictionary. Weights are stored in torch.Tensor | |
| Args: | |
| checkpoint_path (str | Path): Path to the NeMo checkpoint. | |
| load_extra_states (bool): If True, loads BytesIO objects, corresponding to the extra states. | |
| Returns: | |
| dict: Model state dictionary. | |
| """ | |
| nemo_path = nemo_to_path(checkpoint_path) | |
| nemo_weights = nemo_weights_directory(nemo_path) | |
| with (nemo_weights / 'metadata.json').open(mode='r') as f: | |
| config_dict = json.load(f) | |
| if config_dict['sharded_backend'] == 'zarr': | |
| return load_sharded_metadata_zarr(nemo_weights, load_extra_states=load_extra_states) | |
| elif config_dict['sharded_backend'] == 'torch_dist': | |
| # TODO: Remove mocking imports once MCore is available in NIM containers | |
| with _mock_import("megatron.core.dist_checkpointing.strategies.torch"): | |
| return load_sharded_metadata_torch_dist(nemo_weights, load_extra_states=load_extra_states) | |
| raise NotImplementedError(f'Distributed checkpoint backend {config_dict["sharded_backend"]} not supported') | |