# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Example: python scripts/vlm/mllama_generate.py --load_from_hf """ import argparse import requests import torch from megatron.core.inference.common_inference_params import CommonInferenceParams from PIL import Image from transformers import AutoProcessor from nemo import lightning as nl from nemo.collections import vlm from nemo.collections.vlm.inference import generate as vlm_generate from nemo.collections.vlm.inference import setup_inference_wrapper model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" def load_image(image_url: str) -> Image.Image: # pylint: disable=C0115,C0116 try: response = requests.get(image_url, stream=True) response.raise_for_status() image = Image.open(response.raw) return image except requests.exceptions.RequestException as e: print(f"Error loading image from {image_url}: {e}") return None def generate(model, processor, images, text, params): # pylint: disable=C0115,C0116 messages = [ { "role": "user", "content": [{"type": "text", "text": text}], } ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) model = setup_inference_wrapper(model, processor.tokenizer) prompts = [input_text] images = [images] result = vlm_generate( model, processor.tokenizer, processor.image_processor, prompts, images, inference_params=params, ) generated_texts = list(result)[0].generated_text if torch.distributed.get_rank() == 0: print("======== GENERATED TEXT OUTPUT ========") print(f"{generated_texts}") print("=======================================") return generated_texts def main(args) -> None: # pylint: disable=C0115,C0116 strategy = nl.MegatronStrategy( tensor_model_parallel_size=args.tp_size, ckpt_load_optimizer=False, ckpt_save_optimizer=False, ) trainer = nl.Trainer( devices=args.tp_size, max_steps=1000, accelerator="gpu", strategy=strategy, plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), val_check_interval=1000, limit_val_batches=50, ) processor = AutoProcessor.from_pretrained(args.processor_name) tokenizer = processor.tokenizer fabric = trainer.to_fabric() if args.load_from_hf: model = fabric.import_model(f"hf://{model_id}", vlm.MLlamaModel) else: model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer) model = fabric.load_model(args.local_model_path, model) # Load the image raw_images = [load_image(url) for url in args.image_url] if not raw_images: return # Exit if the image can't be loaded params = CommonInferenceParams( temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, num_tokens_to_generate=args.num_tokens_to_generate, ) generate(model, processor, images=raw_images, text=args.prompt, params=params) if __name__ == "__main__": parser = argparse.ArgumentParser(description="") parser.add_argument( "--load_from_hf", action="store_true", help="Flag to indicate whether to load the model from Hugging Face hub.", ) parser.add_argument( "--local_model_path", type=str, default=None, help="Local path to the model if not loading from Hugging Face.", ) parser.add_argument( "--processor_name", type=str, default="meta-llama/Llama-3.2-11B-Vision-Instruct", help="Name or path of processor", ) parser.add_argument( "--prompt", type=str, default="<|image|>\nDescribe the image.", help="Input prompt", ) parser.add_argument( "--image_url", nargs='+', type=str, # pylint: disable=line-too-long default=[ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" ], help="List of the image urls to use for inference.", ) parser.add_argument( "--temperature", type=float, default=1.0, help="""Temperature to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--top_p", type=float, default=0.0, help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--top_k", type=int, default=1, help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""", ) parser.add_argument( "--num_tokens_to_generate", type=int, default=50, help="""Number of tokens to generate per prompt""", ) parser.add_argument("--devices", type=int, required=False, default=1) parser.add_argument("--tp_size", type=int, required=False, default=1) parser.add_argument("--pp_size", type=int, required=False, default=1) parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) args = parser.parse_args() main(args)