Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import warnings | |
| from copy import deepcopy | |
| import fiddle as fdl | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.loops import _TrainingEpochLoop | |
| from lightning.pytorch.loops.fetchers import _DataFetcher | |
| from typing_extensions import Self | |
| from nemo.lightning.fabric.conversion import to_fabric | |
| from nemo.lightning.fabric.fabric import Fabric | |
| from nemo.lightning.io.mixin import IOMixin, serialization, track_io | |
| class NoValOnRestartTrainingLoop(_TrainingEpochLoop): | |
| """ | |
| Extend the PTL Epoch loop to skip validation when restarting. | |
| This happens when resuming a checkpoint that has already run validation, but loading restores | |
| the training state before validation has run. | |
| """ | |
| def _should_check_val_fx(self, data_fetcher) -> bool: | |
| if self.skip_val_on_restart: | |
| return False | |
| return super()._should_check_val_fx(data_fetcher) | |
| def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: | |
| super().load_state_dict(state_dict, prefix) | |
| self.skip_val_on_restart = True | |
| def advance(self, data_fetcher: _DataFetcher) -> None: | |
| super().advance(data_fetcher) | |
| self.skip_val_on_restart = False | |
| def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: | |
| if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): | |
| warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) | |
| return | |
| ## Pass trainer object to avoid trainer getting overwritten as None | |
| loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) | |
| trainer.fit_loop.epoch_loop = loop | |
| class Trainer(pl.Trainer, IOMixin): | |
| def add_io(self, obj): | |
| """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" | |
| if isinstance(obj, (dict, list)): | |
| if isinstance(obj, dict): | |
| obj = obj.values() | |
| for item in obj: | |
| self.add_io(item) | |
| else: | |
| if not serialization.find_node_traverser(type(obj)): | |
| track_io(type(obj)) | |
| return | |
| def io_init(self, **kwargs) -> fdl.Config[Self]: | |
| # Each argument of the trainer can be stateful so we copy them | |
| cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()} | |
| self.add_io(cfg_kwargs) | |
| return fdl.Config(type(self), **cfg_kwargs) | |
| def to_fabric(self, callbacks=None, loggers=None) -> Fabric: | |
| accelerator, devices, strategy, plugins, num_nodes = None, None, None, None, None | |
| if hasattr(self.__io__, "devices"): | |
| devices = self.__io__.devices | |
| if hasattr(self.__io__, "accelerator"): | |
| accelerator = self.__io__.accelerator | |
| if hasattr(self.__io__, "strategy"): | |
| strategy = self.__io__.strategy | |
| if isinstance(strategy, fdl.Config): | |
| strategy = fdl.build(strategy) | |
| strategy = to_fabric(strategy) | |
| if hasattr(self.__io__, "plugins"): | |
| plugins = self.__io__.plugins | |
| if isinstance(plugins, fdl.Config): | |
| plugins = fdl.build(plugins) | |
| plugins = to_fabric(plugins) | |
| if hasattr(self.__io__, "num_nodes"): | |
| num_nodes = self.__io__.num_nodes | |
| out = Fabric( | |
| devices=devices, | |
| accelerator=accelerator, | |
| strategy=strategy, | |
| plugins=plugins, | |
| callbacks=callbacks, | |
| loggers=loggers, | |
| num_nodes=num_nodes, | |
| ) | |
| return out | |