# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional import lightning.pytorch as pl try: from megatron.core.distributed import finalize_model_grads from megatron.core.optimizer import OptimizerConfig from megatron.core.utils import get_model_config HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): OptimizerConfig = object HAVE_MEGATRON_CORE = False from torch.optim import Optimizer from nemo.lightning._strategy_lib import setup_megatron_optimizer from nemo.lightning.megatron_parallel import MegatronParallel from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule class MegatronOptimizerModule(OptimizerModule): """A OptimizerModule for the megatron optimizers. Attributes: config (OptimizerConfig): Configuration for the optimizer. no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. lr_mult (float): Learning rate multiplier. Example:: config = OptimizerConfig(...) lr_scheduler = MyLRSchedulerModule(...) optimizer_module = MegatronOptimizerModule(config, lr_scheduler) Methods: setup(model): Sets up the optimizer. optimizers(model): Defines the optimizers. """ def __init__( self, config: OptimizerConfig, lr_scheduler: Optional[LRSchedulerModule] = None, no_weight_decay_cond: Optional[Callable] = None, scale_lr_cond: Optional[Callable] = None, lr_mult: float = 1.0, ): """Initializes the MegatronOptimizerModule. Args: config (OptimizerConfig): Configuration for the optimizer. lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module. no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. lr_mult (float): Learning rate multiplier. """ super().__init__(lr_scheduler=lr_scheduler) self.config = config self.no_weight_decay_cond = no_weight_decay_cond self.scale_lr_cond = scale_lr_cond self.lr_mult = lr_mult def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): """We will add the finalize_model_grads function to the model config. Args: model: The model for which the optimizer is being set up. """ def finalize_model_grads_func(*args, **kwargs): return self.finalize_model_grads(*args, **kwargs) get_model_config(pl_module).finalize_model_grads_func = finalize_model_grads_func def optimizers(self, model: MegatronParallel) -> List[Optimizer]: """Defines the optimizers. Args: model (MegatronParallel): The model for which the optimizers are being defined. Returns: List[Optimizer]: The list of optimizers. Raises: ValueError: If the model is not an instance of MegatronParallel. """ if not isinstance(model, MegatronParallel): raise ValueError("Model must be an instance of MegatronParallel") optimizer = setup_megatron_optimizer( model, self.config, no_weight_decay_cond=self.no_weight_decay_cond, scale_lr_cond=self.scale_lr_cond, lr_mult=self.lr_mult, ) return [optimizer] def finalize_model_grads(self, *args, **kwargs): """Return function to finalize the model gradients.""" return finalize_model_grads(*args, **kwargs)