subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import AutoProcessor, MllamaConfig
from transformers.models.mllama.configuration_mllama import MllamaTextConfig, MllamaVisionConfig
from nemo import lightning as nl
from nemo.collections import vlm
def split_qkv_weight(qkv_weight, model_config):
"""Split attention qkv from nemo to hf format"""
hidden_size = model_config.hidden_size
head_num = model_config.num_attention_heads
num_query_groups = model_config.num_query_groups or head_num
head_size = model_config.kv_channels or (hidden_size // head_num)
heads_per_group = head_num // num_query_groups
qkv_weight = qkv_weight.reshape(-1, head_size, hidden_size)
q_weight = torch.empty((head_num, head_size, hidden_size), device=qkv_weight.device)
k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device)
v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device)
qkv_index = 0
for i in range(num_query_groups):
q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[
qkv_index : qkv_index + heads_per_group, :, :
]
qkv_index += heads_per_group
k_weight[i, :, :] = qkv_weight[qkv_index, :, :]
qkv_index += 1
v_weight[i, :, :] = qkv_weight[qkv_index, :, :]
qkv_index += 1
return [('q_proj', q_weight), ('k_proj', k_weight), ('v_proj', v_weight)]
def split_kv_weight(kv_weight, model_config):
"""Split cross attention qkv from nemo to hf format"""
hidden_size = model_config.hidden_size
head_num = model_config.num_attention_heads
num_query_groups = model_config.num_query_groups or head_num
head_size = model_config.kv_channels or (hidden_size // head_num)
kv_weight = kv_weight.reshape(-1, head_size, hidden_size)
k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device)
v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=kv_weight.device)
kv_index = 0
for i in range(num_query_groups):
k_weight[i, :, :] = kv_weight[kv_index, :, :]
kv_index += 1
v_weight[i, :, :] = kv_weight[kv_index, :, :]
kv_index += 1
return [('k_proj', k_weight), ('v_proj', v_weight)]
def split_gate_weight(gate_weight):
"""Split linear fc to gate"""
gate_weight = torch.chunk(gate_weight, 2, axis=0)
return [('gate_proj', gate_weight[0]), ('up_proj', gate_weight[1])]
def convert_mllama_config(source_vision, source_text):
"""Convert nemo mllama config to hf config"""
vision_config = MllamaVisionConfig(
num_hidden_layers=source_vision.num_layers,
hidden_size=source_vision.hidden_size,
attention_heads=source_vision.num_attention_heads,
image_size=source_vision.vision_chunk_size,
max_num_tiles=source_vision.vision_max_num_chunks,
torch_dtype="bfloat16",
)
cross_attention_layers = [
x + i for i, x in enumerate(source_text._init_fusion_schedule(source_text.num_cross_attention_layers))
]
text_config = MllamaTextConfig(
rope_theta=source_text.rotary_base,
num_hidden_layers=source_text.num_layers + source_text.num_cross_attention_layers,
cross_attention_layers=cross_attention_layers,
hidden_size=source_text.hidden_size,
intermediate_size=source_text.ffn_hidden_size,
num_attention_heads=source_text.num_attention_heads,
num_key_value_heads=source_text.num_query_groups,
vocab_size=source_text.vocab_size,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
eos_token_id=[128001, 128008, 128009],
torch_dtype="bfloat16",
)
return MllamaConfig(vision_config, text_config, torch_dtype="bfloat16")
def convert_mllama_nemo_to_hf(checkpoint_path, processor_name):
"""Convert nemo mllama to hf state dict and config"""
processor = AutoProcessor.from_pretrained(processor_name)
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
ckpt_load_optimizer=False,
ckpt_save_optimizer=False,
)
trainer = nl.Trainer(
devices=1,
max_steps=1000,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
val_check_interval=1000,
limit_val_batches=50,
)
fabric = trainer.to_fabric()
tokenizer = processor.tokenizer
model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer)
config = model.config
vision_model_config = config.vision_model_config
language_model_config = config.language_model_config
model = fabric.load_model(checkpoint_path, model)
model = model.module.module.module.module
state_dict = model.state_dict()
del model
v = "vision_model.vision_encoder"
key_map = [
("vision_model.class_embedding", f"{v}.class_embedding"),
("vision_model.gated_positional_embedding.embedding", f"{v}.positional_embedding"),
(
"vision_model.gated_positional_embedding.tile_embedding.weight",
f"{v}.gated_tile_positional_embedding.weight",
),
("vision_model.gated_positional_embedding.gate", f"{v}.gated_positional_embedding_gate"),
("vision_model.layernorm_post.bias", f"{v}.ln_post.bias"),
("vision_model.layernorm_post.weight", f"{v}.ln_post.weight"),
("vision_model.layernorm_pre.bias", f"{v}.ln_pre.bias"),
("vision_model.layernorm_pre.weight", f"{v}.ln_pre.weight"),
("vision_model.post_tile_positional_embedding.embedding.weight", f"{v}.post_tile_pos_embed.embedding.weight"),
("vision_model.post_tile_positional_embedding.gate", f"{v}.post_tile_pos_embed.gate"),
("vision_model.pre_tile_positional_embedding.embedding.weight", f"{v}.pre_tile_pos_embed.embedding.weight"),
("vision_model.pre_tile_positional_embedding.gate", f"{v}.pre_tile_pos_embed.gate"),
("multi_modal_projector.bias", "vision_model.vision_projection.encoder.bias"),
("multi_modal_projector.weight", "vision_model.vision_projection.encoder.weight"),
("language_model.model.norm.weight", "language_model.decoder.final_layernorm.weight"),
("language_model.lm_head.weight", "language_model.output_layer.weight"),
]
for i in range(vision_model_config.num_layers):
key_map.extend(
[
(
f"vision_model.transformer.layers.{i}.self_attn.o_proj.weight",
f"{v}.transformer.layers.{i}.self_attention.linear_proj.weight",
),
(
f"vision_model.transformer.layers.{i}.input_layernorm.bias",
f"{v}.transformer.layers.{i}.input_layernorm.bias",
),
(
f"vision_model.transformer.layers.{i}.input_layernorm.weight",
f"{v}.transformer.layers.{i}.input_layernorm.weight",
),
(
f"vision_model.transformer.layers.{i}.post_attention_layernorm.bias",
f"{v}.transformer.layers.{i}.pre_mlp_layernorm.bias",
),
(
f"vision_model.transformer.layers.{i}.post_attention_layernorm.weight",
f"{v}.transformer.layers.{i}.pre_mlp_layernorm.weight",
),
(
f"vision_model.transformer.layers.{i}.mlp.fc1.bias",
f"{v}.transformer.layers.{i}.mlp.linear_fc1.bias",
),
(
f"vision_model.transformer.layers.{i}.mlp.fc1.weight",
f"{v}.transformer.layers.{i}.mlp.linear_fc1.weight",
),
(
f"vision_model.transformer.layers.{i}.mlp.fc2.bias",
f"{v}.transformer.layers.{i}.mlp.linear_fc2.bias",
),
(
f"vision_model.transformer.layers.{i}.mlp.fc2.weight",
f"{v}.transformer.layers.{i}.mlp.linear_fc2.weight",
),
]
)
for i in range(vision_model_config.num_global_layers):
key_map.extend(
[
(
f"vision_model.global_transformer.layers.{i}.self_attn.o_proj.weight",
f"{v}.global_transformer.layers.{i}.self_attention.linear_proj.weight",
),
(
f"vision_model.global_transformer.layers.{i}.gate_attn",
f"{v}.global_transformer.layers.{i}.gate_attn",
),
(
f"vision_model.global_transformer.layers.{i}.gate_ffn",
f"{v}.global_transformer.layers.{i}.gate_ffn",
),
(
f"vision_model.global_transformer.layers.{i}.input_layernorm.bias",
f"{v}.global_transformer.layers.{i}.input_layernorm.bias",
),
(
f"vision_model.global_transformer.layers.{i}.input_layernorm.weight",
f"{v}.global_transformer.layers.{i}.input_layernorm.weight",
),
(
f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.bias",
f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.bias",
),
(
f"vision_model.global_transformer.layers.{i}.post_attention_layernorm.weight",
f"{v}.global_transformer.layers.{i}.pre_mlp_layernorm.weight",
),
(
f"vision_model.global_transformer.layers.{i}.mlp.fc1.bias",
f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.bias",
),
(
f"vision_model.global_transformer.layers.{i}.mlp.fc1.weight",
f"{v}.global_transformer.layers.{i}.mlp.linear_fc1.weight",
),
(
f"vision_model.global_transformer.layers.{i}.mlp.fc2.bias",
f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.bias",
),
(
f"vision_model.global_transformer.layers.{i}.mlp.fc2.weight",
f"{v}.global_transformer.layers.{i}.mlp.linear_fc2.weight",
),
]
)
cross_attention_frequency = language_model_config.num_layers // language_model_config.num_cross_attention_layers
toal_num_layer = language_model_config.num_layers + language_model_config.num_cross_attention_layers
prefix = "language_model.decoder"
for i in range(toal_num_layer):
cross_num = (i - 3) // (cross_attention_frequency + 1)
if (i - 3) % (cross_attention_frequency + 1) == 0:
xattn_index = cross_num * cross_attention_frequency + 3
key_map.extend(
[
(
f"language_model.model.layers.{i}.cross_attn.o_proj.weight",
f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_proj.weight",
),
(
f"language_model.model.layers.{i}.cross_attn.q_proj.weight",
f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.weight",
),
(
f"language_model.model.layers.{i}.cross_attn.k_norm.weight",
f"{prefix}.xattn_layers.{xattn_index}.cross_attention.k_layernorm.weight",
),
(
f"language_model.model.layers.{i}.input_layernorm.weight",
f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_q.layer_norm_weight",
),
(
f"language_model.model.layers.{i}.cross_attn.q_norm.weight",
f"{prefix}.xattn_layers.{xattn_index}.cross_attention.q_layernorm.weight",
),
(
f"language_model.model.layers.{i}.post_attention_layernorm.weight",
f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.layer_norm_weight",
),
(
f"language_model.model.layers.{i}.mlp.down_proj.weight",
f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc2.weight",
),
(
f"language_model.model.layers.{i}.cross_attn_attn_gate",
f"{prefix}.xattn_layers.{xattn_index}.gate_attn",
),
(
f"language_model.model.layers.{i}.cross_attn_mlp_gate",
f"{prefix}.xattn_layers.{xattn_index}.gate_ffn",
),
]
)
else:
attn_index = i - cross_num - 1
key_map.extend(
[
(
f"language_model.model.layers.{i}.self_attn.o_proj.weight",
f"{prefix}.layers.{attn_index}.self_attention.linear_proj.weight",
),
(
f"language_model.model.layers.{i}.post_attention_layernorm.weight",
f"{prefix}.layers.{attn_index}.mlp.linear_fc1.layer_norm_weight",
),
(
f"language_model.model.layers.{i}.mlp.down_proj.weight",
f"{prefix}.layers.{attn_index}.mlp.linear_fc2.weight",
),
(
f"language_model.model.layers.{i}.input_layernorm.weight",
f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.layer_norm_weight",
),
]
)
new_state_dict = {}
for new_key, old_key in key_map:
new_state_dict[new_key] = state_dict[old_key]
def convert_vision_qkv_weight(state_dict, vision_model_config):
hidden_size = vision_model_config.hidden_size
new_state_dict = {}
for i in range(vision_model_config.num_layers):
qkv_weights = state_dict[
f"vision_model.vision_encoder.transformer.layers.{i}.self_attention.linear_qkv.weight"
]
for name, weight in split_qkv_weight(qkv_weights, vision_model_config):
new_key = f'vision_model.transformer.layers.{i}.self_attn.{name}.weight'
new_state_dict[new_key] = weight.reshape(-1, hidden_size)
for i in range(vision_model_config.num_global_layers):
qkv_weights = state_dict[
f"vision_model.vision_encoder.global_transformer.layers.{i}.self_attention.linear_qkv.weight"
]
for name, weight in split_qkv_weight(qkv_weights, vision_model_config):
new_key = f'vision_model.global_transformer.layers.{i}.self_attn.{name}.weight'
new_state_dict[new_key] = weight.reshape(-1, hidden_size)
return new_state_dict
def convert_patch_embeding(state_dict):
conv1_weight = state_dict["vision_model.vision_encoder.conv1._linear.weight"]
return {"vision_model.patch_embedding.weight": conv1_weight.reshape(conv1_weight.shape[0], 3, 14, 14)}
def convert_language_qkv_weight(state_dict, language_model_config):
hidden_size = language_model_config.hidden_size
new_state_dict = {}
for i in range(toal_num_layer):
cross_num = (i - 3) // (cross_attention_frequency + 1)
if (i - 3) % (cross_attention_frequency + 1) == 0:
xattn_index = cross_num * cross_attention_frequency + 3
kv_weights = state_dict[f"{prefix}.xattn_layers.{xattn_index}.cross_attention.linear_kv.weight"]
for name, weight in split_kv_weight(kv_weights, language_model_config):
new_key = f"language_model.model.layers.{i}.cross_attn.{name}.weight"
new_state_dict[new_key] = weight.reshape(-1, hidden_size)
else:
attn_index = i - cross_num - 1
qkv_weights = state_dict[f"{prefix}.layers.{attn_index}.self_attention.linear_qkv.weight"]
for name, weight in split_qkv_weight(qkv_weights, language_model_config):
new_key = f"language_model.model.layers.{i}.self_attn.{name}.weight"
new_state_dict[new_key] = weight.reshape(-1, hidden_size)
return new_state_dict
def convert_gate(state_dict):
new_state_dict = {}
for i in range(toal_num_layer):
cross_num = (i - 3) // (cross_attention_frequency + 1)
if (i - 3) % (cross_attention_frequency + 1) == 0:
xattn_index = cross_num * cross_attention_frequency + 3
gate_weight = state_dict[f"{prefix}.xattn_layers.{xattn_index}.mlp.linear_fc1.weight"]
else:
attn_index = i - cross_num - 1
gate_weight = state_dict[f"{prefix}.layers.{attn_index}.mlp.linear_fc1.weight"]
for name, weight in split_gate_weight(gate_weight):
new_key = f"language_model.model.layers.{i}.mlp.{name}.weight"
new_state_dict[new_key] = weight
return new_state_dict
def convert_embedding(state_dict):
word_embeddings = state_dict["language_model.embedding.word_embeddings.weight"]
learnable_embedding = state_dict["language_model.learnable_embedding.weight"]
return {"language_model.model.embed_tokens.weight": torch.cat((word_embeddings, learnable_embedding), dim=0)}
new_state_dict.update(convert_vision_qkv_weight(state_dict, vision_model_config))
new_state_dict.update(convert_patch_embeding(state_dict))
new_state_dict.update(convert_language_qkv_weight(state_dict, language_model_config))
new_state_dict.update(convert_gate(state_dict))
new_state_dict.update(convert_embedding(state_dict))
return new_state_dict, convert_mllama_config(vision_model_config, language_model_config)