Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Callable, List, Optional | |
| import lightning.pytorch as pl | |
| try: | |
| from megatron.core.distributed import finalize_model_grads | |
| from megatron.core.optimizer import OptimizerConfig | |
| from megatron.core.utils import get_model_config | |
| HAVE_MEGATRON_CORE = True | |
| except (ImportError, ModuleNotFoundError): | |
| OptimizerConfig = object | |
| HAVE_MEGATRON_CORE = False | |
| from torch.optim import Optimizer | |
| from nemo.lightning._strategy_lib import setup_megatron_optimizer | |
| from nemo.lightning.megatron_parallel import MegatronParallel | |
| from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule | |
| class MegatronOptimizerModule(OptimizerModule): | |
| """A OptimizerModule for the megatron optimizers. | |
| Attributes: | |
| config (OptimizerConfig): Configuration for the optimizer. | |
| no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. | |
| scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. | |
| lr_mult (float): Learning rate multiplier. | |
| Example:: | |
| config = OptimizerConfig(...) | |
| lr_scheduler = MyLRSchedulerModule(...) | |
| optimizer_module = MegatronOptimizerModule(config, lr_scheduler) | |
| Methods: | |
| setup(model): Sets up the optimizer. | |
| optimizers(model): Defines the optimizers. | |
| """ | |
| def __init__( | |
| self, | |
| config: OptimizerConfig, | |
| lr_scheduler: Optional[LRSchedulerModule] = None, | |
| no_weight_decay_cond: Optional[Callable] = None, | |
| scale_lr_cond: Optional[Callable] = None, | |
| lr_mult: float = 1.0, | |
| ): | |
| """Initializes the MegatronOptimizerModule. | |
| Args: | |
| config (OptimizerConfig): Configuration for the optimizer. | |
| lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module. | |
| no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. | |
| scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. | |
| lr_mult (float): Learning rate multiplier. | |
| """ | |
| super().__init__(lr_scheduler=lr_scheduler) | |
| self.config = config | |
| self.no_weight_decay_cond = no_weight_decay_cond | |
| self.scale_lr_cond = scale_lr_cond | |
| self.lr_mult = lr_mult | |
| def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | |
| """We will add the finalize_model_grads function to the model config. | |
| Args: | |
| model: The model for which the optimizer is being set up. | |
| """ | |
| def finalize_model_grads_func(*args, **kwargs): | |
| return self.finalize_model_grads(*args, **kwargs) | |
| get_model_config(pl_module).finalize_model_grads_func = finalize_model_grads_func | |
| def optimizers(self, model: MegatronParallel) -> List[Optimizer]: | |
| """Defines the optimizers. | |
| Args: | |
| model (MegatronParallel): The model for which the optimizers are being defined. | |
| Returns: | |
| List[Optimizer]: The list of optimizers. | |
| Raises: | |
| ValueError: If the model is not an instance of MegatronParallel. | |
| """ | |
| if not isinstance(model, MegatronParallel): | |
| raise ValueError("Model must be an instance of MegatronParallel") | |
| optimizer = setup_megatron_optimizer( | |
| model, | |
| self.config, | |
| no_weight_decay_cond=self.no_weight_decay_cond, | |
| scale_lr_cond=self.scale_lr_cond, | |
| lr_mult=self.lr_mult, | |
| ) | |
| return [optimizer] | |
| def finalize_model_grads(self, *args, **kwargs): | |
| """Return function to finalize the model gradients.""" | |
| return finalize_model_grads(*args, **kwargs) | |