# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from argparse import ArgumentParser import torch from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.dist_checkpointing.validation import StrictHandling from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.data import ChatDataModule, MockDataModule from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer from nemo.lightning.pytorch.callbacks import ModelCheckpoint from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.utils import logging # Suppress lengthy HF warning os.environ["TOKENIZERS_PARALLELISM"] = "false" def get_args(): """Parse the command-line arguments.""" parser = ArgumentParser( description=""" Script for training GPT models. Supports 4 modes, with different arguments needed in addition to the required arguments: 1. Pretrain: no additional arguments required 2. SFT: --use-chat-data required 3. Distillation: --teacher_path required 4. SFT Distillation: --use-chat-data and --teacher_path required """ ) parser.add_argument("--name", type=str, required=True, help="Experiment name") parser.add_argument( "--model_path", type=str, required=True, help="Path to NeMo 2 checkpoint. If only model_path is provided, the model will be trained (pretrain or SFT). If teacher_path is also provided, the model will be distilled.", ) parser.add_argument( "--teacher_path", type=str, required=False, help="Path to NeMo 2 checkpoint to use as a distillation teacher. Will trigger distillation mode if provided.", ) parser.add_argument("--kd_config", type=str, help="""Path to Knowledge-Distillation config file""") parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("--cp_size", type=int, default=1, help="Context parallel size") parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--ep_size", type=int, default=1, help="Expert parallel size") parser.add_argument("--precision", type=str, default="bf16-mixed", help="Datatype for models and optimizer") parser.add_argument("--devices", type=int, default=1, help="Number of GPUs to use per node") parser.add_argument("--num_nodes", type=int, default=1, help="Number of nodes to use") parser.add_argument("--log_dir", type=str, required=True, help="Folder for logging and checkpoint saving") parser.add_argument("--max_steps", type=int, required=True, help="Number of global batches to process") parser.add_argument("--gbs", type=int, required=True, help="Global Batch Size") parser.add_argument("--mbs", type=int, required=True, help="Micro-batch Size") parser.add_argument( "--data_paths", nargs="+", help="List of tokenized data paths to load from. If using chat data, provide a single path.", ) parser.add_argument("--split", type=str, default="99,1,0", help="Train,Val,Test ratios to split data") parser.add_argument("--index_mapping_dir", type=str, help="Folder to write cached data indices") parser.add_argument("--use-chat-data", action="store_true", help="Use chat data for fine-tuning.") parser.add_argument( "--chat-template-path", type=str, help="Path to Chat template .txt file to use for chat data. Only provide if overriding default chat template in HuggingFace tokenizer.", ) parser.add_argument( "--use_mock_data", action="store_true", help="Use mock data instead of custom data in --data_paths" ) parser.add_argument("--seq_length", type=int, required=True, help="Number of tokens per input sample") parser.add_argument( "--tokenizer", type=str, help="Name of tokenizer model to override default. Required if using chat data (--use-chat-data).", ) parser.add_argument("--lr", type=float, default=1e-4, help="Base LR for Cosine-Annealing scheduler") parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum LR for Cosine-Annealing scheduler") parser.add_argument("--warmup_steps", type=int, default=50, help="Number of scheduler warmup steps") parser.add_argument("--val_check_interval", type=int, default=100, help="Validate + checkpoint every _ steps") parser.add_argument("--limit_val_batches", type=int, default=32, help="Number of batches per validation stage") parser.add_argument("--log_interval", type=int, default=10, help="Write to log every _ steps") parser.add_argument("--legacy_ckpt", action="store_true", help="Load ckpt saved with TE < 1.14") return parser.parse_args() def _read_chat_template(template_path: str): # pylint: disable=C0116 if not template_path: return None with open(template_path, 'r') as f: return f.read().strip() if __name__ == "__main__": args = get_args() ## Initialize the strategy and trainer strategy = nl.MegatronStrategy( tensor_model_parallel_size=args.tp_size, pipeline_model_parallel_size=args.pp_size, context_parallel_size=args.cp_size, expert_model_parallel_size=args.ep_size, sequence_parallel=(args.tp_size > 1), ddp=DistributedDataParallelConfig( grad_reduce_in_fp32=True, overlap_grad_reduce=True, overlap_param_gather=True, check_for_nan_in_grad=True, average_in_collective=True, ), ckpt_load_strictness=StrictHandling.LOG_ALL if args.legacy_ckpt else None, ) trainer = nl.Trainer( devices=args.devices, num_nodes=args.num_nodes, max_steps=args.max_steps, log_every_n_steps=args.log_interval, val_check_interval=args.val_check_interval, limit_val_batches=args.limit_val_batches, strategy=strategy, accelerator="gpu", plugins=nl.MegatronMixedPrecision( precision=args.precision, params_dtype=torch.bfloat16 if "bf16" in args.precision else torch.float32, autocast_enabled=False, grad_reduce_in_fp32=True, ), ) ## Set up dataset if not args.use_mock_data and not args.data_paths: raise ValueError("Must provide either custom dataset(s) in --data_paths or set --use_mock_data.") if args.use_mock_data: logging.warning("Using Mock Data for training!") data = MockDataModule(seq_length=args.seq_length, global_batch_size=args.gbs, micro_batch_size=args.mbs) elif args.use_chat_data: assert len(args.data_paths) == 1, "If using chat data, provide a single path." assert args.tokenizer is not None, "Tokenizer is required if using chat data." chat_template = _read_chat_template(args.chat_template_path) tokenizer = get_tokenizer(args.tokenizer, chat_template=chat_template) if '{% generation %}' not in tokenizer.tokenizer.chat_template: logging.warning( "The chat template does not contain a '{% generation %}' keyword, which will not produce proper assistant mask during training. Instead no tokens will be masked (all tokens contribute to the loss). See https://github.com/huggingface/transformers/pull/30650" ) data = ChatDataModule( dataset_root=args.data_paths[0], seq_length=args.seq_length, tokenizer=tokenizer, global_batch_size=args.gbs, micro_batch_size=args.mbs, use_hf_tokenizer_chat_template=True, ) else: data = llm.PreTrainingDataModule( paths=args.data_paths, seq_length=args.seq_length, global_batch_size=args.gbs, micro_batch_size=args.mbs, split=args.split, index_mapping_dir=args.index_mapping_dir, ) ## Set up optimizer optim_config = OptimizerConfig( optimizer="adam", lr=args.lr, bf16=("bf16" in args.precision), use_distributed_optimizer=True, ) sched = CosineAnnealingScheduler( max_steps=args.max_steps, warmup_steps=args.warmup_steps, constant_steps=0, min_lr=args.min_lr, ) optim = nl.MegatronOptimizerModule(optim_config, sched) ## Set up checkpointing and logging checkpoint_callback = ModelCheckpoint( monitor="val_loss", save_top_k=1, every_n_train_steps=args.val_check_interval, ) logger = nl.NeMoLogger( name=args.name, log_dir=args.log_dir, ckpt=checkpoint_callback, tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), update_logger_directory=False, ) ## Set up resume and/or restore functionality resume = nl.AutoResume( resume_if_exists=True, resume_ignore_no_checkpoint=True, restore_config=nl.RestoreConfig(path=args.model_path), ) if args.teacher_path: llm.distill( student_model_path=args.model_path, teacher_model_path=args.teacher_path, distillation_config_path=args.kd_config, data=data, trainer=trainer, log=logger, resume=resume, optim=optim, tokenizer=get_tokenizer(args.tokenizer) if args.tokenizer else None, ) else: llm.train( model=args.model_path, data=data, trainer=trainer, optim=optim, log=logger, resume=resume, tokenizer="data" if args.use_chat_data else "model", )