# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from transformers import AutoProcessor, MllamaConfig from transformers.models.mllama.configuration_mllama import MllamaTextConfig, MllamaVisionConfig from nemo import lightning as nl from nemo.collections import vlm def split_qkv_weight(qkv_weight, model_config): """Split attention qkv from nemo to hf format""" hidden_size = model_config.hidden_size head_num = model_config.num_attention_heads num_query_groups = model_config.num_query_groups or head_num head_size = model_config.kv_channels or (hidden_size // head_num) heads_per_group = head_num // num_query_groups qkv_weight = qkv_weight.reshape(-1, head_size, hidden_size) q_weight = torch.empty((head_num, head_size, hidden_size), device=qkv_weight.device) k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) qkv_index = 0 for i in range(num_query_groups): q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ qkv_index : qkv_index + heads_per_group, :, : ] qkv_index += heads_per_group k_weight[i, :, :] = qkv_weight[qkv_index, :, :] qkv_index += 1 v_weight[i, :, :] = qkv_weight[qkv_index, :, :] qkv_index += 1 return [('q_proj', q_weight), ('k_proj', k_weight), ('v_proj', v_weight)] def split_kv_weight(kv_weight, model_config): """Split cross attention qkv from nemo to hf format""" hidden_size = model_config.hidden_size head_num = model_config.num_attention_heads num_query_groups = model_config.num_query_groups or head_num head_size = model_config.kv_channels or (hidden_size // head_num) kv_weight = kv_weight.reshape(-1, head_size, hidden_size) k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device) v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device) kv_index = 0 for i in range(num_query_groups): k_weight[i, :, :] = kv_weight[kv_index, :, :] kv_index += 1 v_weight[i, :, :] = kv_weight[kv_index, :, :] kv_index += 1 return [('k_proj', k_weight), ('v_proj', v_weight)] def split_gate_weight(gate_weight): """Split linear fc to gate""" gate_weight = torch.chunk(gate_weight, 2, axis=0) return [('gate_proj', gate_weight[0]), ('up_proj', gate_weight[1])] def convert_mllama_config(source_vision, source_text): """Convert nemo mllama config to hf config""" vision_config = MllamaVisionConfig( num_hidden_layers=source_vision.num_layers, hidden_size=source_vision.hidden_size, attention_heads=source_vision.num_attention_heads, image_size=source_vision.vision_chunk_size, max_num_tiles=source_vision.vision_max_num_chunks, torch_dtype="bfloat16", ) cross_attention_layers = [ x + i for i, x in enumerate(source_text._init_fusion_schedule(source_text.num_cross_attention_layers)) ] text_config = MllamaTextConfig( rope_theta=source_text.rotary_base, num_hidden_layers=source_text.num_layers + source_text.num_cross_attention_layers, cross_attention_layers=cross_attention_layers, hidden_size=source_text.hidden_size, intermediate_size=source_text.ffn_hidden_size, num_attention_heads=source_text.num_attention_heads, num_key_value_heads=source_text.num_query_groups, vocab_size=source_text.vocab_size, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, eos_token_id=[128001, 128008, 128009], torch_dtype="bfloat16", ) return MllamaConfig(vision_config, text_config, torch_dtype="bfloat16") def convert_mllama_nemo_to_hf(checkpoint_path, processor_name): """Convert nemo mllama to hf state dict and config""" processor = AutoProcessor.from_pretrained(processor_name) strategy = nl.MegatronStrategy( tensor_model_parallel_size=1, ckpt_load_optimizer=False, ckpt_save_optimizer=False, ) trainer = nl.Trainer( devices=1, max_steps=1000, accelerator="gpu", strategy=strategy, plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), val_check_interval=1000, limit_val_batches=50, ) fabric = trainer.to_fabric() tokenizer = processor.tokenizer model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer) config = model.config vision_model_config = config.vision_model_config language_model_config = config.language_model_config model = fabric.load_model(checkpoint_path, model) model = model.module.module.module.module state_dict = model.state_dict() del model v = "vision_model.vision_encoder" key_map = [ ("vision_model.class_embedding", f"{v}.class_embedding"), ("vision_model.gated_positional_embedding.embedding", f"{v}.positional_embedding"), ( "vision_model.gated_positional_embedding.tile_embedding.weight", f"{v}.gated_tile_positional_embedding.weight", ), ("vision_model.gated_positional_embedding.gate", f"{v}.gated_positional_embedding_gate"), ("vision_model.layernorm_post.bias", f"{v}.ln_post.bias"), ("vision_model.layernorm_post.weight", f"{v}.ln_post.weight"), ("vision_model.layernorm_pre.bias", f"{v}.ln_pre.bias"), ("vision_model.layernorm_pre.weight", f"{v}.ln_pre.weight"), ("vision_model.post_tile_positional_embedding.embedding.weight", f"{v}.post_tile_pos_embed.embedding.weight"), ("vision_model.post_tile_positional_embedding.gate", f"{v}.post_tile_pos_embed.gate"), ("vision_model.pre_tile_positional_embedding.embedding.weight", f"{v}.pre_tile_pos_embed.embedding.weight"), ("vision_model.pre_tile_positional_embedding.gate", f"{v}.pre_tile_pos_embed.gate"), ("multi_modal_projector.bias", "vision_model.vision_projection.encoder.bias"), ("multi_modal_projector.weight", "vision_model.vision_projection.encoder.weight"), ("language_model.model.norm.weight", "language_model.decoder.final_layernorm.weight"), ("language_model.lm_head.weight", "language_model.output_layer.weight"), ] for i in range(vision_model_config.num_layers): key_map.extend( [ ( f"vision_model.transformer.layers.{i}.self_attn.o_proj.weight", f"{v}.transformer.layers.{i}.self_attention.linear_proj.weight", ), ( f"vision_model.transformer.layers.{i}.input_layernorm.bias", f"{v}.transformer.layers.{i}.input_layernorm.bias", ), ( f"vision_model.transformer.layers.{i}.input_layernorm.weight", f"{v}.transformer.layers.{i}.input_layernorm.weight", ), ( f"vision_model.transformer.layers.{i}.post_attention_layernorm.bias", f"{v}.transformer.layers.{i}.pre_mlp_layernorm.bias", ), ( f"vision_model.transformer.layers.{i}.post_attention_layernorm.weight", f"{v}.transformer.layers.{i}.pre_mlp_layernorm.weight", ), ( f"vision_model.transformer.layers.{i}.mlp.fc1.bias", f"{v}.transformer.layers.{i}.mlp.linear_fc1.bias", ), ( f"vision_model.transformer.layers.{i}.mlp.fc1.weight", f"{v}.transformer.layers.{i}.mlp.linear_fc1.weight", ), ( f"vision_model.transformer.layers.{i}.mlp.fc2.bias", f"{v}.transformer.layers.{i}.mlp.linear_fc2.bias", ), ( f"vision_model.transformer.layers.{i}.mlp.fc2.weight", f"{v}.transformer.layers.{i}.mlp.linear_fc2.weight", ), ] ) for i in range(vision_model_config.num_global_layers): key_map.extend( [ ( f"vision_model.global_transformer.layers.{i}.self_attn.o_proj.weight", f"{v}.global_transformer.layers.{i}.self_attention.linear_proj.weight", ), ( f"vision_model.global_transformer.layers.{i}.gate_attn", f"{v}.global_transformer.layers.{i}.gate_attn", ), ( f"vision_model.global_transformer.layers.{i}.gate_ffn", f"{v}.global_transformer.layers.{i}.gate_ffn", ), ( f"vision_model.global_transformer.layers.{i}.input_layernorm.bias", f"{v}.global_transformer.layers.{i}.input_layernorm.bias", ), ( f"vision_model.global_transformer.layers.{i}.input_layernorm.weight", f"{v}.global_transformer.layers.{i}.input_layernorm.weight", ), ( f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.bias", f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.bias", ), ( f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.weight", f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.weight", ), ( f"vision_model.global_transformer.layers.{i}.mlp.fc1.bias", f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.bias", ), ( f"vision_model.global_transformer.layers.{i}.mlp.fc1.weight", f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.weight", ), ( f"vision_model.global_transformer.layers.{i}.mlp.fc2.bias", f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.bias", ), ( f"vision_model.global_transformer.layers.{i}.mlp.fc2.weight", f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.weight", ), ] ) cross_attention_frequency = language_model_config.num_layers // language_model_config.num_cross_attention_layers toal_num_layer = language_model_config.num_layers + language_model_config.num_cross_attention_layers prefix = "language_model.decoder" for i in range(toal_num_layer): cross_num = (i - 3) // (cross_attention_frequency + 1) if (i - 3) % (cross_attention_frequency + 1) == 0: xattn_index = cross_num * cross_attention_frequency + 3 key_map.extend( [ ( f"language_model.model.layers.{i}.cross_attn.o_proj.weight", f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_proj.weight", ), ( f"language_model.model.layers.{i}.cross_attn.q_proj.weight", f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.weight", ), ( f"language_model.model.layers.{i}.cross_attn.k_norm.weight", f"{prefix}.xattn_layers.{xattn_index}.cross_attention.k_layernorm.weight", ), ( f"language_model.model.layers.{i}.input_layernorm.weight", f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.layer_norm_weight", ), ( f"language_model.model.layers.{i}.cross_attn.q_norm.weight", f"{prefix}.xattn_layers.{xattn_index}.cross_attention.q_layernorm.weight", ), ( f"language_model.model.layers.{i}.post_attention_layernorm.weight", f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.layer_norm_weight", ), ( f"language_model.model.layers.{i}.mlp.down_proj.weight", f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc2.weight", ), ( f"language_model.model.layers.{i}.cross_attn_attn_gate", f"{prefix}.xattn_layers.{xattn_index}.gate_attn", ), ( f"language_model.model.layers.{i}.cross_attn_mlp_gate", f"{prefix}.xattn_layers.{xattn_index}.gate_ffn", ), ] ) else: attn_index = i - cross_num - 1 key_map.extend( [ ( f"language_model.model.layers.{i}.self_attn.o_proj.weight", f"{prefix}.layers.{attn_index}.self_attention.linear_proj.weight", ), ( f"language_model.model.layers.{i}.post_attention_layernorm.weight", f"{prefix}.layers.{attn_index}.mlp.linear_fc1.layer_norm_weight", ), ( f"language_model.model.layers.{i}.mlp.down_proj.weight", f"{prefix}.layers.{attn_index}.mlp.linear_fc2.weight", ), ( f"language_model.model.layers.{i}.input_layernorm.weight", f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.layer_norm_weight", ), ] ) new_state_dict = {} for new_key, old_key in key_map: new_state_dict[new_key] = state_dict[old_key] def convert_vision_qkv_weight(state_dict, vision_model_config): hidden_size = vision_model_config.hidden_size new_state_dict = {} for i in range(vision_model_config.num_layers): qkv_weights = state_dict[ f"vision_model.vision_encoder.transformer.layers.{i}.self_attention.linear_qkv.weight" ] for name, weight in split_qkv_weight(qkv_weights, vision_model_config): new_key = f'vision_model.transformer.layers.{i}.self_attn.{name}.weight' new_state_dict[new_key] = weight.reshape(-1, hidden_size) for i in range(vision_model_config.num_global_layers): qkv_weights = state_dict[ f"vision_model.vision_encoder.global_transformer.layers.{i}.self_attention.linear_qkv.weight" ] for name, weight in split_qkv_weight(qkv_weights, vision_model_config): new_key = f'vision_model.global_transformer.layers.{i}.self_attn.{name}.weight' new_state_dict[new_key] = weight.reshape(-1, hidden_size) return new_state_dict def convert_patch_embeding(state_dict): conv1_weight = state_dict["vision_model.vision_encoder.conv1._linear.weight"] return {"vision_model.patch_embedding.weight": conv1_weight.reshape(conv1_weight.shape[0], 3, 14, 14)} def convert_language_qkv_weight(state_dict, language_model_config): hidden_size = language_model_config.hidden_size new_state_dict = {} for i in range(toal_num_layer): cross_num = (i - 3) // (cross_attention_frequency + 1) if (i - 3) % (cross_attention_frequency + 1) == 0: xattn_index = cross_num * cross_attention_frequency + 3 kv_weights = state_dict[f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_kv.weight"] for name, weight in split_kv_weight(kv_weights, language_model_config): new_key = f"language_model.model.layers.{i}.cross_attn.{name}.weight" new_state_dict[new_key] = weight.reshape(-1, hidden_size) else: attn_index = i - cross_num - 1 qkv_weights = state_dict[f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.weight"] for name, weight in split_qkv_weight(qkv_weights, language_model_config): new_key = f"language_model.model.layers.{i}.self_attn.{name}.weight" new_state_dict[new_key] = weight.reshape(-1, hidden_size) return new_state_dict def convert_gate(state_dict): new_state_dict = {} for i in range(toal_num_layer): cross_num = (i - 3) // (cross_attention_frequency + 1) if (i - 3) % (cross_attention_frequency + 1) == 0: xattn_index = cross_num * cross_attention_frequency + 3 gate_weight = state_dict[f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.weight"] else: attn_index = i - cross_num - 1 gate_weight = state_dict[f"{prefix}.layers.{attn_index}.mlp.linear_fc1.weight"] for name, weight in split_gate_weight(gate_weight): new_key = f"language_model.model.layers.{i}.mlp.{name}.weight" new_state_dict[new_key] = weight return new_state_dict def convert_embedding(state_dict): word_embeddings = state_dict["language_model.embedding.word_embeddings.weight"] learnable_embedding = state_dict["language_model.learnable_embedding.weight"] return {"language_model.model.embed_tokens.weight": torch.cat((word_embeddings, learnable_embedding), dim=0)} new_state_dict.update(convert_vision_qkv_weight(state_dict, vision_model_config)) new_state_dict.update(convert_patch_embeding(state_dict)) new_state_dict.update(convert_language_qkv_weight(state_dict, language_model_config)) new_state_dict.update(convert_gate(state_dict)) new_state_dict.update(convert_embedding(state_dict)) return new_state_dict, convert_mllama_config(vision_model_config, language_model_config)