# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path from typing import Any, Callable, Optional, Type, overload import fiddle as fdl import lightning.pytorch as pl from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load from nemo.lightning.io.pl import TrainerContext @overload def load_context(path: Path, subpath: Optional[str] = None, build: bool = True) -> TrainerContext: ... @overload def load_context(path: Path, subpath: Optional[str] = None, build: bool = False) -> fdl.Config[TrainerContext]: ... def load_context(path: Path, subpath: Optional[str] = None, build: bool = True): """ Loads a TrainerContext from a json-file or directory. Args: path (Path): The path to the json-file or directory containing 'io.json'. subpath (Optional[str]): Subpath to selectively load only specific objects inside the TrainerContext. Defaults to None. build (bool): Whether to build the TrainerContext. Defaults to True. Otherwise, the TrainerContext is returned as a Config[TrainerContext] object. Returns ------- TrainerContext: The loaded TrainerContext instance. Example: # Load the entire context checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint") # Load a subpath of the context, for eg: model.config checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint", subpath="model.config") """ if not isinstance(path, Path): path = Path(path) try: return load(path, output_type=TrainerContext, subpath=subpath, build=build) except FileNotFoundError: # Maintain backwards compatibility with checkpoints that don't have '/context' dir. if path.parts[-1] == 'context': path = path.parent else: path = path / 'context' return load(path, output_type=TrainerContext, subpath=subpath, build=build) def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: """ Registers an importer for a model with a specified file extension and an optional default path. Args: target (Type[ConnectorMixin]): The model class to which the importer will be attached. ext (str): The file extension associated with the model files to be imported. default_path (Optional[str]): The default path where the model files are located, if any. Returns ------- Callable[[Type[ConnT]], Type[ConnT]]: A decorator function that registers the importer to the model class. Example: @model_importer(MyModel, "hf") class MyModelHfImporter(io.ModelConnector): ... """ return target.register_importer(ext) def model_exporter(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: """ Registers an exporter for a model with a specified file extension and an optional default path. Args: target (Type[ConnectorMixin]): The model class to which the exporter will be attached. ext (str): The file extension associated with the model files to be exported. default_path (Optional[str]): The default path where the model files will be saved, if any. Returns ------- Callable[[Type[ConnT]], Type[ConnT]]: A decorator function that registers the exporter to the model class. Example: @model_exporter(MyModel, "hf") class MyModelHFExporter(io.ModelConnector): ... """ return target.register_exporter(ext) def import_ckpt( model: pl.LightningModule, source: str, output_path: Optional[Path] = None, overwrite: bool = False, **kwargs ) -> Path: """ Imports a checkpoint into a model using the model's associated importer, typically for the purpose of fine-tuning a community model trained in an external framework, such as Hugging Face. This function leverages the ConnectorMixin interface to integrate external checkpoint data seamlessly into the specified model instance. The importer component of the model reads the checkpoint data from the specified source and transforms it into the right format. This is particularly useful for adapting models that have been pre-trained in different environments or frameworks to be fine-tuned or further developed within the current system. The function allows for specifying an output path for the imported checkpoint; if not provided, the importer's default path will be used. The 'overwrite' parameter enables the replacement of existing data at the output path, which is useful when updating models with new data and discarding old checkpoint files. For instance, using `import_ckpt(Mistral7BModel(), "hf")` initiates the import process by searching for a registered model importer tagged with "hf". In NeMo, `HFMistral7BImporter` is registered under this tag via: `@io.model_importer(Mistral7BModel, "hf", default_path="mistralai/Mistral-7B-v0.1")`. This links `Mistral7BModel` to `HFMistral7BImporter`, designed for HuggingFace checkpoints. The importer then processes and integrates these checkpoints into `Mistral7BModel` for further fine-tuning. Args: model (pl.LightningModule): The model into which the checkpoint will be imported. This model must implement the ConnectorMixin, which includes the necessary importer method for checkpoint integration. source (str): The source from which the checkpoint will be imported. This can be a file path, URL, or any other string identifier that the model's importer can recognize. output_path (Optional[Path]): The path where the imported checkpoint will be stored. If not specified, the importer's default path is used. overwrite (bool): If set to True, existing files at the output path will be overwritten. This is useful for model updates where retaining old checkpoint files is not required. Returns ------- Path: The path where the checkpoint has been saved after import. This path is determined by the importer, based on the provided output_path and its internal logic. Raises ------ ValueError: If the model does not implement ConnectorMixin, indicating a lack of necessary importer functionality. Example: model = Mistral7BModel() imported_path = import_ckpt(model, "hf://mistralai/Mistral-7B-v0.1") """ if not isinstance(model, ConnectorMixin): raise ValueError("Model must be an instance of ConnectorMixin") importer: ModelConnector = model.importer(source) ckpt_path = importer(overwrite=overwrite, output_path=output_path, **kwargs) importer.on_import_ckpt(model) return ckpt_path def load_connector_from_trainer_ckpt(path: Path, target: str) -> ModelConnector: """ Loads a ModelConnector from a trainer checkpoint for exporting the model to a different format. This function first loads the model from the trainer checkpoint using the TrainerContext, then retrieves the appropriate exporter based on the target format. Args: path (Path): Path to the trainer checkpoint directory or file. target (str): The target format identifier for which to load the connector (e.g., "hf" for HuggingFace format). Returns: ModelConnector: The loaded connector instance configured for the specified target format. Raises: ValueError: If the loaded model does not implement ConnectorMixin. Example: connector = load_connector_from_trainer_ckpt( Path("/path/to/checkpoint"), "hf" ) """ model: pl.LightningModule = load_context(path, subpath="model") if not isinstance(model, ConnectorMixin): raise ValueError("Model must be an instance of ConnectorMixin") return model.exporter(target, path) def _verify_peft_export(path: Path, target: str): if target == "hf" and (path / "weights" / "adapter_metadata.json").exists(): raise ValueError( f"Your checkpoint \n`{path}`\ncontains PEFT weights, but your specified export target `hf` should be " f"used for full model checkpoints. " f"\nIf you want to convert NeMo 2 PEFT to Hugging Face PEFT checkpoint, set `target='hf-peft'`. " f"If you want to merge LoRA weights back to the base model and export the merged full model, " f"run `llm.peft.merge_lora` first before exporting. See " f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/peft_nemo2.html for more details." ) def export_ckpt( path: Path, target: str, output_path: Optional[Path] = None, overwrite: bool = False, load_connector: Callable[[Path, str], ModelConnector] = load_connector_from_trainer_ckpt, modelopt_export_kwargs: dict[str, Any] = None, **kwargs, ) -> Path: """ Exports a checkpoint from a model using the model's associated exporter, typically for the purpose of sharing a model that has been fine-tuned or customized within NeMo. This function leverages the ConnectorMixin interface to seamlessly integrate the model's state into an external checkpoint format. The exporter component of the model reads the model's state from the specified path and exports it into the format specified by the 'target' identifier. This is particularly useful for adapting models that have been developed or fine-tuned within the current system to be compatible with other environments or frameworks. The function allows for specifying an output path for the exported checkpoint; if not provided, the exporter's default path will be used. The 'overwrite' parameter enables the replacement of existing data at the output path, which is useful when updating models with new data and discarding old checkpoint files. Args: path (Path): The path to the model's checkpoint file from which data will be exported. target (str): The identifier for the exporter that defines the format of the export. output_path (Optional[Path]): The path where the exported checkpoint will be saved. If not specified, the exporter's default path is used. overwrite (bool): If set to True, existing files at the output path will be overwritten. This is useful for model updates where retaining old checkpoint files is not required. load_connector (Callable[[Path, str], ModelConnector]): A function to load the appropriate exporter based on the model and target format. Defaults to `load_connector_from_trainer_ckpt`. modelopt_export_kwargs (Dict[str, Any]): Additional keyword arguments for ModelOpt export to HuggingFace. Returns ------- Path: The path where the checkpoint has been saved after export. This path is determined by the exporter, based on the provided output_path and its internal logic. Raises ------ ValueError: If the model does not implement ConnectorMixin, indicating a lack of necessary exporter functionality. Example: nemo_ckpt_path = Path("/path/to/model.ckpt") export_path = export_ckpt(nemo_ckpt_path, "hf") """ from nemo.collections.llm.modelopt.quantization.quantizer import export_hf_checkpoint _output_path = output_path or Path(path) / target if target == "hf": try: modelopt_export_kwargs = modelopt_export_kwargs or {} # First try to export via ModelOpt route. If rejected, return to the default route output = export_hf_checkpoint(path, _output_path, **modelopt_export_kwargs) except RuntimeError: output = None if output is not None: return output _verify_peft_export(path, target) exporter: ModelConnector = load_connector(path, target) return exporter(overwrite=overwrite, output_path=_output_path, **kwargs)