subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional
import lightning.pytorch as pl
try:
from megatron.core.distributed import finalize_model_grads
from megatron.core.optimizer import OptimizerConfig
from megatron.core.utils import get_model_config
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
OptimizerConfig = object
HAVE_MEGATRON_CORE = False
from torch.optim import Optimizer
from nemo.lightning._strategy_lib import setup_megatron_optimizer
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule
class MegatronOptimizerModule(OptimizerModule):
"""A OptimizerModule for the megatron optimizers.
Attributes:
config (OptimizerConfig): Configuration for the optimizer.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
Example::
config = OptimizerConfig(...)
lr_scheduler = MyLRSchedulerModule(...)
optimizer_module = MegatronOptimizerModule(config, lr_scheduler)
Methods:
setup(model): Sets up the optimizer.
optimizers(model): Defines the optimizers.
"""
def __init__(
self,
config: OptimizerConfig,
lr_scheduler: Optional[LRSchedulerModule] = None,
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
):
"""Initializes the MegatronOptimizerModule.
Args:
config (OptimizerConfig): Configuration for the optimizer.
lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
"""
super().__init__(lr_scheduler=lr_scheduler)
self.config = config
self.no_weight_decay_cond = no_weight_decay_cond
self.scale_lr_cond = scale_lr_cond
self.lr_mult = lr_mult
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
"""We will add the finalize_model_grads function to the model config.
Args:
model: The model for which the optimizer is being set up.
"""
def finalize_model_grads_func(*args, **kwargs):
return self.finalize_model_grads(*args, **kwargs)
get_model_config(pl_module).finalize_model_grads_func = finalize_model_grads_func
def optimizers(self, model: MegatronParallel) -> List[Optimizer]:
"""Defines the optimizers.
Args:
model (MegatronParallel): The model for which the optimizers are being defined.
Returns:
List[Optimizer]: The list of optimizers.
Raises:
ValueError: If the model is not an instance of MegatronParallel.
"""
if not isinstance(model, MegatronParallel):
raise ValueError("Model must be an instance of MegatronParallel")
optimizer = setup_megatron_optimizer(
model,
self.config,
no_weight_decay_cond=self.no_weight_decay_cond,
scale_lr_cond=self.scale_lr_cond,
lr_mult=self.lr_mult,
)
return [optimizer]
def finalize_model_grads(self, *args, **kwargs):
"""Return function to finalize the model gradients."""
return finalize_model_grads(*args, **kwargs)