MagpieTTS_Internal_Demo / scripts /vlm /gemma3vl_import.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma3VL checkpoint import."""
import argparse
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
MODEL_DICT = {
"gemma3_vl_4b_it": ("google/gemma-3-4b-it", llm.Gemma3Config4B),
"gemma3_vl_27b_it": ("google/gemma-3-27b-it", llm.Gemma3Config27B),
}
def main(args: argparse.Namespace):
hf_model_name, language_config_class = MODEL_DICT[args.model]
tokenizer = AutoTokenizer(hf_model_name)
language_transformer_config = language_config_class()
# The default cross_entropy_fusion_impl is `te`, which will not calculate
# loss properly with label < 0.
language_transformer_config.cross_entropy_fusion_impl = "native"
vision_transformer_config = vlm.Gemma3VLVisionConfig()
vision_projection_config = vlm.Gemma3VLMultimodalProjectorConfig(
input_size=vision_transformer_config.hidden_size,
hidden_size=language_transformer_config.hidden_size,
)
gemma3vl_config = vlm.Gemma3VLConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
freeze_language_model=False,
freeze_vision_model=True,
freeze_vision_projection=True,
)
model = vlm.Gemma3VLModel(gemma3vl_config, tokenizer=tokenizer)
llm.import_ckpt(model=model, source=f"hf://{hf_model_name}", overwrite=args.overwrite)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gemma3VL checkpoint import.")
parser.add_argument("--model", type=str, required=False, default="gemma3_vl_4b_it")
parser.add_argument("--overwrite", type=bool, required=False, default=False)
parsed_args = parser.parse_args()
main(parsed_args)