subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from copy import deepcopy
import fiddle as fdl
import lightning.pytorch as pl
from lightning.pytorch.loops import _TrainingEpochLoop
from lightning.pytorch.loops.fetchers import _DataFetcher
from typing_extensions import Self
from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.fabric.fabric import Fabric
from nemo.lightning.io.mixin import IOMixin, serialization, track_io
class NoValOnRestartTrainingLoop(_TrainingEpochLoop):
"""
Extend the PTL Epoch loop to skip validation when restarting.
This happens when resuming a checkpoint that has already run validation, but loading restores
the training state before validation has run.
"""
def _should_check_val_fx(self, data_fetcher) -> bool:
if self.skip_val_on_restart:
return False
return super()._should_check_val_fx(data_fetcher)
def load_state_dict(self, state_dict: dict, prefix: str = "") -> None:
super().load_state_dict(state_dict, prefix)
self.skip_val_on_restart = True
def advance(self, data_fetcher: _DataFetcher) -> None:
super().advance(data_fetcher)
self.skip_val_on_restart = False
def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None:
if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop):
warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning)
return
## Pass trainer object to avoid trainer getting overwritten as None
loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps)
trainer.fit_loop.epoch_loop = loop
class Trainer(pl.Trainer, IOMixin):
def add_io(self, obj):
"""Recurse to the leaves of a container and add io functionality to non-serializable leaves"""
if isinstance(obj, (dict, list)):
if isinstance(obj, dict):
obj = obj.values()
for item in obj:
self.add_io(item)
else:
if not serialization.find_node_traverser(type(obj)):
track_io(type(obj))
return
def io_init(self, **kwargs) -> fdl.Config[Self]:
# Each argument of the trainer can be stateful so we copy them
cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()}
self.add_io(cfg_kwargs)
return fdl.Config(type(self), **cfg_kwargs)
def to_fabric(self, callbacks=None, loggers=None) -> Fabric:
accelerator, devices, strategy, plugins, num_nodes = None, None, None, None, None
if hasattr(self.__io__, "devices"):
devices = self.__io__.devices
if hasattr(self.__io__, "accelerator"):
accelerator = self.__io__.accelerator
if hasattr(self.__io__, "strategy"):
strategy = self.__io__.strategy
if isinstance(strategy, fdl.Config):
strategy = fdl.build(strategy)
strategy = to_fabric(strategy)
if hasattr(self.__io__, "plugins"):
plugins = self.__io__.plugins
if isinstance(plugins, fdl.Config):
plugins = fdl.build(plugins)
plugins = to_fabric(plugins)
if hasattr(self.__io__, "num_nodes"):
num_nodes = self.__io__.num_nodes
out = Fabric(
devices=devices,
accelerator=accelerator,
strategy=strategy,
plugins=plugins,
callbacks=callbacks,
loggers=loggers,
num_nodes=num_nodes,
)
return out