Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import lightning.pytorch as pl | |
| import nemo_run as run | |
| import torch | |
| from lightning.pytorch.loggers import WandbLogger | |
| from megatron.core.distributed import DistributedDataParallelConfig | |
| from megatron.core.optimizer import OptimizerConfig | |
| from megatron.core.transformer.enums import AttnMaskType | |
| from nemo import lightning as nl | |
| from nemo.collections import llm | |
| from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule | |
| from nemo.collections.diffusion.data.diffusion_fake_datamodule import VideoLatentFakeDataModule | |
| from nemo.collections.diffusion.data.diffusion_taskencoder import BasicDiffusionTaskEncoder | |
| from nemo.collections.diffusion.models.model import ( | |
| DiT7BConfig, | |
| DiTConfig, | |
| DiTLConfig, | |
| DiTLlama1BConfig, | |
| DiTLlama5BConfig, | |
| DiTLlama30BConfig, | |
| DiTModel, | |
| DiTXLConfig, | |
| ECDiTLlama1BConfig, | |
| ) | |
| from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule | |
| from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback | |
| from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback | |
| from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform | |
| from nemo.lightning.pytorch.callbacks.nsys import NsysCallback | |
| from nemo.lightning.pytorch.strategies.utils import RestoreConfig | |
| from nemo.utils.exp_manager import TimingCallback | |
| def multimodal_datamodule() -> pl.LightningDataModule: | |
| """Multimodal Datamodule Initialization""" | |
| data_module = DiffusionDataModule( | |
| seq_length=2048, | |
| task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), | |
| micro_batch_size=1, | |
| global_batch_size=32, | |
| ) | |
| return data_module | |
| def simple_datamodule() -> pl.LightningDataModule: | |
| """Simple Datamodule Initialization""" | |
| data_module = EnergonMultiModalDataModule( | |
| seq_length=2048, | |
| micro_batch_size=1, | |
| global_batch_size=32, | |
| num_workers=16, | |
| tokenizer=None, | |
| image_processor=None, | |
| task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), | |
| ) | |
| return data_module | |
| def multimodal_fake_datamodule() -> pl.LightningDataModule: | |
| """Multimodal Mock Datamodule Initialization""" | |
| data_module = VideoLatentFakeDataModule( | |
| seq_length=None, # Set None to dectect the sequence length automatically. | |
| task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), | |
| micro_batch_size=1, | |
| global_batch_size=32, | |
| ) | |
| return data_module | |
| def peft(args) -> ModelTransform: | |
| """Parameter Efficient Fine Tuning""" | |
| return llm.peft.LoRA( | |
| target_modules=['linear_qkv', 'linear_proj'], # , 'linear_fc1', 'linear_fc2'], | |
| dim=args.lora_dim, | |
| ) | |
| def pretrain() -> run.Partial: | |
| """Base Pretraining Config""" | |
| return run.Partial( | |
| llm.train, | |
| model=run.Config( | |
| DiTModel, | |
| config=run.Config(DiTConfig), | |
| ), | |
| data=multimodal_datamodule(), | |
| trainer=run.Config( | |
| nl.Trainer, | |
| devices='auto', | |
| num_nodes=int(os.environ.get('SLURM_NNODES', 1)), | |
| accelerator="gpu", | |
| strategy=run.Config( | |
| nl.MegatronStrategy, | |
| tensor_model_parallel_size=1, | |
| pipeline_model_parallel_size=1, | |
| context_parallel_size=1, | |
| sequence_parallel=False, | |
| pipeline_dtype=torch.bfloat16, | |
| ddp=run.Config( | |
| DistributedDataParallelConfig, | |
| check_for_nan_in_grad=True, | |
| grad_reduce_in_fp32=True, | |
| overlap_grad_reduce=True, | |
| overlap_param_gather=True, | |
| ), | |
| ), | |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), | |
| num_sanity_val_steps=0, | |
| limit_val_batches=1, | |
| val_check_interval=1000, | |
| max_epochs=10000, | |
| log_every_n_steps=1, | |
| callbacks=[ | |
| run.Config( | |
| ModelCheckpoint, | |
| monitor='global_step', | |
| filename='{global_step}', | |
| every_n_train_steps=1000, | |
| save_top_k=3, | |
| mode='max', | |
| ), | |
| run.Config(PreemptionCallback), | |
| run.Config(TimingCallback), | |
| run.Config( | |
| MegatronCommOverlapCallback, | |
| tp_comm_overlap=False, | |
| ), | |
| ], | |
| ), | |
| log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None)), | |
| optim=run.Config( | |
| nl.MegatronOptimizerModule, | |
| config=run.Config( | |
| OptimizerConfig, | |
| lr=1e-4, | |
| bf16=True, | |
| params_dtype=torch.bfloat16, | |
| use_distributed_optimizer=True, | |
| weight_decay=0, | |
| ), | |
| ), | |
| tokenizer=None, | |
| resume=run.Config( | |
| nl.AutoResume, | |
| resume_if_exists=True, | |
| resume_ignore_no_checkpoint=True, | |
| resume_past_end=True, | |
| ), | |
| model_transform=None, | |
| ) | |
| def pretrain_xl() -> run.Partial: | |
| """DiT-XL Pretraining Recipe""" | |
| recipe = pretrain() | |
| recipe.model.config = run.Config(DiTXLConfig) | |
| return recipe | |
| def pretrain_l() -> run.Partial: | |
| """DiT-L Pretraining Recipe""" | |
| recipe = pretrain() | |
| recipe.model.config = run.Config(DiTLConfig) | |
| return recipe | |
| def set_use_megatron_fsdp(recipe): | |
| try: | |
| recipe.trainer.strategy.ddp.use_megatron_fsdp = True | |
| except AttributeError: | |
| recipe.trainer.strategy.ddp.use_custom_fsdp = True | |
| def train_mock() -> run.Partial: | |
| """DiT Mock Pretraining Recipe""" | |
| recipe = pretrain() | |
| recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.model.config.num_layers = 16 | |
| recipe.data.seq_length = 73728 | |
| recipe.data.task_encoder.seq_length = 73728 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 4 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.strategy.context_parallel_size = 2 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.data.global_batch_size = 1 | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/train_mock' | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| return recipe | |
| def mock_ditllama5b_8k() -> run.Partial: | |
| """DiT-5B mock Recipe""" | |
| recipe = pretrain() | |
| recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.data.seq_length = recipe.data.task_encoder.seq_length = 8192 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 2 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.strategy.context_parallel_size = 1 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.data.global_batch_size = 32 | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k' | |
| recipe.model.config.attn_mask_type = AttnMaskType.no_mask | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| recipe.trainer.max_steps = 15 | |
| recipe.trainer.callbacks.pop(0) | |
| recipe.trainer.enable_checkpointing = False | |
| recipe.trainer.callbacks.append( | |
| run.Config( | |
| NsysCallback, | |
| start_step=10, | |
| end_step=11, | |
| ) | |
| ) | |
| recipe.resume = None | |
| return recipe | |
| def mock_dit7b_8k() -> run.Partial: | |
| """DiT-7B mock Recipe""" | |
| recipe = mock_ditllama5b_8k() | |
| recipe.model.config = run.Config(DiT7BConfig, max_frames=1) | |
| recipe.data.model_config = recipe.model.config | |
| recipe.model.config.attn_mask_type = AttnMaskType.no_mask | |
| recipe.model.config.use_cpu_initialization = True | |
| recipe.log.log_dir = 'nemo_experiments/mock_dit7b_8k' | |
| return recipe | |
| def pretrain_7b() -> run.Partial: | |
| """DiT-7B Pretraining Recipe""" | |
| recipe = pretrain() | |
| recipe.model.config = run.Config(DiT7BConfig) | |
| recipe.data.global_batch_size = 4608 | |
| recipe.data.micro_batch_size = 9 | |
| recipe.data.num_workers = 15 | |
| recipe.data.use_train_split_for_val = True | |
| recipe.data.seq_length = 260 | |
| recipe.data.task_encoder.seq_length = 260 | |
| recipe.trainer.val_check_interval = 1000 | |
| recipe.log.log_dir = 'nemo_experiments/dit7b' | |
| recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9) | |
| recipe.optim.config.weight_decay = 0.1 | |
| recipe.optim.config.adam_beta1 = 0.9 | |
| recipe.optim.config.adam_beta2 = 0.95 | |
| return recipe | |
| def pretrain_7b_pack() -> run.Partial: | |
| """DiT-7B Pretraining Recipe with Packing""" | |
| recipe = pretrain_7b() | |
| recipe.data.global_batch_size = 4608 // 9 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.data.num_workers = 15 | |
| recipe.data.use_train_split_for_val = True | |
| recipe.data.seq_length = 256 * 9 | |
| recipe.data.packing_buffer_size = 1000 | |
| recipe.data.task_encoder.seq_length = None | |
| recipe.data.task_encoder.max_seq_length = recipe.data.seq_length | |
| recipe.model.config.qkv_format = 'thd' | |
| return recipe | |
| def pretrain_7b_256p_joint() -> run.Partial: | |
| """DiT-7B Pretraining Recipe 256p Stage 1""" | |
| recipe = pretrain_7b() | |
| recipe.data.global_batch_size = 256 # 768 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.data.seq_length = 8192 | |
| recipe.data.task_encoder.seq_length = 8192 | |
| recipe.model.config.seq_length = 8192 | |
| recipe.optim.config.lr = 6e-5 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 2 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| # recipe.resume.restore_config = run.Config(RestoreConfig, path='', load_optim_state=True) | |
| recipe.log.log_dir = 'nemo_experiments/pretrain_7b_256p_joint' | |
| return recipe | |
| def pretrain_7b_256p_joint_pack() -> run.Partial: | |
| """DiT-7B Pretraining Recipe 256p Stage 1 with Packing""" | |
| recipe = pretrain_7b_256p_joint() | |
| recipe.data.global_batch_size = 128 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.data.num_workers = 10 | |
| recipe.data.seq_length = recipe.model.config.seq_length = recipe.data.task_encoder.max_seq_length = 10240 | |
| recipe.data.task_encoder.seq_length = None | |
| recipe.data.packing_buffer_size = 1000 | |
| recipe.data.virtual_epoch_length = 0 | |
| recipe.model.config.qkv_format = 'thd' | |
| return recipe | |
| def pretrain_ditllama5b() -> run.Partial: | |
| """MovieGen 5B Training""" | |
| recipe = pretrain_7b() | |
| recipe.data.micro_batch_size = 12 | |
| recipe.model.config = run.Config(DiTLlama5BConfig) | |
| recipe.log.log_dir = 'nemo_experiments/ditllama5b' | |
| return recipe | |
| def pretrain_ditllama30b() -> run.Partial: | |
| """MovieGen 30B Stage 1 Training""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.model.config = run.Config(DiTLlama30BConfig) | |
| recipe.data.global_batch_size = 9216 | |
| recipe.data.micro_batch_size = 6 | |
| recipe.data.task_encoder.aethetic_score = 4.0 | |
| recipe.data.seq_length = 256 | |
| recipe.data.task_encoder.seq_length = 256 | |
| recipe.data.virtual_epoch_length = 0 | |
| recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock' | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| return recipe | |
| def pretrain_ditllama30b_stage2_mock() -> run.Partial: | |
| """MovieGen 30B Stage 2 Training""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.model.config = run.Config(DiTLlama30BConfig) | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.data.model_config = recipe.model.config | |
| recipe.data.seq_length = 8192 | |
| recipe.data.task_encoder.seq_length = 8192 | |
| recipe.data.global_batch_size = 256 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 2 | |
| recipe.trainer.strategy.context_parallel_size = 4 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock' | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| return recipe | |
| def pretrain_ditllama30b_stage3_mock() -> run.Partial: | |
| """MovieGen 30B Stage 3 Training""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.model.config = run.Config(DiTLlama30BConfig) | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.data.model_config = recipe.model.config | |
| recipe.data.seq_length = 73728 | |
| recipe.data.task_encoder.seq_length = 73728 | |
| recipe.data.global_batch_size = 256 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 2 | |
| recipe.trainer.strategy.context_parallel_size = 8 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock' | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| return recipe | |
| def pretrain_ditllama5b_stage3_mock_with_pp() -> run.Partial: | |
| """MovieGen 30B Stage 3 Training""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.data.model_config = recipe.model.config | |
| recipe.data.seq_length = 8192 | |
| recipe.data.task_encoder.seq_length = 8192 | |
| recipe.data.global_batch_size = 1 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 2 | |
| recipe.trainer.strategy.pipeline_model_parallel_size = 2 | |
| recipe.trainer.strategy.context_parallel_size = 2 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage5_mock_with_pp' | |
| return recipe | |
| def pretrain_ditllama30b_stage3_mock_with_pp() -> run.Partial: | |
| """MovieGen 30B Stage 3 Training with Pipeline Parallelism""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.model.config = run.Config(DiTLlama30BConfig) | |
| recipe.data = multimodal_fake_datamodule() | |
| recipe.data.model_config = recipe.model.config | |
| recipe.data.seq_length = 73728 | |
| recipe.data.task_encoder.seq_length = 73728 | |
| recipe.data.global_batch_size = 256 | |
| recipe.data.micro_batch_size = 1 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 4 | |
| recipe.trainer.strategy.pipeline_model_parallel_size = 4 | |
| recipe.trainer.strategy.context_parallel_size = 8 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.trainer.limit_val_batches = 0 | |
| recipe.trainer.val_check_interval = 1.0 | |
| recipe.data.model_config = recipe.model.config | |
| recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock_with_pp' | |
| return recipe | |
| def pretrain_ditllama1b() -> run.Partial: | |
| """MovieGen 1B Stage 1 Training""" | |
| recipe = pretrain_ditllama5b() | |
| recipe.model.config = run.Config(DiTLlama1BConfig) | |
| recipe.data.task_encoder.aethetic_score = 4.0 | |
| recipe.data.seq_length = 256 | |
| recipe.data.task_encoder.seq_length = 256 | |
| recipe.model.config.seq_length = 256 | |
| recipe.data.global_batch_size = 1536 | |
| recipe.data.micro_batch_size = 96 | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.log.log_dir = 'nemo_experiments/ditllama1b' | |
| recipe.trainer.val_check_interval = 3000 | |
| recipe.trainer.callbacks[0].every_n_train_steps = 3000 | |
| recipe.trainer.callbacks[0].monitor = 'global_step' | |
| recipe.trainer.callbacks[0].save_top_k = 3 | |
| recipe.trainer.callbacks[0].mode = 'max' | |
| return recipe | |
| def pretrain_ditllama3b() -> run.Partial: | |
| """MovieGen 3B Stage 1 Training""" | |
| recipe = pretrain_ditllama1b() | |
| recipe.data.micro_batch_size = 48 | |
| recipe.model.config = run.Config( | |
| DiTLlama1BConfig, | |
| hidden_size=3072, | |
| num_layers=28, | |
| num_attention_heads=24, | |
| ffn_hidden_size=8192, | |
| ) | |
| recipe.log.log_dir = 'nemo_experiments/ditllama3b' | |
| return recipe | |
| def pretrain_ecditllama1b() -> run.Partial: | |
| """EC-DiT 1B Training""" | |
| recipe = pretrain_ditllama1b() | |
| recipe.data.task_encoder.aethetic_score = 5.0 | |
| recipe.data.micro_batch_size = 72 | |
| recipe.data.global_batch_size = 2304 | |
| recipe.model.config = run.Config(ECDiTLlama1BConfig) | |
| recipe.log.log_dir = 'nemo_experiments/ecditllama1b' | |
| recipe.trainer.val_check_interval = 3000 | |
| set_use_megatron_fsdp(recipe=recipe) | |
| recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params' | |
| recipe.trainer.strategy.ddp.overlap_param_gather = True | |
| recipe.trainer.strategy.ddp.overlap_grad_reduce = True | |
| recipe.model.config.use_cpu_initialization = True | |
| return recipe | |
| def dreambooth() -> run.Partial: | |
| """Dreambooth Fine Tuning""" | |
| recipe = pretrain() | |
| recipe.optim.config.lr = 1e-6 | |
| recipe.data = multimodal_datamodule() | |
| recipe.model.config = run.Config(DiTConfig) | |
| recipe.trainer.max_steps = 1000 | |
| recipe.trainer.strategy.tensor_model_parallel_size = 8 | |
| recipe.trainer.strategy.sequence_parallel = True | |
| recipe.resume.restore_config = run.Config(RestoreConfig) | |
| recipe.resume.resume_if_exists = False | |
| return recipe | |
| if __name__ == "__main__": | |
| OOM_DEBUG = False | |
| if OOM_DEBUG: | |
| torch.cuda.memory._record_memory_history( | |
| True, | |
| # Keep 100,000 alloc/free events from before the snapshot | |
| trace_alloc_max_entries=100000, | |
| # Record stack information for the trace events | |
| trace_alloc_record_context=True, | |
| ) | |
| run.cli.main(llm.train, default_factory=dreambooth) | |