# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This script is just an example of using NeMo checkpoints # for generating outputs and is subject to change without notice. from argparse import ArgumentParser import torch.distributed from megatron.core.inference.common_inference_params import CommonInferenceParams import nemo.lightning as nl from nemo.collections.llm import api """ torchrun --nproc-per-node=8 /opt/NeMo/scripts/llm/generate.py \ --model_path= \ --tp=8 \ --devices=8 \ --num_tokens_to_generate=40 \ --temperature=0.001 \ --top_p=0.0 \ --top_k=1 \ --fp8 """ def get_args(): """ Parse the command line arguments. """ parser = ArgumentParser(description="""Run generation on a few sample prompts given the checkpoint path.""") parser.add_argument( "--prompts", type=str, nargs="+", default=[ "Q: How are you?", "Q: How big is the universe?", "Q: How is the weather?", "Q: How many stars are there?", "Paris is know for its ", "In a hot sunny day, you should ", "Q: How many planets are in the solar system?", "Q: How old are you?", ], help="List of prompt strings", ) parser.add_argument( "--model_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""", ) parser.add_argument( "--tp", type=int, default=1, help="""Tensor parallel size""", ) parser.add_argument( "--pp", type=int, default=1, help="""Pipeline parallel size""", ) parser.add_argument( "--ep", type=int, default=1, help="""Expert parallel size""", ) parser.add_argument( "--etp", type=int, default=None, help="""Expert tensor parallel size""", ) parser.add_argument( "--devices", type=int, default=1, help="""Number of GPUs to use on a single node""", ) parser.add_argument( "--nodes", type=int, default=1, help="""Number of nodes to use""", ) parser.add_argument( "--temperature", type=float, default=1.0, help="""Temperature to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--top_p", type=float, default=0.95, help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--top_k", type=int, default=0, help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--add_BOS", action="store_true", help="""Whether to add BOS token to the prompt""", ) parser.add_argument( "--num_tokens_to_generate", type=int, default=25, help="""Number of tokens to generate per prompt""", ) parser.add_argument( "--fp8", action="store_true", help="""Whether to run inference in FP8 precision""", ) parser.add_argument( "--fp8_recipe", type=str, default="tensorwise", help="""fp8 recipe, can be 'tensorwise', 'delayed', or 'mxfp8'""", ) parser.add_argument( "--max_batch_size", type=int, default=8, help="""Maximum batch size for inference""", ) parser.add_argument( "--random_seed", type=int, default=1234, help="""Random seed for generation""", ) parser.add_argument( "--legacy_ckpt", action="store_true", help="""Load ckpt saved with TE < 1.14""", ) parser.add_argument( "--disable_flash_decode", action="store_true", help="""Disable flash decode for models that do not support it""", ) args = parser.parse_args() return args if __name__ == "__main__": args = get_args() if args.fp8: assert len(args.prompts) % 8 == 0, "Batch size should be divisible by 8 for FP8 inference" if args.etp is None and args.ep > 1: # Unless ETP is explicitly given, disable ETP if using EP. Otherwise ETP = TP. args.etp = 1 strategy = nl.MegatronStrategy( tensor_model_parallel_size=args.tp, pipeline_model_parallel_size=args.pp, expert_model_parallel_size=args.ep, expert_tensor_parallel_size=args.etp, context_parallel_size=1, sequence_parallel=False, setup_optimizers=False, store_optimizer_states=False, ) trainer = nl.Trainer( accelerator="gpu", devices=args.devices, num_nodes=args.nodes, strategy=strategy, plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", params_dtype=torch.bfloat16, pipeline_dtype=torch.bfloat16, autocast_enabled=False, grad_reduce_in_fp32=False, fp8="hybrid" if args.fp8 else None, fp8_recipe=args.fp8_recipe if args.fp8 else None, fp8_amax_history_len=1, fp8_amax_compute_algo="max" if args.fp8 else "most_recent", ), ) # Load ckpt saved with TE < 1.14 if args.legacy_ckpt: trainer.strategy.ckpt_load_strictness = False prompts = args.prompts results = api.generate( path=args.model_path, prompts=prompts, trainer=trainer, add_BOS=args.add_BOS, inference_params=CommonInferenceParams( temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, num_tokens_to_generate=args.num_tokens_to_generate, return_log_probs=False, top_n_logprobs=0, ), text_only=True, max_batch_size=args.max_batch_size, random_seed=args.random_seed, enable_flash_decode=not args.disable_flash_decode, ) if torch.distributed.get_rank() == 0: for i, r in enumerate(results): print(prompts[i]) print("*" * 50) print(r) print("\n\n")