Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Gemma3VL checkpoint import.""" | |
| import argparse | |
| from nemo.collections import llm, vlm | |
| from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer | |
| MODEL_DICT = { | |
| "gemma3_vl_4b_it": ("google/gemma-3-4b-it", llm.Gemma3Config4B), | |
| "gemma3_vl_27b_it": ("google/gemma-3-27b-it", llm.Gemma3Config27B), | |
| } | |
| def main(args: argparse.Namespace): | |
| hf_model_name, language_config_class = MODEL_DICT[args.model] | |
| tokenizer = AutoTokenizer(hf_model_name) | |
| language_transformer_config = language_config_class() | |
| # The default cross_entropy_fusion_impl is `te`, which will not calculate | |
| # loss properly with label < 0. | |
| language_transformer_config.cross_entropy_fusion_impl = "native" | |
| vision_transformer_config = vlm.Gemma3VLVisionConfig() | |
| vision_projection_config = vlm.Gemma3VLMultimodalProjectorConfig( | |
| input_size=vision_transformer_config.hidden_size, | |
| hidden_size=language_transformer_config.hidden_size, | |
| ) | |
| gemma3vl_config = vlm.Gemma3VLConfig( | |
| language_transformer_config=language_transformer_config, | |
| vision_transformer_config=vision_transformer_config, | |
| vision_projection_config=vision_projection_config, | |
| freeze_language_model=False, | |
| freeze_vision_model=True, | |
| freeze_vision_projection=True, | |
| ) | |
| model = vlm.Gemma3VLModel(gemma3vl_config, tokenizer=tokenizer) | |
| llm.import_ckpt(model=model, source=f"hf://{hf_model_name}", overwrite=args.overwrite) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Gemma3VL checkpoint import.") | |
| parser.add_argument("--model", type=str, required=False, default="gemma3_vl_4b_it") | |
| parser.add_argument("--overwrite", type=bool, required=False, default=False) | |
| parsed_args = parser.parse_args() | |
| main(parsed_args) | |