# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Example: python scripts/vlm/gemma3vl_generate.py --local_model_path="path/to/converted_nemo_checkpoint" """ import argparse from pathlib import Path import torch import torch.distributed as dist from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func import nemo.lightning as nl from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.vlm import Gemma3VLModel from nemo.collections.vlm.inference.base import _setup_trainer_and_restore_model from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils.get_rank import get_last_rank class SingleBatchIterator: def __init__(self, pixel_values, input_ids, position_ids): self.batch = dict( pixel_values=pixel_values, input_ids=input_ids, position_ids=position_ids, ) self._yielded = False def __iter__(self): return self def __next__(self): if self._yielded: raise StopIteration self._yielded = True return self.batch def gemma3_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: batch = next(data_iterator) forward_args = { "input_ids": batch["input_ids"], "position_ids": batch["position_ids"], "pixel_values": batch.get("pixel_values", None), "loss_mask": batch.get("loss_mask", None), "labels": batch.get("labels", None), } def loss_func(x, **kwargs): return x return model(**forward_args), loss_func def main(args) -> None: # pylint: disable=C0115,C0116,C0301 strategy = nl.MegatronStrategy( tensor_model_parallel_size=args.tp, pipeline_model_parallel_size=args.pp, sequence_parallel=args.tp > 1, ckpt_include_optimizer=False, ckpt_load_strictness="log_all", pipeline_dtype=torch.bfloat16, ) trainer = nl.Trainer( devices=min(args.tp * args.pp, 8), num_nodes=max(args.tp * args.pp // 8, 1), accelerator="gpu", strategy=strategy, plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), enable_checkpointing=False, ) if args.local_model_path: path = Path(args.local_model_path) model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) else: fabric = trainer.to_fabric() model = fabric.import_model("hf://google/gemma-3-4b-it", Gemma3VLModel) model = model.module.cuda() model.eval() from transformers import AutoProcessor model_id = 'google/gemma-3-4b-it' processor = AutoProcessor.from_pretrained(model_id) gemma_tokenizer = AutoTokenizer(model_id) hf_tokenizer = gemma_tokenizer.tokenizer messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, { "role": "user", "content": [ { "type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG", }, {"type": "text", "text": "What animal is on the candy?"}, ], }, ] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) input_ids = inputs["input_ids"].cuda() # add additional dim to (B, N, C, H, W) pixel_values = inputs["pixel_values"].cuda().unsqueeze(0).to(dtype=torch.bfloat16) position_ids = ( torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) ) generated_ids = input_ids.clone() stop_tokens = [1, 126] # Greedy generation loop for step in range(20): with torch.no_grad(): if torch.distributed.get_rank() == 0: print(step) fwd_bwd_function = get_forward_backward_func() iterator = SingleBatchIterator(pixel_values, input_ids, position_ids) output = fwd_bwd_function( forward_step_func=gemma3_forward_step, data_iterator=iterator, model=model, num_microbatches=1, forward_only=True, seq_length=input_ids.size(1), micro_batch_size=1, collect_non_loss_data=True, ) if isinstance(output, list) and len(output) > 0: output = output[0] if parallel_state.is_pipeline_last_stage(): world_size = parallel_state.get_tensor_model_parallel_world_size() gathered_tensors = [torch.zeros_like(output) for _ in range(world_size)] # All-gather operation dist.all_gather(gathered_tensors, output, group=parallel_state.get_tensor_model_parallel_group()) # Concatenate along last dimension (dim=2) output = torch.cat(gathered_tensors, dim=2) next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) else: next_token_ids = torch.ones((1, 1), device=generated_ids.device, dtype=generated_ids.dtype) torch.distributed.broadcast(next_token_ids, get_last_rank()) generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) input_ids = generated_ids position_ids = ( torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) .unsqueeze(0) .expand_as(input_ids) ) # If the generated token is the end of sequence token, stop generating if next_token_ids.item() in stop_tokens: break generated_texts = hf_tokenizer.decode(list(generated_ids[0])) if torch.distributed.get_rank() == 0: print("======== GENERATED TEXT OUTPUT ========") print(f"{generated_texts}") print("=======================================") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Gemma3 Multimodal Inference") parser.add_argument( "--local_model_path", type=str, default=None, help="Local path to the model if not loading from Hugging Face.", ) parser.add_argument('--tp', default=1) parser.add_argument('--pp', default=1) args = parser.parse_args() main(args)