subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: This script is just an example of using NeMo checkpoints
# for generating outputs and is subject to change without notice.
from argparse import ArgumentParser
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
import nemo.lightning as nl
from nemo.collections.llm import api
"""
torchrun --nproc-per-node=8 /opt/NeMo/scripts/llm/generate.py \
--model_path=<PATH_TO_NEMO2_MODEL> \
--tp=8 \
--devices=8 \
--num_tokens_to_generate=40 \
--temperature=0.001 \
--top_p=0.0 \
--top_k=1 \
--fp8
"""
def get_args():
"""
Parse the command line arguments.
"""
parser = ArgumentParser(description="""Run generation on a few sample prompts given the checkpoint path.""")
parser.add_argument(
"--prompts",
type=str,
nargs="+",
default=[
"Q: How are you?",
"Q: How big is the universe?",
"Q: How is the weather?",
"Q: How many stars are there?",
"Paris is know for its ",
"In a hot sunny day, you should ",
"Q: How many planets are in the solar system?",
"Q: How old are you?",
],
help="List of prompt strings",
)
parser.add_argument(
"--model_path",
type=str,
required=True,
help="""Path to NeMo 2 checkpoint""",
)
parser.add_argument(
"--tp",
type=int,
default=1,
help="""Tensor parallel size""",
)
parser.add_argument(
"--pp",
type=int,
default=1,
help="""Pipeline parallel size""",
)
parser.add_argument(
"--ep",
type=int,
default=1,
help="""Expert parallel size""",
)
parser.add_argument(
"--etp",
type=int,
default=None,
help="""Expert tensor parallel size""",
)
parser.add_argument(
"--devices",
type=int,
default=1,
help="""Number of GPUs to use on a single node""",
)
parser.add_argument(
"--nodes",
type=int,
default=1,
help="""Number of nodes to use""",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Temperature to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--top_k",
type=int,
default=0,
help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--add_BOS",
action="store_true",
help="""Whether to add BOS token to the prompt""",
)
parser.add_argument(
"--num_tokens_to_generate",
type=int,
default=25,
help="""Number of tokens to generate per prompt""",
)
parser.add_argument(
"--fp8",
action="store_true",
help="""Whether to run inference in FP8 precision""",
)
parser.add_argument(
"--fp8_recipe",
type=str,
default="tensorwise",
help="""fp8 recipe, can be 'tensorwise', 'delayed', or 'mxfp8'""",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=8,
help="""Maximum batch size for inference""",
)
parser.add_argument(
"--random_seed",
type=int,
default=1234,
help="""Random seed for generation""",
)
parser.add_argument(
"--legacy_ckpt",
action="store_true",
help="""Load ckpt saved with TE < 1.14""",
)
parser.add_argument(
"--disable_flash_decode",
action="store_true",
help="""Disable flash decode for models that do not support it""",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if args.fp8:
assert len(args.prompts) % 8 == 0, "Batch size should be divisible by 8 for FP8 inference"
if args.etp is None and args.ep > 1:
# Unless ETP is explicitly given, disable ETP if using EP. Otherwise ETP = TP.
args.etp = 1
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp,
pipeline_model_parallel_size=args.pp,
expert_model_parallel_size=args.ep,
expert_tensor_parallel_size=args.etp,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
store_optimizer_states=False,
)
trainer = nl.Trainer(
accelerator="gpu",
devices=args.devices,
num_nodes=args.nodes,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=False,
fp8="hybrid" if args.fp8 else None,
fp8_recipe=args.fp8_recipe if args.fp8 else None,
fp8_amax_history_len=1,
fp8_amax_compute_algo="max" if args.fp8 else "most_recent",
),
)
# Load ckpt saved with TE < 1.14
if args.legacy_ckpt:
trainer.strategy.ckpt_load_strictness = False
prompts = args.prompts
results = api.generate(
path=args.model_path,
prompts=prompts,
trainer=trainer,
add_BOS=args.add_BOS,
inference_params=CommonInferenceParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
num_tokens_to_generate=args.num_tokens_to_generate,
return_log_probs=False,
top_n_logprobs=0,
),
text_only=True,
max_batch_size=args.max_batch_size,
random_seed=args.random_seed,
enable_flash_decode=not args.disable_flash_decode,
)
if torch.distributed.get_rank() == 0:
for i, r in enumerate(results):
print(prompts[i])
print("*" * 50)
print(r)
print("\n\n")