subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mock Data Example:
torchrun --nproc_per_node=2 scripts/vlm/llama4/llama4_finetune.py \
--devices=2 --tp=2 --data_type=mock --mbs=1 --gbs=4 --use_toy_model
"""
import argparse
import torch
from lightning.pytorch.loggers import WandbLogger
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.vlm.data.data_module import EnergonDataModule
from nemo.collections.vlm.llama4.data.task_encoder import TaskEncoder as Llama4TaskEncoder
from nemo.collections.vlm.llama4.data.task_encoder import TaskEncoderConfig as Llama4TaskEncoderConfig
from nemo.collections.vlm.llama4.model.base import Llama4OmniModel
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback
def main(args):
# pylint: disable=C0115,C0116
# Global and micro batch sizes
gbs = args.gbs
mbs = args.mbs
max_steps = args.max_steps
num_workers = args.num_workers
val_check_interval = 500
decoder_seq_length = args.decoder_seq_length
# Submodules configurations
# switch to 128E with vlm.Llama4MaverickExperts128Config()
llama4_config = vlm.Llama4ScoutExperts16Config()
if args.use_toy_model:
decoder_seq_length = 4096
val_check_interval = 50
llama4_config.vision_transformer_config.num_layers = 2
llama4_config.language_transformer_config.num_layers = 2
llama4_config.language_transformer_config.num_moe_experts = 2
num_workers = 0
if args.data_type == "llava":
raise NotImplementedError
elif args.data_type == "energon":
task_encoder = Llama4TaskEncoder(
config=Llama4TaskEncoderConfig(
hf_path='meta-llama/Llama-4-Scout-17B-16E-Instruct',
)
)
data = EnergonDataModule(
path=args.data_path,
train_encoder=task_encoder,
seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
num_workers=num_workers,
)
elif args.data_type == "mock":
llama_tokenizer = AutoTokenizer('meta-llama/Llama-4-Scout-17B-16E-Instruct')
data = vlm.Llama4MockDataModule(
seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=llama_tokenizer,
image_processor=None,
num_workers=num_workers,
packed_sequence=args.use_packed_sequence,
)
else:
raise ValueError(f"Data type {args.data_type} not supported")
from megatron.core.distributed import DistributedDataParallelConfig
# Training strategy setup
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp_size,
expert_tensor_parallel_size=args.tp_size,
expert_model_parallel_size=args.ep_size,
pipeline_model_parallel_size=args.pp_size,
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=True,
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
ckpt_load_strictness="log_all",
)
model = Llama4OmniModel(llama4_config, tokenizer=data.tokenizer)
# Checkpoint callback setup
checkpoint_callback = nl.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
save_top_k=2,
every_n_train_steps=1000,
dirpath=args.log_dir,
)
from nemo.lightning.pytorch.callbacks import NsysCallback
# Trainer setup
trainer = nl.Trainer(
num_nodes=args.num_nodes,
devices=args.devices,
max_steps=max_steps,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[
checkpoint_callback,
TimingCallback(),
MegatronCommOverlapCallback(tp_comm_overlap=False),
NsysCallback(start_step=10, end_step=12, ranks=[0], gen_shape=True),
],
val_check_interval=val_check_interval,
limit_val_batches=gbs,
log_every_n_steps=1,
num_sanity_val_steps=0,
)
# Logger setup
nemo_logger = nl.NeMoLogger(
log_dir=args.log_dir,
name=args.name,
wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None,
)
# Auto resume setup
resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
resume_from_directory=args.log_dir,
restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None,
)
# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=args.lr,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=True,
bf16=True,
)
sched = CosineAnnealingScheduler(
max_steps=trainer.max_steps,
warmup_steps=150,
constant_steps=0,
min_lr=1.0e-07,
)
opt = MegatronOptimizerModule(opt_config, sched)
# PEFT setup
if args.peft == 'lora':
peft = vlm.peft.LoRA(
target_modules=[
"linear_qkv",
"linear_proj",
"linear_fc1",
"linear_fc2",
]
)
else:
peft = None
llm.finetune(
model=model,
data=data,
trainer=trainer,
peft=peft,
log=nemo_logger,
optim=opt,
resume=resume,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama4 Model Training Script")
# Argument parsing
parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon")
parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file")
parser.add_argument(
"--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints"
)
parser.add_argument(
"--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model"
)
parser.add_argument(
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
)
parser.add_argument("--devices", type=int, required=False, default=1)
parser.add_argument("--num_workers", type=int, required=False, default=4)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument("--max_steps", type=int, required=False, default=5190)
parser.add_argument("--tp_size", type=int, required=False, default=1)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--cp_size", type=int, required=False, default=1)
parser.add_argument("--ep_size", type=int, required=False, default=1)
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
parser.add_argument("--projector_type", type=str, required=False, default="mcore_mlp")
parser.add_argument("--name", type=str, required=False, default="llama4_pretrain")
parser.add_argument("--peft", type=str, default='none', help="none | lora")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--gbs", type=int, required=False, default=128, help="Global batch size")
parser.add_argument("--mbs", type=int, required=False, default=1, help="Micro batch size")
parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate")
parser.add_argument("--decoder_seq_length", type=int, required=False, default=8192, help="decoder sequence length")
parser.add_argument(
"--use_packed_sequence",
action="store_true",
)
parser.add_argument(
"--use_toy_model",
action="store_true",
help="Toy size model used for testing",
)
args = parser.parse_args()
main(args)