Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from nemo.lightning.pytorch.utils import extract_dtypes | |
| from nemo.utils import logging | |
| SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) | |
| TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) | |
| F = TypeVar("F", bound=Callable[..., Any]) | |
| class TransformCTX: | |
| """Transform Data class Definition.""" | |
| source: nn.Module | |
| source_state: dict | |
| target: nn.Module | |
| target_state: dict | |
| class _ModelState: | |
| """ | |
| Helper class for used for to modify state dict of a source model during model conversion. | |
| """ | |
| def __init__(self, state_dict, config=None): | |
| self._state_dict = state_dict | |
| self.config = config | |
| def state_dict(self): | |
| # pylint: disable=C0115,C0116 | |
| return self._state_dict | |
| def to(self, dtype): | |
| # pylint: disable=C0115,C0116 | |
| for k, v in self._state_dict.items(): | |
| if v.dtype != dtype: | |
| logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") | |
| self._state_dict[k] = v.to(dtype) | |
| def apply_transforms( | |
| source: Union[nn.Module, _ModelState], | |
| target: TargetModuleT, | |
| mapping: Dict[str, str], | |
| transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [], | |
| state_dict_ignored_entries: List = [], | |
| cast_dtype: Optional[torch.dtype] = None, | |
| ) -> TargetModuleT: | |
| """ | |
| Applies a series of transformations to adapt the state dictionary of a source module to | |
| match the structure of a target module's state dictionary. | |
| This function renames keys according to a provided mapping and modifies values using a list | |
| of transformation functions. Each transformation function typically is decorated | |
| with `io.state_transform`. | |
| Args: | |
| source (nn.Module): The source module from which parameters and buffers are taken. | |
| target (TargetModuleT): The target module to which parameters and buffers are adapted. | |
| mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary | |
| is mapped to a corresponding key in the target state dictionary. | |
| transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions | |
| that modify the `TransformCTX` object. If None, no transformations beyond key renaming | |
| are applied. Defaults to None. | |
| state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases | |
| where multiple entries in model's state_dict point to one entry in model's named_parameter. | |
| E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, | |
| `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight | |
| in T5 Huggingface implementation.). In these cases, ignore redundant entries. | |
| cast_dtype Optional[torch.dtype]: case the output state dict to a certain precision. | |
| Returns | |
| ------- | |
| TargetModuleT: The modified target module with its state dictionary adjusted according to | |
| the specified mappings and transformations. | |
| Raises | |
| ------ | |
| ValueError: If there's a mismatch in shape between corresponding source and target parameters | |
| or buffers. | |
| RuntimeError: If the target state dictionary contains keys that are not present in the source | |
| state dictionary after all transformations. | |
| Examples | |
| -------- | |
| >>> source_module = nn.Linear(10, 5) | |
| >>> target_module = nn.Linear(10, 5) | |
| >>> mapping = {'weight': 'weights', 'bias': 'biases'} | |
| @io.state_transform( | |
| source_key="weight", | |
| target_key="weights" | |
| ) | |
| def scale_weights(ctx): | |
| ctx.target_state['weights'] = ctx.source_state['weight'] * 2 | |
| return ctx | |
| >>> transformed_target = apply_transforms( | |
| ... source_module, target_module, mapping, [scale_weights] | |
| ... ) | |
| >>> print(transformed_target.state_dict()['weights']) | |
| See Also | |
| -------- | |
| - `TransformCTX`: For more details on the context object used in transformations. | |
| - `StateDictTransform`: For creating complex transformations. | |
| Note: | |
| This function is particularly useful when adapting models from different frameworks or | |
| when consolidating models with different architectural changes. | |
| """ | |
| from megatron.core.transformer.module import MegatronModule | |
| # TODO: How can we improve this? | |
| _source = source | |
| if hasattr(source, "module") and isinstance(source.module, MegatronModule): | |
| _source = source.module | |
| _target = target | |
| if hasattr(target, "module") and isinstance(target.module, MegatronModule): | |
| _target = target.module | |
| # Track dtypes to make sure they weren't modified during conversion. | |
| target_orig_dtypes = extract_dtypes(_target.named_parameters()) | |
| target_state = _target.state_dict() | |
| ctx = TransformCTX( | |
| source=_source, | |
| source_state=_source.state_dict(), | |
| target=_target, | |
| target_state=target_state, | |
| ) | |
| for key, val in mapping.items(): | |
| logging.debug(f"Mapping {key} -> {val}") | |
| ctx = StateDictTransform(key, val)(ctx) | |
| for transform in transforms: | |
| logging.debug(f"Transforming {transform.source_key} -> {transform.target_key}") | |
| ctx = transform(ctx) | |
| _params: Dict[str, nn.Parameter] = {} | |
| for name, param in _target.named_parameters(): | |
| if name in target_state: | |
| target_param = target_state[name] | |
| if param.data.shape != target_param.shape: | |
| raise ValueError( | |
| f"Shape mismatch for parameter {name}: target shape {param.shape} vs " | |
| f"converted source shape {target_param.shape}" | |
| ) | |
| _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) | |
| target_state.pop(name) | |
| else: | |
| print(f"Unexpected key: {name} not in checkpoint but in model.") | |
| for key, val in _params.items(): | |
| _module, _key = _target, key | |
| if "." in key: | |
| for part in key.split(".")[:-1]: | |
| _module = getattr(_module, part) | |
| _key = key.split(".")[-1] | |
| _module.register_parameter(_key, val) | |
| _buffers = {} | |
| for name, buffer in _target.named_buffers(): | |
| if name in target_state: | |
| if buffer.shape != target_state[name].shape: | |
| raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") | |
| _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) | |
| target_state.pop(name) | |
| for key, val in _buffers.items(): | |
| _module, _key = _target, key | |
| if "." in key: | |
| for part in key.split(".")[:-1]: | |
| _module = getattr(_module, part) | |
| _key = key.split(".")[-1] | |
| _module.register_buffer(_key, val) | |
| keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) | |
| keys = [key for key in keys if key not in state_dict_ignored_entries] | |
| if len(keys) != 0: | |
| raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.") | |
| # TODO: Is this correct? | |
| # for key in target.state_dict(): | |
| # if key.endswith("_extra_state"): | |
| # del target.state_dict()[key] | |
| """finally: | |
| cls._set_model_restore_state(is_being_restored=False)""" | |
| meta_tensor_keys = [] | |
| for name, param in target.named_parameters(): | |
| if param.is_meta: | |
| meta_tensor_keys.append(name) | |
| assert not meta_tensor_keys, ( | |
| f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." | |
| f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" | |
| ) | |
| if cast_dtype: | |
| logging.info(f"Casting model to {cast_dtype}...") | |
| _target.to(cast_dtype) | |
| logging.info(f"Casting model to {cast_dtype} complete.") | |
| else: | |
| assert target_orig_dtypes == extract_dtypes(_target.named_parameters()), ( | |
| f"dtype mismatch between source and target state dicts. " | |
| f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v!=torch.bfloat16} }, " | |
| f"Right side is " | |
| f"{ {k: v for k, v in extract_dtypes(_target.named_parameters()).items() if v!=torch.bfloat16} }" | |
| ) | |
| if hasattr(target, "module") and isinstance(target.module, MegatronModule): | |
| target.module = _target | |
| return target | |
| return _target | |
| def _default_transform(inp): | |
| return inp | |
| class StateDictTransform(Generic[F]): | |
| """ | |
| A transformation class for state dictionaries, allowing for flexible key matching and | |
| transformation of values between source and target state dictionaries. | |
| Attributes | |
| ---------- | |
| source_key: A string, tuple of strings, or a dictionary specifying the keys in the source | |
| state dictionary to match. Wildcards (*) are supported. | |
| target_key: A string or tuple of strings specifying the keys in the target state dictionary | |
| to match. Wildcards (*) are supported. | |
| transform: A callable that performs the transformation on matched keys' values. | |
| Examples | |
| -------- | |
| >>> def example_transform(ctx, *args): | |
| ... return sum(args) | |
| >>> transform = StateDictTransform( | |
| ... source_key="model.layers.*.self_attn.*_proj.weight", | |
| ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", | |
| ... transform=example_transform | |
| ... ) | |
| """ | |
| def __init__( | |
| self, | |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], | |
| target_key: Union[str, Tuple[str, ...]], | |
| transform: F = _default_transform, | |
| ): | |
| self.source_key = source_key | |
| self.target_key = target_key | |
| self.transform = transform | |
| def __call__(self, ctx: TransformCTX) -> TransformCTX: | |
| source_key = self.source_key | |
| target_key = self.target_key | |
| source_dict, target_dict = ctx.source_state, ctx.target_state | |
| np.set_printoptions(threshold=10) | |
| fn_params = dict(inspect.signature(self.transform).parameters) | |
| fn_params.pop("ctx", None) | |
| matched = False | |
| if isinstance(source_key, (dict, tuple)): | |
| if isinstance(source_key, tuple): | |
| source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} | |
| else: | |
| source_key_dict = source_key | |
| source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} | |
| target_matches = _match_keys(list(target_dict.keys()), target_key) | |
| param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) | |
| source_matches = [ | |
| source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] | |
| for v in param_names | |
| ] | |
| target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] | |
| for layer_names_group in zip(*(source_matches + target_matches)): | |
| # Wrap in a list if it's a single layer (ie non-expert) | |
| if isinstance(layer_names_group[0], str): | |
| layer_names_group = [[x] for x in layer_names_group] | |
| for layer_names in zip(*layer_names_group): | |
| target_dict[layer_names[-1]] = self.call_transform( | |
| ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) | |
| ) | |
| logging.debug(f"Matched (transform)! {layer_names_group=}") | |
| matched = True | |
| else: | |
| source_keys = list(source_dict.keys()) | |
| target_keys = list(target_dict.keys()) | |
| source_matches = _match_keys(source_keys, source_key) | |
| if source_matches.size == 1 and source_matches == np.array(None): | |
| raise ValueError(f"No matches found for source key: {source_key}") | |
| if isinstance(target_key, str): | |
| target_matches = _match_keys(target_keys, target_key) | |
| if target_matches.size == 1 and target_matches == np.array(None): | |
| raise ValueError(f"No matches found for target key: {target_key}") | |
| else: | |
| if isinstance(target_key, dict): | |
| raise ValueError("Target key must be a string or a tuple of strings.") | |
| _matches = [_match_keys(target_keys, key) for key in target_key] | |
| target_matches = np.stack(_matches, axis=-1) | |
| # Determine if we are dealing with multiple source matches or multiple target matches | |
| multiple_sources = source_matches.ndim >= target_matches.ndim | |
| accepts_var_args = any( | |
| param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() | |
| ) | |
| if multiple_sources: | |
| for target_index, target_match in np.ndenumerate(target_matches): | |
| try: | |
| source_match = source_matches[target_index] | |
| except IndexError as e: | |
| logging.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}") | |
| raise e | |
| if accepts_var_args: | |
| source_values = [source_dict[k] for k in source_match] | |
| target_dict[target_match] = self.call_transform(ctx, *source_values) | |
| else: | |
| _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) | |
| if len(fn_params) != len(_source_match_list): | |
| raise ValueError( | |
| f"Mismatch between source and target keys: {source_match} vs {target_match}" | |
| ) | |
| kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} | |
| target_dict[target_match] = self.call_transform(ctx, **kwargs) | |
| logging.debug(f"Matched (multi source)! {target_match=} {source_match=}") | |
| matched = True | |
| else: | |
| for source_index, source_match in np.ndenumerate(source_matches): | |
| target_match = target_matches[source_index] | |
| source_values = ( | |
| [source_dict[source_match]] | |
| if np.isscalar(source_match) | |
| else [source_dict[k] for k in source_match] | |
| ) | |
| if accepts_var_args: | |
| outputs = self.call_transform(ctx, *source_values) | |
| else: | |
| kwargs = {param: val for param, val in zip(fn_params, source_values)} | |
| outputs = self.call_transform(ctx, **kwargs) | |
| if isinstance(target_match, str): | |
| target_dict[target_match] = outputs | |
| else: | |
| for i, t in enumerate(outputs): | |
| target_dict[target_match[i]] = t | |
| logging.debug(f"Matched (single source)! {target_match=} {source_match=}") | |
| matched = True | |
| if not matched: | |
| logging.warning(f"No matches found for source key: {source_key=} {target_key=}") | |
| return ctx | |
| def call_transform(self, ctx: TransformCTX, *args, **kwargs): | |
| """Perform transform and check if the given args valid.""" | |
| func_params = inspect.signature(self.transform).parameters | |
| expected_num_args = len([p for p in func_params if p not in ['self', 'ctx']]) | |
| provided_num_args = len(args) + len(kwargs) | |
| accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) | |
| if not accepts_var_args and provided_num_args != expected_num_args: | |
| raise ValueError( | |
| f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." | |
| ) | |
| if 'ctx' in func_params: | |
| return self.transform(ctx, *args, **kwargs) | |
| return self.transform(*args, **kwargs) | |
| def _match_keys(keys: List[str], pattern: str) -> np.ndarray: | |
| escaped_pattern = '' | |
| i = 0 | |
| wildcard_positions = [] | |
| while i < len(pattern): | |
| if pattern[i : i + 2] == '**': | |
| escaped_pattern += r'(.+)' # Match any characters including dots | |
| wildcard_positions.append('**') | |
| i += 2 | |
| elif pattern[i] == '*': | |
| escaped_pattern += r'([^.]+)' # Match any characters except dots | |
| wildcard_positions.append('*') | |
| i += 1 | |
| else: | |
| if pattern[i] == '.': | |
| escaped_pattern += r'\.' # Escape the dot | |
| else: | |
| escaped_pattern += pattern[i] | |
| i += 1 | |
| regex_pattern = re.compile("^" + escaped_pattern + "$") | |
| num_wildcards = len(wildcard_positions) | |
| wildcard_matches = [[] for _ in range(num_wildcards)] | |
| for key in filter(lambda x: x is not None, keys): | |
| match = regex_pattern.match(key) | |
| if match: | |
| for i, group in enumerate(match.groups()): | |
| if group not in wildcard_matches[i]: | |
| wildcard_matches[i].append(group) | |
| # Sort the wildcard matches to maintain consistent ordering | |
| for i in range(len(wildcard_matches)): | |
| wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) | |
| # Determine the shape of the output array based on the unique matches for each wildcard | |
| shape = [len(matches) for matches in wildcard_matches] | |
| if len(wildcard_matches) == 0: | |
| # If there is no wildcard matches, assuming it is a single match | |
| shape = [1] | |
| # Initialize an empty array with the determined shape | |
| output_array = np.empty(shape, dtype=object) | |
| # Populate the array with the keys, now that we have the correct shape and ordering | |
| for key in filter(lambda x: x is not None, keys): | |
| match = regex_pattern.match(key) | |
| if match: | |
| # Convert match groups to indices based on their position in wildcard_matches | |
| indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] | |
| output_array[tuple(indices)] = key # Place the key in the array based on the indices | |
| return output_array | |
| def state_transform( | |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], | |
| target_key: Union[str, Tuple[str, ...]], | |
| ) -> Callable[[F], StateDictTransform[F]]: ... | |
| def state_transform( | |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F | |
| ) -> StateDictTransform[F]: ... | |
| def state_transform( | |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], | |
| target_key: Union[str, Tuple[str, ...]], | |
| fn: Optional[F] = None, | |
| ): | |
| """ | |
| A decorator for creating StateDictTransform instances with specified source and target keys, | |
| and a transformation function. This allows for concise definition of state dictionary | |
| transformations. | |
| Args: | |
| source_key: A string, tuple of strings, or a dictionary specifying the keys in the source | |
| state dictionary to match. Wildcards (*) are supported. | |
| target_key: A string or tuple of strings specifying the keys in the target state dictionary | |
| to match. Wildcards (*) are supported. | |
| fn: An optional callable that performs the transformation on matched keys' values. If not | |
| provided, the decorator can be used to wrap a function definition. | |
| Returns | |
| ------- | |
| A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that | |
| takes a function and returns a StateDictTransform instance. | |
| Examples | |
| -------- | |
| >>> @state_transform( | |
| ... source_key="model.layers.*.self_attn.*_proj.weight", | |
| ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" | |
| ... ) | |
| ... def sum_transform(ctx, *args): | |
| ... return sum(args) | |
| """ | |
| def wrapper(fn) -> StateDictTransform: | |
| return StateDictTransform(source_key, target_key, fn) | |
| if fn is None: | |
| return wrapper | |
| return wrapper(fn) | |
| class TransformFns: | |
| """ | |
| A collection of common functions used in state dict transformation. | |
| """ | |
| def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): | |
| """ | |
| Split interleave-concatenated qkv to q, k, v | |
| Example: export layer linear_qkv to HF {q|k|v}_proj | |
| """ | |
| megatron_config = ctx.source.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| heads_per_group = head_num // num_query_groups | |
| # hidden_size = megatron_config.hidden_size | |
| head_size = megatron_config.kv_channels | |
| qkv_total_dim = head_num + 2 * num_query_groups | |
| linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) | |
| # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size | |
| # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r | |
| hidden_size = linear_qkv.size(-1) | |
| q_slice = torch.cat( | |
| [ | |
| torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) | |
| for i in range(num_query_groups) | |
| ] | |
| ) | |
| k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) | |
| v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) | |
| q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() | |
| k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() | |
| v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() | |
| return q_proj, k_proj, v_proj | |
| def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): | |
| """ | |
| Split interleave-concatenated qkv bias to separate q, k, v bias | |
| Example: export layer linear_qkv bias to HF {q|k|v}_proj bias | |
| """ | |
| megatron_config = ctx.source.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| heads_per_group = head_num // num_query_groups | |
| head_size = megatron_config.kv_channels | |
| qkv_total_dim = head_num + 2 * num_query_groups | |
| qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) | |
| q_slice = torch.cat( | |
| [ | |
| torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) | |
| for i in range(num_query_groups) | |
| ] | |
| ) | |
| k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) | |
| v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) | |
| q_bias = qkv_bias[q_slice].reshape(-1).cpu() | |
| k_bias = qkv_bias[k_slice].reshape(-1).cpu() | |
| v_bias = qkv_bias[v_slice].reshape(-1).cpu() | |
| return q_bias, k_bias, v_bias | |
| def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor): | |
| """ | |
| Merge naively concatenated q, k, v to interleave-concatenated qkv. | |
| Example: import HF qkv to layer linear_qkv | |
| """ | |
| megatron_config = ctx.target.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| head_size = megatron_config.kv_channels | |
| q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0) | |
| return TransformFns.merge_qkv(ctx, q, k, v) | |
| def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): | |
| """ | |
| Merge q, k, v to interleave-concatenated qkv. | |
| Example: import HF {q|k|v}_proj to layer linear_qkv | |
| """ | |
| megatron_config = ctx.target.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| heads_per_group = head_num // num_query_groups | |
| hidden_size = megatron_config.hidden_size | |
| head_size = megatron_config.kv_channels | |
| old_tensor_shape = q.size() | |
| new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] | |
| new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] | |
| q = q.view(*new_q_tensor_shape) | |
| k = k.view(*new_kv_tensor_shape) | |
| v = v.view(*new_kv_tensor_shape) | |
| qkv_weights_l = [] | |
| for i in range(num_query_groups): | |
| qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) | |
| qkv_weights_l.append(k[i : i + 1, :, :]) | |
| qkv_weights_l.append(v[i : i + 1, :, :]) | |
| qkv_weights = torch.cat(qkv_weights_l) | |
| assert qkv_weights.ndim == 3, qkv_weights.shape | |
| assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape | |
| assert qkv_weights.shape[1] == head_size, qkv_weights.shape | |
| assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape | |
| qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) | |
| return qkv_weights | |
| def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor): | |
| """ | |
| Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias. | |
| Example: import HF qkv bias to layer linear_qkv bias | |
| """ | |
| megatron_config = ctx.target.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| head_size = megatron_config.kv_channels | |
| qb, kb, vb = qkv_bias.split( | |
| [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 | |
| ) | |
| return TransformFns.merge_qkv_bias(ctx, qb, kb, vb) | |
| def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): | |
| """ | |
| Merge q, k, v bias to interleave-concatenated qkv bias. | |
| Example: import HF {q|k|v}_proj bias to layer linear_qkv bias | |
| """ | |
| megatron_config = ctx.target.config | |
| head_num = megatron_config.num_attention_heads | |
| num_query_groups = megatron_config.num_query_groups | |
| heads_per_group = head_num // num_query_groups | |
| head_size = megatron_config.kv_channels | |
| new_q_tensor_shape = (head_num, head_size) | |
| new_kv_tensor_shape = (num_query_groups, head_size) | |
| qb = qb.view(*new_q_tensor_shape) | |
| kb = kb.view(*new_kv_tensor_shape) | |
| vb = vb.view(*new_kv_tensor_shape) | |
| qkv_bias = torch.empty((0, head_size)).type_as(qb) | |
| for i in range(num_query_groups): | |
| qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) | |
| qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) | |
| qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) | |
| qkv_bias = qkv_bias.reshape( | |
| [ | |
| head_size * (head_num + 2 * num_query_groups), | |
| ] | |
| ) | |
| return qkv_bias | |
| def merge_fc1(gate: torch.Tensor, up: torch.Tensor): | |
| """ | |
| Merge gate and up proj into concatenated fc1 | |
| Example: import HF {gate|up}_proj to layer linear_fc1 | |
| """ | |
| return torch.cat((gate, up), dim=0) | |
| def split_fc1(linear_fc1: torch.Tensor): | |
| """ | |
| Split concatenated fc1 to gate and up proj | |
| Example: export layer linear_fc1 to HF {gate|up}_proj | |
| """ | |
| gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) | |
| return gate_proj, up_proj | |
| def duplicate2(param: torch.Tensor): | |
| """ | |
| Duplicate the source parameter to two target parameters | |
| Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A | |
| """ | |
| return param, param | |
| def duplicate3(param: torch.Tensor): | |
| """ | |
| Duplicate the source parameter to three target parameters | |
| Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A | |
| """ | |
| return param, param, param | |
| def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): | |
| """ | |
| Prune the embedding size to vocab size | |
| Example: export embedding/output layer to HF with non-padded vocab size | |
| """ | |
| megatron_config = ctx.target.config | |
| return embedding[: megatron_config.vocab_size, :] | |