MagpieTTS_Internal_Demo / scripts /vlm /llava_next_pretrain.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \
--devices=8 --tp=4 --data_type=mock
torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \
--devices=8 --tp=4 --data_type=energon --data_path='' \
--language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5
torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \
--devices=8 --tp=4 --data_type=energon --data_path='' \
--num_workers=8 --max_samples_per_sequence=100 --shuffle_buffer_size=100 \
--language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5
"""
import argparse
import torch
from lightning.pytorch.loggers import WandbLogger
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.multimodal.data.energon import ImageToken
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback
def main(args):
# pylint: disable=C0115,C0116
# Global and micro batch sizes
gbs = args.gbs
mbs = args.mbs
max_steps = args.max_steps
num_workers = args.num_workers
# For Interleaved data, the decoder sequence length needs to be higher than VQA samples
decoder_seq_length = 4096
# decoder_seq_length = 40960
if args.data_type == "energon":
from transformers import AutoProcessor
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.multimodal.data.energon import EnergonMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.vlm import LlavaNextTaskEncoder
data_path = args.data_path
max_samples_per_sequence = args.max_samples_per_sequence
shuffle_buffer_size = args.shuffle_buffer_size
model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer(model_id)
multimodal_sample_config = MultiModalSampleConfig(
image_token=ImageToken(token_str="<image>", token_id=-200),
ignore_place_holder=-100,
)
# Setting system prompt to empty string
multimodal_sample_config.conversation_template_config.system = ''
task_encoder = LlavaNextTaskEncoder(
tokenizer=tokenizer.tokenizer,
image_processor=processor.image_processor,
multimodal_sample_config=multimodal_sample_config,
packed_sequence=args.use_packed_sequence,
packed_sequence_size=decoder_seq_length,
)
data = EnergonMultiModalDataModule(
path=data_path,
tokenizer=tokenizer,
image_processor=processor.image_processor,
num_workers=num_workers,
micro_batch_size=mbs,
global_batch_size=gbs,
max_samples_per_sequence=max_samples_per_sequence,
shuffle_buffer_size=shuffle_buffer_size,
seq_length=decoder_seq_length,
multimodal_sample_config=multimodal_sample_config,
task_encoder=task_encoder,
packing_buffer_size=200 if args.use_packed_sequence else None,
virtual_epoch_length=1000,
)
elif args.data_type == "mock":
data = vlm.LlavaNextMockDataModule(
seq_length=decoder_seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=None,
image_processor=None,
num_workers=num_workers,
)
else:
raise ValueError(f"Data type {args.data_type} not supported")
# Submodules configurations
language_transformer_config = llm.Llama2Config7B(seq_length=decoder_seq_length)
vision_transformer_config = vlm.HFCLIPVisionConfig(pretrained_model_name_or_path=args.vision_encoder_model_path)
vision_projection_config = vlm.MultimodalProjectorConfig(
projector_type=args.projector_type,
input_size=vision_transformer_config.hidden_size,
hidden_size=language_transformer_config.hidden_size,
ffn_hidden_size=language_transformer_config.hidden_size,
)
# Llava Next model configuration
llava_next_config = vlm.LlavaNextConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
language_model_from_pretrained=args.language_model_path,
pipeline_dtype=torch.bfloat16,
freeze_language_model=True,
freeze_vision_model=True,
)
model = vlm.LlavaNextModel(llava_next_config, tokenizer=data.tokenizer)
# Training strategy setup
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp_size,
pipeline_model_parallel_size=args.pp_size,
context_parallel_size=args.cp_size,
encoder_pipeline_model_parallel_size=args.encoder_pp_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=True if args.tp_size > 1 else False,
)
# Checkpoint callback setup
checkpoint_callback = nl.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
save_top_k=2,
every_n_train_steps=1000,
dirpath=args.log_dir,
)
# Trainer setup
trainer = nl.Trainer(
num_nodes=args.num_nodes,
devices=args.devices,
max_steps=max_steps,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[
checkpoint_callback,
TimingCallback(),
],
val_check_interval=1000,
limit_val_batches=gbs,
log_every_n_steps=1,
num_sanity_val_steps=0,
)
# Logger setup
nemo_logger = nl.NeMoLogger(
log_dir=args.log_dir,
name=args.name,
wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None,
)
# Auto resume setup
resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
resume_from_directory=args.log_dir,
restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None,
)
# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=args.lr,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=True,
bf16=True,
)
sched = CosineAnnealingScheduler(
max_steps=trainer.max_steps,
warmup_steps=150,
constant_steps=0,
min_lr=2.0e-05,
)
opt = MegatronOptimizerModule(opt_config, sched)
llm.pretrain(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
optim=opt,
resume=resume,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llava Next Pretraining Script")
# Argument parsing
parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon")
parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file")
parser.add_argument(
"--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints"
)
parser.add_argument(
"--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model"
)
parser.add_argument(
"--vision_encoder_model_path",
type=str,
required=False,
default="openai/clip-vit-large-patch14-336",
help="Path to the pretrained vision encoder model",
)
parser.add_argument(
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
)
parser.add_argument("--devices", type=int, required=False, default=1)
parser.add_argument("--num_nodes", type=int, required=False, default=1)
parser.add_argument("--max_steps", type=int, required=False, default=2100)
parser.add_argument("--tp_size", type=int, required=False, default=2)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--cp_size", type=int, required=False, default=1)
parser.add_argument("--encoder_pp_size", type=int, required=False, default=0)
parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu")
parser.add_argument("--name", type=str, required=False, default="llava_next_pretrain")
parser.add_argument("--wandb_project", type=str, required=False, default=None)
parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size")
parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size")
parser.add_argument("--lr", type=float, required=False, default=0.001, help="Learning rate")
parser.add_argument(
"--num_workers",
type=int,
required=False,
default=4,
help="The number of data loader workers per rank. May be 0 to disable worker processes",
)
parser.add_argument(
"--max_samples_per_sequence",
type=int,
required=False,
default=100,
help="If using Energon, the maximum number of samples per sequence to load from memory",
)
parser.add_argument(
"--shuffle_buffer_size",
type=int,
required=False,
default=100,
help="If using Energon, the size of the sample shuffle buffer (before task encoding)",
)
parser.add_argument("--use_packed_sequence", action="store_true")
args = parser.parse_args()
main(args)