PatrickHaller's picture
Upload modeling_xqwen.py with huggingface_hub
1b2dd8a verified
from typing import Callable, Optional, Tuple, Union
from dataclasses import dataclass
import functools
import torch
from torch import nn
import torch.nn.init as init
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
# from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
GradientCheckpointingLayer,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
# from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils import auto_docstring, can_return_tuple, logging
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ImportError:
print("Flash Attention is not installed. Please install it to use xQwenForCausalLM with Flash Attention.")
# from transformers.masking_utils import causal_mask_mapping
try:
from fla.layers.gated_deltaproduct import GatedDeltaProduct
fla_available = True
except:
fla_available = False
from fla.modules import ShortConvolution
from fla.modules.feature_map import HedgehogFeatureMap
from .configuration_xqwen import xQwenConfig
logger = logging.get_logger(__name__)
from xlstm.xlstm_large.model import (
mLSTMStateType,
soft_cap,
# mLSTMLayer,
mLSTMLayerConfig,
mLSTMBackendConfig,
mLSTMLayerStateType,
mLSTMBackend,
MultiHeadLayerNorm
)
class xLSTMCache:
"""
Cache / RNN State handler for xLSTM.
Args:
config: xLSTMConfig
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
"""
def __init__(
self, config, batch_size: int, dtype: torch.dtype = torch.bfloat16, device: Optional[str] = None
):
self.seqlen_offset = torch.tensor(0, dtype=torch.int64, device=device)
self.dtype = dtype
self.config = config
self.qk_head_dim = self.config.head_dim
self.v_head_dim = self.config.head_dim
self.rnn_state: mLSTMStateType = {
layer: (
torch.zeros(
[batch_size, config.num_heads, self.qk_head_dim, self.v_head_dim], dtype=dtype, device=device
),
torch.zeros([batch_size, config.num_heads, self.qk_head_dim], dtype=dtype, device=device),
torch.zeros([batch_size, config.num_heads, 1], dtype=dtype, device=device),
)
for layer in range(config.num_hidden_layers)
}
self.rnn_state_initial = True
def reset(self):
self.rnn_state = {
layer: (
torch.zeros_like(self.rnn_state[layer][0]),
torch.zeros_like(self.rnn_state[layer][1]),
torch.zeros_like(self.rnn_state[layer][2]),
)
for layer in self.rnn_state
}
self.rnn_state_initial = True
@dataclass
class xQwenModelOutputWithPast(BaseModelOutputWithPast):
cache_params: Optional[xLSTMCache] = None
@dataclass
class xQwenCausalLMOutput(CausalLMOutputWithPast):
cache_params: Optional[xLSTMCache] = None
@use_kernel_forward_from_hub("RMSNorm")
class xQwenRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
xQwenRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class xQwenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
if self.config.mlp_dropout > 0.0:
self.dropout = nn.Dropout(config.mlp_dropout)
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
if self.config.mlp_dropout > 0.0:
down_proj = self.dropout(down_proj)
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class xQwenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: xQwenConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = xQwenRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = xQwenRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class mLSTMLayer(nn.Module):
def __init__(self, config: mLSTMLayerConfig):
super().__init__()
self.config = config
# self.head_dim = config.embedding_dim // config.num_heads
self.head_dim = self.config.head_dim
self.num_key_value_groups = config.num_heads // config.num_key_value_heads
self.v_dim = int(config.embedding_dim * config.v_dim_factor)
self.qk_dim = int(config.embedding_dim * config.qk_dim_factor)
if self.config.weight_mode == "single":
self.q = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.config.num_heads * self.head_dim,
bias=self.config.use_bias,
)
self.k = nn.Linear(
in_features=self.config.hidden_size,
out_features=config.num_key_value_heads * self.head_dim,
bias=self.config.use_bias,
)
self.v = nn.Linear(
in_features=self.config.hidden_size,
out_features=config.num_key_value_heads * self.head_dim,
bias=self.config.use_bias,
)
self.ogate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.head_dim * self.config.num_heads,
# out_features=self.config.hidden_size,
bias=self.config.use_bias,
)
self.igate_preact = nn.Linear(
# in_features=self.head_dim * self.config.num_heads,
in_features=self.config.hidden_size,
out_features=self.config.num_heads,
bias=True,
)
self.fgate_preact = nn.Linear(
# in_features=self.head_dim * self.config.num_heads,
in_features=self.config.hidden_size,
out_features=self.config.num_heads,
bias=True,
)
elif self.config.weight_mode == "fused":
self.qkv_opreact = nn.Linear(
in_features=self.config.hidden_size,
out_features=2 * self.qk_dim + 2 * self.v_dim,
bias=self.config.use_bias,
)
self.ifgate_preact = nn.Linear(
in_features=self.config.hidden_size,
out_features=2 * self.config.num_heads,
bias=True,
)
self.ogate_act_fn = nn.Sigmoid()
self.mlstm_backend = mLSTMBackend(config=self.config.mlstm_backend_config())
self.multihead_norm = MultiHeadLayerNorm(
num_heads=self.config.num_heads,
head_dim=self.head_dim,
eps=self.config.norm_eps,
use_weight=True,
use_bias=self.config.use_bias,
force_float32_reductions=self.config.norm_reduction_force_float32,
)
self.out_proj = nn.Linear(
in_features=self.head_dim * self.config.num_heads,
out_features=self.config.hidden_size,
bias=self.config.use_bias,
)
if self.config.use_sliding_window:
self.block_mask = None
self.swa_attention = None
if self.config.swa_modulation == "dynamic":
self.swa_alpha = nn.Parameter(
torch.tensor(
0.5, dtype=torch.float32, requires_grad=True
)
)
if self.config.use_short_conv:
self.q_conv1d = ShortConvolution(
hidden_size=self.config.hidden_size,
kernel_size=self.config.conv_size,
bias=False,
activation='silu'
)
self.k_conv1d = ShortConvolution(
hidden_size=self.config.hidden_size,
kernel_size=self.config.conv_size,
bias=False,
activation='silu'
)
self.v_conv1d = ShortConvolution(
hidden_size=self.config.hidden_size,
kernel_size=self.config.conv_size,
bias=False,
activation='silu'
)
if self.config.use_hedgehog:
self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_dim)
self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_dim)
def set_swa_block_mask(self, q_len, mem_window=4):
block_mask = self.get_swa_block(with_memory=self.config.swa_with_memory, mem_window=mem_window)
self.block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len)
self.swa_attention = functools.partial(
flex_attention, block_mask=self.block_mask
)
self.q_len = q_len
def get_swa_block(self, with_memory=False, mem_window=None):
if with_memory:
assert mem_window is not None, "mem_window must be specified for sliding window with memory"
def swa_with_memory(b, h, q_idx, kv_idx):
""" Sliding window causal attention with memory.
Add mask so model always attents to first m tokens in the sequence.
"""
causal_mask = q_idx >= kv_idx
window_mask = (q_idx - kv_idx) <= self.config.sliding_window
memory_mask = kv_idx < mem_window
return (causal_mask & window_mask) | memory_mask
return swa_with_memory
def sliding_window_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
window_mask = (q_idx - kv_idx) <= self.config.sliding_window
return causal_mask & window_mask
return sliding_window_causal
def forward(
self, x: torch.Tensor,
state: mLSTMLayerStateType | None = None,
output_attentions: bool = False,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
B, S, _ = x.shape
if self.config.weight_mode == "single":
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.config.use_short_conv:
q, _ = self.q_conv1d(q)
k, _ = self.k_conv1d(k)
v, _ = self.v_conv1d(v)
o_preact = self.ogate_preact(x)
i_preact = soft_cap(
self.igate_preact(x), cap_value=self.config.gate_soft_cap
)
f_preact = soft_cap(
self.fgate_preact(x), cap_value=self.config.gate_soft_cap
)
elif self.config.weight_mode == "fused":
qkv_opreact = self.qkv_opreact(x)
q, k, v, o_preact = torch.tensor_split(
qkv_opreact,
(
self.qk_dim,
2 * self.qk_dim,
2 * self.qk_dim + self.v_dim,
),
dim=-1,
)
if_preact = soft_cap(
self.ifgate_preact(x), cap_value=self.config.gate_soft_cap
)
i_preact, f_preact = torch.tensor_split(
if_preact, (self.config.num_heads,), dim=-1
)
q = q.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
k = k.reshape(B, S, self.config.num_key_value_heads, -1).transpose(1, 2)
v = v.reshape(B, S, self.config.num_key_value_heads, -1).transpose(1, 2)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.use_hedgehog:
q = self.feature_map_q(q)
k = self.feature_map_k(k)
if self.config.use_sliding_window:
sq, sk, sv = q, k, v
# assert position_ids is not None, "position_ids must be provided for sliding window attention"
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0)
cos, sin = position_embeddings
sq, sk, = apply_rotary_pos_emb(sq, sk, cos, sin)
i_preact = i_preact.transpose(1, 2)
f_preact = f_preact.transpose(1, 2)
if state is None:
c_initial, n_initial, m_initial = None, None, None
else:
c_initial, n_initial, m_initial = state
h, state = self.mlstm_backend(
q=q,
k=k,
v=v,
i=i_preact,
f=f_preact,
c_initial=c_initial,
n_initial=n_initial,
m_initial=m_initial,
)
h = h.transpose(1, 2)
h_norm = self.multihead_norm(h)
if self.config.use_sliding_window:
if sq.dtype == torch.float32:
sq, sk, sv = sq.to(torch.float16), sk.to(torch.float16), sv.to(torch.float16)
q_len = sq.size(-2)
if self.block_mask is None or self.swa_attention is None:
self.set_swa_block_mask(q_len, mem_window=self.config.sliding_window_memory)
elif self.q_len != q_len:
self.set_swa_block_mask(q_len, mem_window=self.config.sliding_window_memory)
y = self.swa_attention(sq, sk, sv).transpose(1, 2)
# y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
# sq.transpose(1, 2),
# sk.transpose(1, 2),
# sv.transpose(1, 2),
# attention_mask,
# q_len,
# position_ids=position_ids,
# dropout=0.0,
# sliding_window=self.config.sliding_window,
# use_top_left_mask=False,
# is_causal=True,
# target_dtype=torch.float32,
# )
# TODO: Indepent normalization for sliding window?
y = self.multihead_norm(y)
if self.config.swa_modulation == "static":
out = 0.5 * y + 0.5 * h_norm
elif self.config.swa_modulation == "dynamic":
if self.config.swa_modulation_bounded:
out = y + torch.tanh(self.swa_alpha) * h_norm
else:
out = y + self.swa_alpha * h_norm
else:
out = y
# raise ValueError("Unknown sliding window modulation type: {}".format(self.config.swa_modulation))
else:
out = h_norm
out = out.reshape(B, S, -1)
out = self.ogate_act_fn(o_preact) * out
y = self.out_proj(out)
return y, state
token_mixer_type = {
"qwen_attention": xQwenAttention,
"xlstm_attention": mLSTMLayer,
}
def build_mlstm_config(config):
return config
return mLSTMLayerConfig(
embedding_dim=config.embedding_dim,
num_heads=config.num_heads,
use_bias=config.use_bias,
norm_eps=config.rms_norm_eps,
norm_reduction_force_float32=config.norm_reduction_force_float32,
qk_dim_factor=1,
v_dim_factor=1,
num_key_value_heads=config.num_key_value_heads,
gate_soft_cap=config.gate_soft_cap,
weight_mode="single",
mlstm_backend=mLSTMBackendConfig(
chunkwise_kernel=config.chunkwise_kernel,
sequence_kernel=config.sequence_kernel,
step_kernel=config.step_kernel,
mode=config.mode,
chunk_size=config.chunk_size,
return_last_states=config.return_last_states,
autocast_kernel_dtype=config.autocast_kernel_dtype,
eps=config.eps,
inference_state_dtype=config.inference_state_dtype,
),
)
def build_gdp(config):
assert fla_available, "GatedDeltaProduct requires fla package to be installed."
# config.hidden_size = 512
return GatedDeltaProduct(
hidden_size=config.hidden_size,
expand_v=1,
head_dim=config.hidden_size // config.num_attention_heads,
num_heads=config.num_attention_heads,
use_output_gate=False,
use_short_conv=True,
use_forget_gate=True,
num_householder=2
)
class xQwenDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: xQwenConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attention_type = config.layer_types[layer_idx]
if self.attention_type == "qwen_attention":
self.self_attn = xQwenAttention(config=config, layer_idx=layer_idx)
elif self.attention_type == "xlstm_attention":
self.self_attn = mLSTMLayer(build_mlstm_config(config))
elif self.attention_type == "gdp_attention":
self.self_attn = build_gdp(config)
else:
raise ValueError("Unsupported attention type: {}".format(self.attention_type))
self.mlp = xQwenMLP(config)
self.input_layernorm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
state: mLSTMStateType | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if output_attentions:
return None, self.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
position_ids=position_ids,
position_embeddings=position_embeddings,
)
# Self Attention
hidden_states, *state = self.self_attn(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
position_ids=position_ids,
position_embeddings=position_embeddings,
state=state,
)
if len(state) == 1:
state = state[0] # unpack the single state tuple
else:
state = None
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs, state
@auto_docstring
class xQwenPreTrainedModel(PreTrainedModel):
config_class = xQwenConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["xQwenDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, xQwenRMSNorm):
module.weight.data.fill_(1.0)
class xQwenRotaryEmbedding(nn.Module):
def __init__(self, config: xQwenConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings\
self.config = config
# Hedgehog feature map doubles the hidden size for q and k
if self.config.use_sliding_window and self.config.use_hedgehog:
# self.config.hidden_size *= 2
self.config.head_dim *= 2
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@auto_docstring
class xQwenModel(xQwenPreTrainedModel):
config_class = xQwenConfig
def __init__(self, config: xQwenConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[xQwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = xQwenRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_params: Optional[xLSTMCache] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
# if not isinstance(causal_mask_mapping := attention_mask, dict):
# # Prepare mask arguments
# mask_kwargs = {
# "config": self.config,
# "input_embeds": inputs_embeds,
# "attention_mask": attention_mask,
# "cache_position": cache_position,
# "past_key_values": past_key_values,
# }
# # Create the masks
# causal_mask_mapping = {
# "full_attention": create_causal_mask(**mask_kwargs),
# }
# # The sliding window alternating layers are not always activated depending on the config
# if self.has_sliding_layers:
# causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
use_cache = False
if use_cache:
if cache_params is None:
cache_params = xLSTMCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
else:
cache_params = None
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
inference_with_cache = (
not self.training
and self.config.max_inference_chunksize < hidden_states.shape[1]
and not output_hidden_states
)
if inference_with_cache:
all_hidden_states = None
offset = 0
with torch.no_grad():
if cache_params is None:
cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0])
final_state = torch.zeros_like(hidden_states)
while offset < hidden_states.shape[1]:
hidden_states_chunk = hidden_states[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
]
for i, layer in self.layers[: self.config.num_hidden_layers]:
hidden_state_chunk, rnn_state = layer(
hidden_state_chunk,
state=cache_params.rnn_state[i],
)
for state_idx in range(len(cache_params.rnn_state[1])):
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
final_state[
:, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
] = hidden_state_chunk
offset += self.config.max_inference_chunksize
hidden_states = final_state
else:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs, rnn_state = decoder_layer(
hidden_states,
# attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
state=cache_params.rnn_state[i] if cache_params is not None else None,
**flash_attn_kwargs,
)
if cache_params:
for state_idx in range(len(cache_params.rnn_state[i])):
local_rnn_state = rnn_state[state_idx]
cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
cache_params.rnn_state_initial = False
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return xQwenModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cache_params= cache_params if use_cache else None
)
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@auto_docstring
class xQwenForCausalLM(xQwenPreTrainedModel, GenerationMixin):
config_class = xQwenConfig
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = xQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, xQwenForCausalLM
>>> model = xQwenForCausalLM.from_pretrained("Qwen/xQwen-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/xQwen-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
# return CausalLMOutputWithPast(
return xQwenCausalLMOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def copy_from_teacher(self, teacher, copy_qkv: bool = True):
assert len(self.model.layers) == len(teacher.model.layers)
self.model.embed_tokens.weight.data.copy_(teacher.get_input_embeddings().weight.data)
self.model.norm.weight.data.copy_(teacher.model.norm.weight.data)
self.lm_head.weight.data.copy_(teacher.get_output_embeddings().weight.data)
for self_layer, teacher_layer in zip(self.model.layers, teacher.model.layers):
# self_layer.token_mixer.load_state_dict(teacher_layer.token_mixer.state_dict())
self_layer.mlp.load_state_dict(teacher_layer.mlp.state_dict())
self_layer.input_layernorm.load_state_dict(teacher_layer.input_layernorm.state_dict())
self_layer.post_attention_layernorm.load_state_dict(teacher_layer.post_attention_layernorm.state_dict())
if copy_qkv:
self_layer.self_attn.q.load_state_dict(teacher_layer.self_attn.q_proj.state_dict())
self_layer.self_attn.out_proj.load_state_dict(teacher_layer.self_attn.o_proj.state_dict())
v_proj_unrolled = teacher_layer.self_attn.v_proj
k_proj_unrolled = teacher_layer.self_attn.k_proj
self_layer.self_attn.v.load_state_dict(v_proj_unrolled.state_dict())
self_layer.self_attn.k.load_state_dict(k_proj_unrolled.state_dict())
self_layer.self_attn.igate_preact.bias.data.fill_(torch.log(torch.tensor(2.0)))
self_layer.self_attn.igate_preact.bias.data.fill_(-torch.log(torch.tensor(2.0)))
# Init weight with small values
init.xavier_uniform_(self_layer.self_attn.igate_preact.weight.data)
self_layer.self_attn.igate_preact.weight.data *= 0.1
init.xavier_uniform_(self_layer.self_attn.fgate_preact.weight.data)
self_layer.self_attn.fgate_preact.weight.data *= 0.1
def prepare_inputs_for_generation(
self,
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`
# Does not support using additional convolution states via inputs_embeds
# as opposed to Mamba, currently.
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
# If the first cache position is non-zero, we assume we are in generation mode.
# Thus, the cache_params state is assumed to be the state before the last token
# (lastly generated token), and all previous tokens are already ingested.
# This should as well support generation from scratch with the [BOS] token inserted first.
# if is_torchdynamo_compiling() or cache_position[0] > 0:
if cache_params is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
attention_mask = None
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
}
)
return model_inputs
class xQwenForSequenceClassification(xQwenPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
self.model = xQwenModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs
):
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
class xQwenForTokenClassification(xQwenPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
self.model = xQwenModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> TokenClassifierOutput:
outputs: BaseModelOutputWithPast = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.score(sequence_output)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.config)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register(xQwenConfig.model_type, xQwenConfig)
AutoModel.register(xQwenConfig, xQwenModel)
AutoModelForCausalLM.register(xQwenConfig, xQwenForCausalLM)
__all__ = [
"xQwenForCausalLM",
"xQwenModel",
"xQwenPreTrainedModel",
"xQwenForSequenceClassification",
"xQwenForTokenClassification",
]