subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload
import numpy as np
import torch
from torch import nn
from nemo.lightning.pytorch.utils import extract_dtypes
from nemo.utils import logging
SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module)
F = TypeVar("F", bound=Callable[..., Any])
@dataclass
class TransformCTX:
"""Transform Data class Definition."""
source: nn.Module
source_state: dict
target: nn.Module
target_state: dict
class _ModelState:
"""
Helper class for used for to modify state dict of a source model during model conversion.
"""
def __init__(self, state_dict, config=None):
self._state_dict = state_dict
self.config = config
def state_dict(self):
# pylint: disable=C0115,C0116
return self._state_dict
def to(self, dtype):
# pylint: disable=C0115,C0116
for k, v in self._state_dict.items():
if v.dtype != dtype:
logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)")
self._state_dict[k] = v.to(dtype)
@torch.no_grad
def apply_transforms(
source: Union[nn.Module, _ModelState],
target: TargetModuleT,
mapping: Dict[str, str],
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
state_dict_ignored_entries: List = [],
cast_dtype: Optional[torch.dtype] = None,
) -> TargetModuleT:
"""
Applies a series of transformations to adapt the state dictionary of a source module to
match the structure of a target module's state dictionary.
This function renames keys according to a provided mapping and modifies values using a list
of transformation functions. Each transformation function typically is decorated
with `io.state_transform`.
Args:
source (nn.Module): The source module from which parameters and buffers are taken.
target (TargetModuleT): The target module to which parameters and buffers are adapted.
mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary
is mapped to a corresponding key in the target state dictionary.
transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions
that modify the `TransformCTX` object. If None, no transformations beyond key renaming
are applied. Defaults to None.
state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases
where multiple entries in model's state_dict point to one entry in model's named_parameter.
E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`,
`decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight
in T5 Huggingface implementation.). In these cases, ignore redundant entries.
cast_dtype Optional[torch.dtype]: case the output state dict to a certain precision.
Returns
-------
TargetModuleT: The modified target module with its state dictionary adjusted according to
the specified mappings and transformations.
Raises
------
ValueError: If there's a mismatch in shape between corresponding source and target parameters
or buffers.
RuntimeError: If the target state dictionary contains keys that are not present in the source
state dictionary after all transformations.
Examples
--------
>>> source_module = nn.Linear(10, 5)
>>> target_module = nn.Linear(10, 5)
>>> mapping = {'weight': 'weights', 'bias': 'biases'}
@io.state_transform(
source_key="weight",
target_key="weights"
)
def scale_weights(ctx):
ctx.target_state['weights'] = ctx.source_state['weight'] * 2
return ctx
>>> transformed_target = apply_transforms(
... source_module, target_module, mapping, [scale_weights]
... )
>>> print(transformed_target.state_dict()['weights'])
See Also
--------
- `TransformCTX`: For more details on the context object used in transformations.
- `StateDictTransform`: For creating complex transformations.
Note:
This function is particularly useful when adapting models from different frameworks or
when consolidating models with different architectural changes.
"""
from megatron.core.transformer.module import MegatronModule
# TODO: How can we improve this?
_source = source
if hasattr(source, "module") and isinstance(source.module, MegatronModule):
_source = source.module
_target = target
if hasattr(target, "module") and isinstance(target.module, MegatronModule):
_target = target.module
# Track dtypes to make sure they weren't modified during conversion.
target_orig_dtypes = extract_dtypes(_target.named_parameters())
target_state = _target.state_dict()
ctx = TransformCTX(
source=_source,
source_state=_source.state_dict(),
target=_target,
target_state=target_state,
)
for key, val in mapping.items():
logging.debug(f"Mapping {key} -> {val}")
ctx = StateDictTransform(key, val)(ctx)
for transform in transforms:
logging.debug(f"Transforming {transform.source_key} -> {transform.target_key}")
ctx = transform(ctx)
_params: Dict[str, nn.Parameter] = {}
for name, param in _target.named_parameters():
if name in target_state:
target_param = target_state[name]
if param.data.shape != target_param.shape:
raise ValueError(
f"Shape mismatch for parameter {name}: target shape {param.shape} vs "
f"converted source shape {target_param.shape}"
)
_params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad)
target_state.pop(name)
else:
print(f"Unexpected key: {name} not in checkpoint but in model.")
for key, val in _params.items():
_module, _key = _target, key
if "." in key:
for part in key.split(".")[:-1]:
_module = getattr(_module, part)
_key = key.split(".")[-1]
_module.register_parameter(_key, val)
_buffers = {}
for name, buffer in _target.named_buffers():
if name in target_state:
if buffer.shape != target_state[name].shape:
raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}")
_buffers[name] = nn.Parameter(target_state[name], requires_grad=False)
target_state.pop(name)
for key, val in _buffers.items():
_module, _key = _target, key
if "." in key:
for part in key.split(".")[:-1]:
_module = getattr(_module, part)
_key = key.split(".")[-1]
_module.register_buffer(_key, val)
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
keys = [key for key in keys if key not in state_dict_ignored_entries]
if len(keys) != 0:
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")
# TODO: Is this correct?
# for key in target.state_dict():
# if key.endswith("_extra_state"):
# del target.state_dict()[key]
"""finally:
cls._set_model_restore_state(is_being_restored=False)"""
meta_tensor_keys = []
for name, param in target.named_parameters():
if param.is_meta:
meta_tensor_keys.append(name)
assert not meta_tensor_keys, (
f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion."
f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?"
)
if cast_dtype:
logging.info(f"Casting model to {cast_dtype}...")
_target.to(cast_dtype)
logging.info(f"Casting model to {cast_dtype} complete.")
else:
assert target_orig_dtypes == extract_dtypes(_target.named_parameters()), (
f"dtype mismatch between source and target state dicts. "
f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v!=torch.bfloat16} }, "
f"Right side is "
f"{ {k: v for k, v in extract_dtypes(_target.named_parameters()).items() if v!=torch.bfloat16} }"
)
if hasattr(target, "module") and isinstance(target.module, MegatronModule):
target.module = _target
return target
return _target
def _default_transform(inp):
return inp
class StateDictTransform(Generic[F]):
"""
A transformation class for state dictionaries, allowing for flexible key matching and
transformation of values between source and target state dictionaries.
Attributes
----------
source_key: A string, tuple of strings, or a dictionary specifying the keys in the source
state dictionary to match. Wildcards (*) are supported.
target_key: A string or tuple of strings specifying the keys in the target state dictionary
to match. Wildcards (*) are supported.
transform: A callable that performs the transformation on matched keys' values.
Examples
--------
>>> def example_transform(ctx, *args):
... return sum(args)
>>> transform = StateDictTransform(
... source_key="model.layers.*.self_attn.*_proj.weight",
... target_key="decoder.layers.*.self_attention.linear_qkv.weight",
... transform=example_transform
... )
"""
def __init__(
self,
source_key: Union[str, Tuple[str, ...], Dict[str, str]],
target_key: Union[str, Tuple[str, ...]],
transform: F = _default_transform,
):
self.source_key = source_key
self.target_key = target_key
self.transform = transform
def __call__(self, ctx: TransformCTX) -> TransformCTX:
source_key = self.source_key
target_key = self.target_key
source_dict, target_dict = ctx.source_state, ctx.target_state
np.set_printoptions(threshold=10)
fn_params = dict(inspect.signature(self.transform).parameters)
fn_params.pop("ctx", None)
matched = False
if isinstance(source_key, (dict, tuple)):
if isinstance(source_key, tuple):
source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)}
else:
source_key_dict = source_key
source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()}
target_matches = _match_keys(list(target_dict.keys()), target_key)
param_names = list(filter(lambda x: x in source_matches_dict, fn_params))
source_matches = [
source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()]
for v in param_names
]
target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]]
for layer_names_group in zip(*(source_matches + target_matches)):
# Wrap in a list if it's a single layer (ie non-expert)
if isinstance(layer_names_group[0], str):
layer_names_group = [[x] for x in layer_names_group]
for layer_names in zip(*layer_names_group):
target_dict[layer_names[-1]] = self.call_transform(
ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]]))
)
logging.debug(f"Matched (transform)! {layer_names_group=}")
matched = True
else:
source_keys = list(source_dict.keys())
target_keys = list(target_dict.keys())
source_matches = _match_keys(source_keys, source_key)
if source_matches.size == 1 and source_matches == np.array(None):
raise ValueError(f"No matches found for source key: {source_key}")
if isinstance(target_key, str):
target_matches = _match_keys(target_keys, target_key)
if target_matches.size == 1 and target_matches == np.array(None):
raise ValueError(f"No matches found for target key: {target_key}")
else:
if isinstance(target_key, dict):
raise ValueError("Target key must be a string or a tuple of strings.")
_matches = [_match_keys(target_keys, key) for key in target_key]
target_matches = np.stack(_matches, axis=-1)
# Determine if we are dealing with multiple source matches or multiple target matches
multiple_sources = source_matches.ndim >= target_matches.ndim
accepts_var_args = any(
param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values()
)
if multiple_sources:
for target_index, target_match in np.ndenumerate(target_matches):
try:
source_match = source_matches[target_index]
except IndexError as e:
logging.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}")
raise e
if accepts_var_args:
source_values = [source_dict[k] for k in source_match]
target_dict[target_match] = self.call_transform(ctx, *source_values)
else:
_source_match_list = [source_match] if isinstance(source_match, str) else list(source_match)
if len(fn_params) != len(_source_match_list):
raise ValueError(
f"Mismatch between source and target keys: {source_match} vs {target_match}"
)
kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)}
target_dict[target_match] = self.call_transform(ctx, **kwargs)
logging.debug(f"Matched (multi source)! {target_match=} {source_match=}")
matched = True
else:
for source_index, source_match in np.ndenumerate(source_matches):
target_match = target_matches[source_index]
source_values = (
[source_dict[source_match]]
if np.isscalar(source_match)
else [source_dict[k] for k in source_match]
)
if accepts_var_args:
outputs = self.call_transform(ctx, *source_values)
else:
kwargs = {param: val for param, val in zip(fn_params, source_values)}
outputs = self.call_transform(ctx, **kwargs)
if isinstance(target_match, str):
target_dict[target_match] = outputs
else:
for i, t in enumerate(outputs):
target_dict[target_match[i]] = t
logging.debug(f"Matched (single source)! {target_match=} {source_match=}")
matched = True
if not matched:
logging.warning(f"No matches found for source key: {source_key=} {target_key=}")
return ctx
def call_transform(self, ctx: TransformCTX, *args, **kwargs):
"""Perform transform and check if the given args valid."""
func_params = inspect.signature(self.transform).parameters
expected_num_args = len([p for p in func_params if p not in ['self', 'ctx']])
provided_num_args = len(args) + len(kwargs)
accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values())
if not accepts_var_args and provided_num_args != expected_num_args:
raise ValueError(
f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}."
)
if 'ctx' in func_params:
return self.transform(ctx, *args, **kwargs)
return self.transform(*args, **kwargs)
def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
escaped_pattern = ''
i = 0
wildcard_positions = []
while i < len(pattern):
if pattern[i : i + 2] == '**':
escaped_pattern += r'(.+)' # Match any characters including dots
wildcard_positions.append('**')
i += 2
elif pattern[i] == '*':
escaped_pattern += r'([^.]+)' # Match any characters except dots
wildcard_positions.append('*')
i += 1
else:
if pattern[i] == '.':
escaped_pattern += r'\.' # Escape the dot
else:
escaped_pattern += pattern[i]
i += 1
regex_pattern = re.compile("^" + escaped_pattern + "$")
num_wildcards = len(wildcard_positions)
wildcard_matches = [[] for _ in range(num_wildcards)]
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
for i, group in enumerate(match.groups()):
if group not in wildcard_matches[i]:
wildcard_matches[i].append(group)
# Sort the wildcard matches to maintain consistent ordering
for i in range(len(wildcard_matches)):
wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x)
# Determine the shape of the output array based on the unique matches for each wildcard
shape = [len(matches) for matches in wildcard_matches]
if len(wildcard_matches) == 0:
# If there is no wildcard matches, assuming it is a single match
shape = [1]
# Initialize an empty array with the determined shape
output_array = np.empty(shape, dtype=object)
# Populate the array with the keys, now that we have the correct shape and ordering
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
# Convert match groups to indices based on their position in wildcard_matches
indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())]
output_array[tuple(indices)] = key # Place the key in the array based on the indices
return output_array
@overload
def state_transform(
source_key: Union[str, Tuple[str, ...], Dict[str, str]],
target_key: Union[str, Tuple[str, ...]],
) -> Callable[[F], StateDictTransform[F]]: ...
@overload
def state_transform(
source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F
) -> StateDictTransform[F]: ...
def state_transform(
source_key: Union[str, Tuple[str, ...], Dict[str, str]],
target_key: Union[str, Tuple[str, ...]],
fn: Optional[F] = None,
):
"""
A decorator for creating StateDictTransform instances with specified source and target keys,
and a transformation function. This allows for concise definition of state dictionary
transformations.
Args:
source_key: A string, tuple of strings, or a dictionary specifying the keys in the source
state dictionary to match. Wildcards (*) are supported.
target_key: A string or tuple of strings specifying the keys in the target state dictionary
to match. Wildcards (*) are supported.
fn: An optional callable that performs the transformation on matched keys' values. If not
provided, the decorator can be used to wrap a function definition.
Returns
-------
A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that
takes a function and returns a StateDictTransform instance.
Examples
--------
>>> @state_transform(
... source_key="model.layers.*.self_attn.*_proj.weight",
... target_key="decoder.layers.*.self_attention.linear_qkv.weight"
... )
... def sum_transform(ctx, *args):
... return sum(args)
"""
def wrapper(fn) -> StateDictTransform:
return StateDictTransform(source_key, target_key, fn)
if fn is None:
return wrapper
return wrapper(fn)
class TransformFns:
"""
A collection of common functions used in state dict transformation.
"""
@staticmethod
def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor):
"""
Split interleave-concatenated qkv to q, k, v
Example: export layer linear_qkv to HF {q|k|v}_proj
"""
megatron_config = ctx.source.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
# hidden_size = megatron_config.hidden_size
head_size = megatron_config.kv_channels
qkv_total_dim = head_num + 2 * num_query_groups
linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1])
# when converting base model (linear_qkv), hidden size = megatron_config.hidden_size
# when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r
hidden_size = linear_qkv.size(-1)
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()
return q_proj, k_proj, v_proj
@staticmethod
def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor):
"""
Split interleave-concatenated qkv bias to separate q, k, v bias
Example: export layer linear_qkv bias to HF {q|k|v}_proj bias
"""
megatron_config = ctx.source.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
head_size = megatron_config.kv_channels
qkv_total_dim = head_num + 2 * num_query_groups
qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
q_bias = qkv_bias[q_slice].reshape(-1).cpu()
k_bias = qkv_bias[k_slice].reshape(-1).cpu()
v_bias = qkv_bias[v_slice].reshape(-1).cpu()
return q_bias, k_bias, v_bias
@staticmethod
def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor):
"""
Merge naively concatenated q, k, v to interleave-concatenated qkv.
Example: import HF qkv to layer linear_qkv
"""
megatron_config = ctx.target.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
head_size = megatron_config.kv_channels
q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0)
return TransformFns.merge_qkv(ctx, q, k, v)
@staticmethod
def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Merge q, k, v to interleave-concatenated qkv.
Example: import HF {q|k|v}_proj to layer linear_qkv
"""
megatron_config = ctx.target.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_size = megatron_config.kv_channels
old_tensor_shape = q.size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]
q = q.view(*new_q_tensor_shape)
k = k.view(*new_kv_tensor_shape)
v = v.view(*new_kv_tensor_shape)
qkv_weights_l = []
for i in range(num_query_groups):
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
qkv_weights_l.append(k[i : i + 1, :, :])
qkv_weights_l.append(v[i : i + 1, :, :])
qkv_weights = torch.cat(qkv_weights_l)
assert qkv_weights.ndim == 3, qkv_weights.shape
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
return qkv_weights
@staticmethod
def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor):
"""
Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias.
Example: import HF qkv bias to layer linear_qkv bias
"""
megatron_config = ctx.target.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
head_size = megatron_config.kv_channels
qb, kb, vb = qkv_bias.split(
[head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0
)
return TransformFns.merge_qkv_bias(ctx, qb, kb, vb)
@staticmethod
def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor):
"""
Merge q, k, v bias to interleave-concatenated qkv bias.
Example: import HF {q|k|v}_proj bias to layer linear_qkv bias
"""
megatron_config = ctx.target.config
head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
head_size = megatron_config.kv_channels
new_q_tensor_shape = (head_num, head_size)
new_kv_tensor_shape = (num_query_groups, head_size)
qb = qb.view(*new_q_tensor_shape)
kb = kb.view(*new_kv_tensor_shape)
vb = vb.view(*new_kv_tensor_shape)
qkv_bias = torch.empty((0, head_size)).type_as(qb)
for i in range(num_query_groups):
qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :]))
qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :]))
qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :]))
qkv_bias = qkv_bias.reshape(
[
head_size * (head_num + 2 * num_query_groups),
]
)
return qkv_bias
@staticmethod
def merge_fc1(gate: torch.Tensor, up: torch.Tensor):
"""
Merge gate and up proj into concatenated fc1
Example: import HF {gate|up}_proj to layer linear_fc1
"""
return torch.cat((gate, up), dim=0)
@staticmethod
def split_fc1(linear_fc1: torch.Tensor):
"""
Split concatenated fc1 to gate and up proj
Example: export layer linear_fc1 to HF {gate|up}_proj
"""
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)
return gate_proj, up_proj
@staticmethod
def duplicate2(param: torch.Tensor):
"""
Duplicate the source parameter to two target parameters
Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A
"""
return param, param
@staticmethod
def duplicate3(param: torch.Tensor):
"""
Duplicate the source parameter to three target parameters
Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A
"""
return param, param, param
@staticmethod
def prune_padding(ctx: TransformCTX, embedding: torch.Tensor):
"""
Prune the embedding size to vocab size
Example: export embedding/output layer to HF with non-padded vocab size
"""
megatron_config = ctx.target.config
return embedding[: megatron_config.vocab_size, :]