subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Protocol, Sequence, Type, TypeVar, Union, runtime_checkable
import fiddle as fdl
import lightning.fabric as lb
import lightning.pytorch as pl
from torch import nn
from typing_extensions import Self, override
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io.mixin import IOMixin, serialization, track_io
if TYPE_CHECKING:
from megatron.core.optimizer import OptimizerConfig
ModelT = TypeVar("ModelT", bound=nn.Module)
class Fabric(lb.Fabric, IOMixin):
def io_init(self, **kwargs) -> fdl.Config[Self]:
# Each argument of the trainer can be stateful so we copy them
cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()}
for val in cfg_kwargs.values():
if not serialization.find_node_traverser(type(val)):
track_io(type(val))
return fdl.Config(type(self), **cfg_kwargs)
def load_model(
self,
path: Union[str, Path],
model: Optional[ModelT] = None,
) -> "DistributedModel[ModelT]":
"""Load and set up a model for distributed training.
This method loads a model from the given path, sets it up for distributed training
using the current Fabric instance, and returns a DistributedModel.
Args:
path (Union[str, Path]): The path to the saved model checkpoint.
model (Optional[ModelT], optional): An optional pre-instantiated model. If not
provided, the model will be loaded from the checkpoint. Defaults to None.
Returns:
DistributedModel[ModelT]: The loaded and distributed model.
Example:
>>> from nemo import lightning as nl
>>>
>>> trainer = nl.Trainer(
... devices=2,
... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2),
... plugins=nl.MegatronMixedPrecision(precision='16-mixed')
... )
>>> fabric = trainer.to_fabric()
>>> distributed_model = fabric.load_model("path/to/checkpoint/dir")
>>>
>>> # You can now interact with the parallel model
"""
self.launch()
from nemo.lightning.io import load_context
path = Path(path)
if model is None:
context = load_context(ckpt_to_context_subdir(path))
model = context.model
dist_model = self.setup_module(model)
self.load(path, {"state_dict": dist_model})
return dist_model
def import_model(
self,
path: Union[str, Path],
model_type: Type[ModelT],
) -> "DistributedModel[ModelT]":
"""
Import a model from a given path and set it up for distributed training.
This method imports a model of the specified type from the given path, loads it,
and sets it up for distributed training using the current Fabric instance.
Args:
path (Union[str, Path]): The path to the model. Can be a local path or a
Hugging Face model identifier.
model_type (Type[ModelT]): The type of the model to import. Must be a subclass
of ConnectorMixin.
Returns:
DistributedModel[ModelT]: The imported and distributed model.
Raises:
TypeError: If the provided model_type is not a subclass of ConnectorMixin.
Example:
>>> from nemo import lightning as nl
>>> from nemo.collections.llm import MistralModel
>>>
>>> trainer = nl.Trainer(
... devices=2,
... strategy=nl.MegatronStrategy(tensor_model_parallel_size=2),
... plugins=nl.MegatronMixedPrecision(precision='16-mixed')
... )
>>> fabric = trainer.to_fabric()
>>> model = fabric.import_model("hf://mistralai/Mistral-7B-v0.1", MistralModel)
>>>
>>> # You can now interact with the parallel model
"""
from nemo.lightning.io import ConnectorMixin
if not issubclass(model_type, ConnectorMixin):
raise TypeError("The provided model class must be a subclass of ConnectorMixin")
model: ModelT = model_type.import_from(path)
return self.load_model(model.ckpt_path, model)
@override
def setup_module(self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True):
from nemo.lightning.fabric.strategies import FabricMegatronStrategy
out = super().setup_module(module, move_to_device=move_to_device, _reapply_compile=_reapply_compile)
# We don't want to return a _FabricModule for megatron since we only want to precision convert
# at the beginning and end of the pipeline
if isinstance(self.strategy, FabricMegatronStrategy):
return out._forward_module
return out
def setup_datamodule(self, datamodule: pl.LightningDataModule, stage: str = "") -> pl.LightningDataModule:
datamodule.setup(stage)
if hasattr(self.strategy, "process_datamodule"):
datamodule = self.strategy.process_datamodule(datamodule)
return datamodule
@runtime_checkable
class DistributedModel(Protocol[ModelT]):
module: ModelT