# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import functools import inspect import json import shutil import threading import types import uuid from copy import deepcopy from pathlib import Path from pydoc import locate from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import fiddle as fdl import fiddle._src.experimental.dataclasses as fdl_dc import lightning.pytorch as pl from cloudpickle import dump from cloudpickle import load as pickle_load from fiddle._src import config as config_lib from fiddle._src import partial from fiddle._src.experimental import serialization from typing_extensions import Self from nemo.lightning.io.artifact.base import Artifact from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.connector import ModelConnector from nemo.lightning.io.fdl_torch import enable as _enable_ext from nemo.lightning.io.to_config import to_config from nemo.utils import logging ConnT = TypeVar("ConnT", bound=ModelConnector) CkptType = TypeVar("CkptType") _enable_ext() # Thread-local storage for artifacts directory _thread_local = threading.local() def _is_default_factory(arg: Any) -> bool: return arg == dataclasses._HAS_DEFAULT_FACTORY def _ordered_arguments_with_default(data: config_lib.Config) -> Dict[Union[int, str], Any]: result = config_lib.ordered_arguments(data, include_defaults=True) for key, arg in result.items(): if isinstance(arg, config_lib.Config): ordered_arg = _ordered_arguments_with_default(arg) result[key] = ordered_arg elif _is_default_factory(arg): if dataclasses.is_dataclass(data.__fn_or_cls__): fields = dataclasses.fields(data.__fn_or_cls__) for field in fields: if field.name == key: result[key] = field.default_factory() break if "__fn_or_cls__" in result: raise ValueError( "It is not supported to dump objects of functions/classes " "that have a __fn_or_cls__ parameter." ) result["_target_"] = ( f"{inspect.getmodule(config_lib.get_callable(data)).__name__}" f".{config_lib.get_callable(data).__qualname__}" ) if isinstance(data, partial.Partial): result["_partial_"] = True return result def _config_representer_with_defaults(dumper, data, type_name="Config"): """Returns a YAML representation of `data`.""" value = _ordered_arguments_with_default(data) return dumper.represent_data(value) def _partial_representer_with_defaults(dumper, data): return _config_representer_with_defaults(dumper, data, type_name="Partial") def _safe_object_representer(dumper, data): """ Represent a given object as YAML using the specified dumper. This function is a fallback for objects that don't have specific representers. If the object has __qualname__ attr, the __target__ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". If the object does not have a __qualname__ attr, the __target__ is set from its __class__ attr. The __call__ key is used to indicate whether the target should be called to create an instance. Args: dumper (yaml.Dumper): The YAML dumper to use for serialization. data (Any): The data to serialize. This can be any Python object, but if it's a class or a class instance, special handling will be applied. Returns: str: The YAML representation of the data. """ try: obj = data target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" call = False except AttributeError: obj = data.__class__ target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" call = True value = { "_target_": target, # type: ignore "_call_": call, } return dumper.represent_data(value) class IOMixin: """ A mixin class designed to capture the arguments passed to the `__init__` method, facilitating the re-creation of the object through `io.reinit` method using stored configurations. This class intercepts the initialization of an object to store the arguments in a configuration object, which can be serialized and later used to reinitialize the object to its original state. It utilizes `fdl.Config` from the Fiddle library to create a structured configuration object that holds the initialization parameters. This configuration object is crucial for enabling serialization and deserialization of the parameters, thus allowing the object to be reconstructed at a later time with the same initial state. Attributes ---------- __io__ (fdl.Config[Self]): A configuration object that stores the captured initialization parameters in a structured format. This object is an instance of `fdl.Config`, which allows for the serialization and deserialization of the parameters, enabling the object to be reconstructed at a later time with the same initial state. Examples -------- from nemo.lightning import io class ExampleClass(io.IOMixin): def __init__(self, param1, param2): super().__init__() self.param1 = param1 self.param2 = param2 # Creating an instance of ExampleClass example = ExampleClass('value1', 'value2') example_copy = io.reinit(example) Note: For more information on `fdl.Config`, refer to the Fiddle library documentation at [Fiddle Config Documentation](https://fiddle.readthedocs.io/en/latest/api_reference/core.html#config). """ __io__: fdl.Config[Self] def __new__(cls, *args, **kwargs): """ Overrides the default object creation process to wrap the `__init__` method, allowing initialization arguments to be captured and stored in the `__io__` attribute. Args: *args: Variable length argument list for the `__init__` method. **kwargs: Arbitrary keyword arguments for the `__init__` method. Returns ------- The newly created object instance. """ cls = _io_wrap_init(cls) output = object().__new__(cls) return output def __init_subclass__(cls): _io_register_serialization(cls) # Add OneLogger timing hooks for data modules to enable telemetry tracking if issubclass(cls, pl.LightningDataModule): from nemo.lightning.callback_group import hook_class_init_with_callbacks hook_class_init_with_callbacks(cls, "on_dataloader_init_start", "on_dataloader_init_end") super().__init_subclass__() def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: """ Transforms and captures the arguments passed to the `__init__` method, filtering out any arguments that are instances of `IOProtocol` or are dataclass fields with default factories. Args: init_fn (Callable): The original `__init__` method of the class. *args: Variable length argument list for the `__init__` method. **kwargs: Arbitrary keyword arguments for the `__init__` method. Returns ------- Dict[str, Any]: A dictionary of the captured and transformed arguments. """ return _io_transform_args(self, init_fn, *args, **kwargs) def io_init(self, **kwargs) -> fdl.Config[Self]: """ Initializes the configuration object (`__io__`) with the captured arguments. Args: **kwargs: A dictionary of arguments that were captured during object initialization. Returns ------- fdl.Config[Self]: The initialized configuration object. """ return _io_init(self, **kwargs) @classmethod def io_artifacts(cls) -> List[Artifact]: """Initialize io artifacts""" return [] def io_dump(self, output: Path, yaml_attrs: list[str]): """ Serializes the configuration object (`__io__`) to a file, allowing the object state to be saved and later restored. Also creates an artifacts directory and stores it in a thread-local global variable. If the artifacts directory is empty at the end, it is deleted. Args: output (Path): The path to the directory where the configuration object and artifacts will be stored. """ output_path = Path(output) local_artifacts_dir = "." artifacts_dir = output_path / local_artifacts_dir artifacts_dir.mkdir(parents=True, exist_ok=True) # Store artifacts directory in thread-local storage _thread_local.local_artifacts_dir = local_artifacts_dir _thread_local.output_path = output_path config_path = output_path / "io.json" with open(config_path, "w") as f: io = _artifact_transform_save(self, deepcopy(self.__io__), output_path, local_artifacts_dir) json = serialization.dump_json(io) f.write(json) yaml_configs = self._io_dump_yaml(io, attrs=yaml_attrs) for attr, serialized_str in yaml_configs.items(): _path = output_path / f"{attr}.yaml" _path.write_text(serialized_str) # Clear thread-local storage after io_dump is complete del _thread_local.local_artifacts_dir del _thread_local.output_path # Check if artifacts directory is empty and delete if so if not any(artifacts_dir.iterdir()): shutil.rmtree(artifacts_dir) def _io_dump_yaml(self, io: config_lib.Config, attrs: list[str]): import yaml original_representers = yaml.SafeDumper.yaml_representers.copy() from nemo_run.config import Config, Partial from nemo_run.core.serialization.yaml import YamlSerializer yaml.SafeDumper.add_representer(config_lib.Config, _config_representer_with_defaults) yaml.SafeDumper.add_representer(partial.Partial, _partial_representer_with_defaults) yaml.SafeDumper.add_representer(Config, _config_representer_with_defaults) yaml.SafeDumper.add_representer(Partial, _partial_representer_with_defaults) yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) serializer = YamlSerializer() result = {} for attr in attrs: result[attr] = serializer.serialize(getattr(io, attr)) yaml.SafeDumper.yaml_representers = original_representers return result class ConnectorMixin: """ A mixin class that provides methods to register and retrieve model connectors for importing and exporting models. This class supports dynamic registration of connectors based on file extensions, which facilitates the customization and extension of model serialization and deserialization processes. Attributes ---------- _IMPORTERS (Dict[str, Type[ModelConnector]]): A dictionary mapping file extensions to model connector classes that handle the import process. _EXPORTERS (Dict[str, Type[ModelConnector]]): A dictionary mapping file extensions to model connector classes that handle the export process. """ _IMPORTERS: Dict[str, Type[ModelConnector]] = {} _EXPORTERS: Dict[str, Type[ModelConnector]] = {} @classmethod def import_from(cls, path: str) -> Self: """ Creates an instance of a model by using the appropriate importer based on the file extension of the provided path. Args: path (str): The path to the model file to be imported. Example: from nemo.collections import llm model = llm.Mistral7BModel.import_from("hf") Returns ------- Self: An instance of the model initialized from the imported data. """ output = cls._get_connector(path).init() output.ckpt_path = output.import_ckpt(path) return output @classmethod def register_importer(cls, ext: str, default_path: Optional[str] = None) -> Callable[[Type[ConnT]], Type[ConnT]]: """ A class method decorator to register a model connector as an importer for a specific file extension. Args: ext (str): The file extension to associate with the model connector. default_path (Optional[str]): The default path to use if no path is specified during import. Returns ------- Callable[[Type[ConnT]], Type[ConnT]]: The decorator that registers the model connector. """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: cls._IMPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector return decorator @classmethod def register_exporter(cls, ext: str, default_path: Optional[str] = None) -> Callable[[Type[ConnT]], Type[ConnT]]: """ A class method decorator to register a model connector as an exporter for a specific file extension. Args: ext (str): The file extension to associate with the model connector. default_path (Optional[str]): The default path to use if no path is specified during export. Returns ------- Callable[[Type[ConnT]], Type[ConnT]]: The decorator that registers the model connector. """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: cls._EXPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector return decorator @classmethod def importer(cls, path: str) -> ModelConnector: """ Retrieves the appropriate model connector for importing based on the extension of the provided path. Args: path (str): The path to the model file to be imported. Returns ------- ModelConnector: The model connector instance capable of handling the import. """ return cls._get_connector(path, importer=True) @classmethod def exporter(cls, ext: str, path: Union[str, Path]) -> ModelConnector: """ Retrieves the appropriate model connector for exporting based on the extension. Args: ext (str): The file extension associated with the model connector. path (Union[str, Path]): The path where the model will be exported. Returns ------- ModelConnector: The model connector instance capable of handling the export. """ return cls._get_connector(ext, path, importer=False) def import_ckpt(self, path: str, overwrite: bool = False, base_path: Optional[Path] = None, **kwargs) -> Path: """ Imports a checkpoint from a specified path, potentially overwriting existing files. Args: path (str): The path to the checkpoint file to be imported. overwrite (bool): Flag to determine if existing files should be overwritten (default is False). base_path (Optional[Path]): The base path where the checkpoint file is located; used to resolve relative paths. Returns ------- Path: The path to the imported checkpoint. Raises ------ FileNotFoundError: If the checkpoint file does not exist at the specified path. """ connector = self._get_connector(path, **kwargs) ckpt_path: Path = connector.local_path(base_path=base_path) ckpt_path = connector(ckpt_path, overwrite=overwrite) connector.on_import_ckpt(self) return ckpt_path @classmethod def _get_connector( cls, ext: Union[str, Path], path: Optional[Union[str, Path]] = None, importer: bool = True, **kwargs ) -> ModelConnector: """ Retrieves the appropriate model connector based on the file extension and path, distinguishing between importers and exporters. Args: ext (Union[str, Path]): The file extension or a URI that may include a protocol specifier. path (Optional[Union[str, Path]]): The path where the model file is located or will be saved. importer (bool): Flag to determine if the connector is for importing (True) or exporting (False). Returns ------- ModelConnector: The model connector instance capable of handling the import or export. Raises ------ ValueError: If no connector is found for the specified extension or if no default path is provided when required. """ _path = None ext = str(ext) if "://" in ext: ext, _path = ext.split("://") else: _path = str(path) connector = cls._IMPORTERS.get(str(cls) + ext) if importer else cls._EXPORTERS.get(str(cls) + ext) if not connector: raise ValueError(f"No connector found for extension '{ext}' for {cls}") if not _path: if not connector.default_path: raise ValueError(f"No default path specified for extension '{ext}'. ", "Please provide a path") return connector() return connector(_path, **kwargs) def track_io(target, artifacts: Optional[List[Artifact]] = None): """ Adds IO functionality to the target object or eligible classes in the target module by wrapping __init__ and registering serialization methods. Args: target (object or types.ModuleType): The target object or module to modify. Returns: object or types.ModuleType: The modified target with IO functionality added to eligible classes. Examples: >>> from nemo.collections.common import tokenizers >>> modified_tokenizers = track_io(tokenizers) >>> ModifiedWordTokenizer = track_io(tokenizers.WordTokenizer) """ def _add_io_to_class(cls): if inspect.isclass(cls) and hasattr(cls, "__init__") and not hasattr(cls, "__io__"): if cls in [str, int, float, tuple, list, dict, bool, type(None)]: return cls cls = _io_wrap_init(cls) _io_register_serialization(cls) cls.__io_artifacts__ = artifacts or [] return cls def _process_module(module): for name, obj in inspect.getmembers(module): if inspect.isclass(obj) and _is_defined_in_module_or_submodules(obj, module): setattr(module, name, _add_io_to_class(obj)) return module def _is_defined_in_module_or_submodules(obj, module): return obj.__module__ == module.__name__ or obj.__module__.startswith(f"{module.__name__}.") if isinstance(target, types.ModuleType): return _process_module(target) elif inspect.isclass(target): return _add_io_to_class(target) else: raise TypeError("Target must be a module or a class") def _io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: """ Transforms and captures the arguments passed to the `__init__` method, filtering out any arguments that are instances of `IOProtocol` or are dataclass fields with default factories. Args: init_fn (Callable): The original `__init__` method of the class. *args: Variable length argument list for the `__init__` method. **kwargs: Arbitrary keyword arguments for the `__init__` method. Returns ------- Dict[str, Any]: A dictionary of the captured and transformed arguments. """ sig = inspect.signature(init_fn) bound_args = sig.bind_partial(self, *args, **kwargs) config_kwargs = {k: v for k, v in bound_args.arguments.items() if k != "self"} to_del = [] for key in config_kwargs: if isinstance(config_kwargs[key], IOProtocol): config_kwargs[key] = config_kwargs[key].__io__ if dataclasses.is_dataclass(config_kwargs[key]): config_kwargs[key] = fdl_dc.convert_dataclasses_to_configs(config_kwargs[key], allow_post_init=True) # Check if the arg is a factory (dataclasses.field) if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": to_del.append(key) for key in to_del: del config_kwargs[key] return config_kwargs def _io_init(self, **kwargs) -> fdl.Config[Self]: """ Initializes the configuration object (`__io__`) with the captured arguments. Args: **kwargs: A dictionary of arguments that were captured during object initialization. Returns ------- fdl.Config[Self]: The initialized configuration object. """ try: return fdl.Config(type(self), **kwargs) except Exception as e: error_msg = ( f"Error creating fdl.Config for {type(self).__name__}: {str(e)}\n" f"Arguments that caused the error: {kwargs}\n" f"This may be due to unsupported argument types or nested configurations." ) raise RuntimeError(error_msg) from e def _io_wrap_init(cls): """Wraps the __init__ method of a class to add IO functionality.""" original_init = cls.__init__ if getattr(cls, "__wrapped_init__", False): return cls @functools.wraps(original_init) def wrapped_init(self, *args, **kwargs): if hasattr(self, "io_transform_args"): cfg_kwargs = self.io_transform_args(original_init, *args, **kwargs) else: cfg_kwargs = _io_transform_args(self, original_init, *args, **kwargs) if hasattr(self, "io_init"): self.__io__ = self.io_init(**cfg_kwargs) else: self.__io__ = _io_init(self, **cfg_kwargs) original_init(self, *args, **kwargs) cls.__init__ = wrapped_init cls.__wrapped_init__ = True return cls def _io_register_serialization(cls): serialization.register_node_traverser( cls, flatten_fn=_io_flatten_object, unflatten_fn=_io_unflatten_object, path_elements_fn=_io_path_elements_fn, ) def _io_flatten_object(instance): try: serialization.dump_json(instance.__io__) except (serialization.UnserializableValueError, AttributeError) as e: if not hasattr(_thread_local, "local_artifacts_dir") or not hasattr(_thread_local, "output_path"): raise e local_artifact_path = Path(_thread_local.local_artifacts_dir) / f"{uuid.uuid4()}" output_path = _thread_local.output_path artifact_path = output_path / local_artifact_path with open(artifact_path, "wb") as f: dump(getattr(instance, "__io__", instance), f) return (str(local_artifact_path),), None return instance.__io__.__flatten__() def _io_unflatten_object(values, metadata): if not hasattr(_thread_local, "output_dir"): return fdl.Config.__unflatten__(values, metadata) output_dir = _thread_local.output_dir if len(values) == 1: pickle_path = values[0] with open(Path(output_dir) / pickle_path, "rb") as f: return pickle_load(f) return fdl.Config.__unflatten__(values, metadata) def _io_path_elements_fn(x): try: serialization.dump_json(x.__io__) except (serialization.UnserializableValueError, AttributeError): return (serialization.IdentityElement(),) return x.__io__.__path_elements__() def _artifact_transform_save(instance, cfg: fdl.Config, output_path: Path, relative_dir: Path = Path(".")): artifacts = getattr(cfg.__fn_or_cls__, "__io_artifacts__", []) for artifact in artifacts: # Allow optional artifacts if artifact.skip or (not hasattr(cfg, artifact.attr) and not artifact.required): continue if not hasattr(cfg, artifact.attr) and artifact.required: raise ValueError(f"Artifact '{artifact.attr}' is required but not provided") current_val = getattr(cfg, artifact.attr) if current_val is None: if artifact.required: raise ValueError(f"Artifact '{artifact.attr}' is required but not provided") continue # dump artifact and return the relative path new_val = artifact.dump(instance, current_val, output_path, relative_dir) setattr(cfg, artifact.attr, new_val) for attr in dir(cfg): try: child = getattr(cfg, attr) except AttributeError: continue child = to_config(child) try: if isinstance(child, (fdl.Config, fdl.Partial)): setattr( cfg, attr, _artifact_transform_save( getattr(instance, attr, None), child, output_path=output_path, relative_dir=relative_dir, ), ) except ValueError: pass return cfg def _artifact_transform_load(cfg: fdl.Config, path: Path): for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): # We expect an artifact.attr to be a string or a fdl.Config. # Some parameteres can be a string or a filepath. When those parameters are just strings, # we will represent it with a fdl.Config, and will skip the rest of the loop (base-dir adjustment). current_val = getattr(cfg, artifact.attr) if isinstance(current_val, fdl.Config): # artifact.attr is a string not a path. setattr(cfg, artifact.attr, fdl.build(current_val).attr) continue if artifact.skip: continue current_val = getattr(cfg, artifact.attr) # __init__ arguments can be None if current_val is None: continue # replace local path with absolute one new_val = str(Path(path) / current_val) setattr(cfg, artifact.attr, new_val) for attr in dir(cfg): try: if isinstance(getattr(cfg, attr), fdl.Config): _artifact_transform_load(getattr(cfg, attr), path=path) except ValueError: pass def drop_unexpected_params(config: fdl.Config) -> bool: """ Analyzes config to detect unexpected keyword arguments -- for example, deprecated parameters -- and updates the config by dropping them. Returns True if the config gets updated and False otherwise. Args: config (fdl.Config): The configuration object to analyze. """ updated = False def analyze(config: fdl.Config, prefix: str): if isinstance(config, fdl.Config): signature = inspect.signature(config.__fn_or_cls__) accept_kwargs = any(param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) if not accept_kwargs: to_drop = [param for param in config.__arguments__ if param not in signature.parameters] if to_drop: nonlocal updated updated = True logging.warning(f"Deprecated parameters to drop from {prefix}: {to_drop}") for param in to_drop: del config.__arguments__[param] else: logging.debug(f"Skip analyzing {prefix} as it accepts arbitrary keyword arguments.") # Proceed recursively for all arguments for key, value in config.__arguments__.items(): analyze(value, prefix + "." + key) analyze(config, "") return updated def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None, build: bool = True) -> CkptType: """ Loads a configuration from a pickle file and constructs an object of the specified type. Args: path (Path): The path to the pickle file or directory containing 'io.pkl'. output_type (Type[CkptType]): The type of the object to be constructed from the loaded data. subpath (Optional[str]): Subpath to selectively load only specific objects inside the output_type. Defaults to None. Returns ------- CkptType: An instance of the specified type constructed from the loaded configuration. Raises ------ FileNotFoundError: If the specified file does not exist. Example: loaded_model = load("/path/to/model", output_type=MyModel) """ _path = Path(path) _thread_local.output_dir = _path if hasattr(_path, "is_dir") and _path.is_dir(): _path = Path(_path) / "io.json" elif hasattr(_path, "isdir") and _path.isdir: _path = Path(_path) / "io.json" if not _path.is_file(): raise FileNotFoundError(f"No such file: '{_path}'") if subpath: subpath = "." + subpath # add IO functionality to custom objects present in the json file with open(_path) as f: j = json.load(f) for obj, val in j.get("objects", {}).items(): clss = ".".join([val["type"]["module"], val["type"]["name"]]) if subpath and "paths" in val: if all(map(lambda p: subpath not in p, val["paths"])): continue if not serialization.find_node_traverser(locate(clss)): track_io(locate(clss)) with open(_path, "rb") as f: json_config = json.loads(f.read()) root_key = None for obj, val in json_config.get("objects", {}).items(): if "paths" in val and subpath in val["paths"]: root_key = obj break if subpath and not root_key: logging.warning(f"Could not find {subpath} for {output_type} in {_path}") if root_key: json_config["root"]["key"] = root_key config = serialization.Deserialization(json_config).result _artifact_transform_load(config, path) drop_unexpected_params(config) if not build: return config return fdl.build(config)