PatrickHaller commited on
Commit
1b2dd8a
·
verified ·
1 Parent(s): 8680fab

Upload modeling_xqwen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_xqwen.py +1296 -0
modeling_xqwen.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple, Union
2
+ from dataclasses import dataclass
3
+ import functools
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.init as init
8
+ from torch.nn import functional as F
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.integrations import use_kernel_forward_from_hub
14
+ # from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
15
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
16
+ from transformers.modeling_layers import (
17
+ GradientCheckpointingLayer,
18
+ )
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutputWithPast,
21
+ CausalLMOutputWithPast,
22
+ SequenceClassifierOutputWithPast,
23
+ TokenClassifierOutput
24
+ )
25
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
26
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from transformers.processing_utils import Unpack
28
+ # from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
29
+ from transformers.utils import auto_docstring, can_return_tuple, logging
30
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
31
+
32
+ try:
33
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
34
+ except ImportError:
35
+ print("Flash Attention is not installed. Please install it to use xQwenForCausalLM with Flash Attention.")
36
+
37
+ # from transformers.masking_utils import causal_mask_mapping
38
+
39
+ try:
40
+ from fla.layers.gated_deltaproduct import GatedDeltaProduct
41
+ fla_available = True
42
+ except:
43
+ fla_available = False
44
+
45
+ from fla.modules import ShortConvolution
46
+ from fla.modules.feature_map import HedgehogFeatureMap
47
+
48
+ from .configuration_xqwen import xQwenConfig
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ from xlstm.xlstm_large.model import (
53
+ mLSTMStateType,
54
+ soft_cap,
55
+ # mLSTMLayer,
56
+ mLSTMLayerConfig,
57
+ mLSTMBackendConfig,
58
+ mLSTMLayerStateType,
59
+ mLSTMBackend,
60
+ MultiHeadLayerNorm
61
+ )
62
+
63
+ class xLSTMCache:
64
+ """
65
+ Cache / RNN State handler for xLSTM.
66
+
67
+ Args:
68
+ config: xLSTMConfig
69
+ batch_size: int
70
+ dtype: torch.dtype
71
+ device: torch.device
72
+
73
+ Attributes:
74
+ seqlen_offset: int
75
+ dtype: torch.dtype
76
+ """
77
+
78
+ def __init__(
79
+ self, config, batch_size: int, dtype: torch.dtype = torch.bfloat16, device: Optional[str] = None
80
+ ):
81
+ self.seqlen_offset = torch.tensor(0, dtype=torch.int64, device=device)
82
+ self.dtype = dtype
83
+ self.config = config
84
+ self.qk_head_dim = self.config.head_dim
85
+ self.v_head_dim = self.config.head_dim
86
+
87
+ self.rnn_state: mLSTMStateType = {
88
+ layer: (
89
+ torch.zeros(
90
+ [batch_size, config.num_heads, self.qk_head_dim, self.v_head_dim], dtype=dtype, device=device
91
+ ),
92
+ torch.zeros([batch_size, config.num_heads, self.qk_head_dim], dtype=dtype, device=device),
93
+ torch.zeros([batch_size, config.num_heads, 1], dtype=dtype, device=device),
94
+ )
95
+ for layer in range(config.num_hidden_layers)
96
+ }
97
+ self.rnn_state_initial = True
98
+
99
+ def reset(self):
100
+ self.rnn_state = {
101
+ layer: (
102
+ torch.zeros_like(self.rnn_state[layer][0]),
103
+ torch.zeros_like(self.rnn_state[layer][1]),
104
+ torch.zeros_like(self.rnn_state[layer][2]),
105
+ )
106
+ for layer in self.rnn_state
107
+ }
108
+ self.rnn_state_initial = True
109
+
110
+ @dataclass
111
+ class xQwenModelOutputWithPast(BaseModelOutputWithPast):
112
+ cache_params: Optional[xLSTMCache] = None
113
+
114
+ @dataclass
115
+ class xQwenCausalLMOutput(CausalLMOutputWithPast):
116
+ cache_params: Optional[xLSTMCache] = None
117
+
118
+ @use_kernel_forward_from_hub("RMSNorm")
119
+ class xQwenRMSNorm(nn.Module):
120
+ def __init__(self, hidden_size, eps=1e-6):
121
+ """
122
+ xQwenRMSNorm is equivalent to T5LayerNorm
123
+ """
124
+ super().__init__()
125
+ self.weight = nn.Parameter(torch.ones(hidden_size))
126
+ self.variance_epsilon = eps
127
+
128
+ def forward(self, hidden_states):
129
+ input_dtype = hidden_states.dtype
130
+ hidden_states = hidden_states.to(torch.float32)
131
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
132
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
133
+ return self.weight * hidden_states.to(input_dtype)
134
+
135
+ def extra_repr(self):
136
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
137
+
138
+
139
+ class xQwenMLP(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.config = config
143
+ self.hidden_size = config.hidden_size
144
+ self.intermediate_size = config.intermediate_size
145
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
146
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
147
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
148
+ self.act_fn = ACT2FN[config.hidden_act]
149
+ if self.config.mlp_dropout > 0.0:
150
+ self.dropout = nn.Dropout(config.mlp_dropout)
151
+
152
+ def forward(self, x):
153
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
154
+ if self.config.mlp_dropout > 0.0:
155
+ down_proj = self.dropout(down_proj)
156
+ return down_proj
157
+
158
+
159
+ def rotate_half(x):
160
+ """Rotates half the hidden dims of the input."""
161
+ x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
+ return torch.cat((-x2, x1), dim=-1)
164
+
165
+
166
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
167
+ """Applies Rotary Position Embedding to the query and key tensors.
168
+
169
+ Args:
170
+ q (`torch.Tensor`): The query tensor.
171
+ k (`torch.Tensor`): The key tensor.
172
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
173
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
174
+ position_ids (`torch.Tensor`, *optional*):
175
+ Deprecated and unused.
176
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
177
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
178
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
179
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
180
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
181
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
182
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
183
+ Returns:
184
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
185
+ """
186
+ cos = cos.unsqueeze(unsqueeze_dim)
187
+ sin = sin.unsqueeze(unsqueeze_dim)
188
+ q_embed = (q * cos) + (rotate_half(q) * sin)
189
+ k_embed = (k * cos) + (rotate_half(k) * sin)
190
+ return q_embed, k_embed
191
+
192
+
193
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
194
+ """
195
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
196
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
197
+ """
198
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
199
+ if n_rep == 1:
200
+ return hidden_states
201
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
202
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
203
+
204
+
205
+ def eager_attention_forward(
206
+ module: nn.Module,
207
+ query: torch.Tensor,
208
+ key: torch.Tensor,
209
+ value: torch.Tensor,
210
+ attention_mask: Optional[torch.Tensor],
211
+ scaling: float,
212
+ dropout: float = 0.0,
213
+ **kwargs,
214
+ ):
215
+ key_states = repeat_kv(key, module.num_key_value_groups)
216
+ value_states = repeat_kv(value, module.num_key_value_groups)
217
+
218
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
219
+ if attention_mask is not None:
220
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
221
+ attn_weights = attn_weights + causal_mask
222
+
223
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
224
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
225
+ attn_output = torch.matmul(attn_weights, value_states)
226
+ attn_output = attn_output.transpose(1, 2).contiguous()
227
+
228
+ return attn_output, attn_weights
229
+
230
+
231
+ class xQwenAttention(nn.Module):
232
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
233
+
234
+ def __init__(self, config: xQwenConfig, layer_idx: int):
235
+ super().__init__()
236
+ self.config = config
237
+ self.layer_idx = layer_idx
238
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
239
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
240
+ self.scaling = self.head_dim**-0.5
241
+ self.attention_dropout = config.attention_dropout
242
+ self.is_causal = True
243
+
244
+ self.q_proj = nn.Linear(
245
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
246
+ )
247
+ self.k_proj = nn.Linear(
248
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
249
+ )
250
+ self.v_proj = nn.Linear(
251
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
252
+ )
253
+ self.o_proj = nn.Linear(
254
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
255
+ )
256
+ self.q_norm = xQwenRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
257
+ self.k_norm = xQwenRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
258
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
264
+ attention_mask: Optional[torch.Tensor],
265
+ past_key_value: Optional[Cache] = None,
266
+ cache_position: Optional[torch.LongTensor] = None,
267
+ **kwargs: Unpack[FlashAttentionKwargs],
268
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
269
+ input_shape = hidden_states.shape[:-1]
270
+ hidden_shape = (*input_shape, -1, self.head_dim)
271
+
272
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
273
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
274
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
275
+
276
+ cos, sin = position_embeddings
277
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
278
+
279
+ if past_key_value is not None:
280
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
281
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
282
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
283
+
284
+ attention_interface: Callable = eager_attention_forward
285
+ if self.config._attn_implementation != "eager":
286
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
287
+
288
+ attn_output, attn_weights = attention_interface(
289
+ self,
290
+ query_states,
291
+ key_states,
292
+ value_states,
293
+ attention_mask,
294
+ dropout=0.0 if not self.training else self.attention_dropout,
295
+ scaling=self.scaling,
296
+ sliding_window=self.sliding_window, # diff with Llama
297
+ **kwargs,
298
+ )
299
+
300
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
301
+ attn_output = self.o_proj(attn_output)
302
+ return attn_output, attn_weights
303
+
304
+ class mLSTMLayer(nn.Module):
305
+ def __init__(self, config: mLSTMLayerConfig):
306
+ super().__init__()
307
+ self.config = config
308
+
309
+ # self.head_dim = config.embedding_dim // config.num_heads
310
+ self.head_dim = self.config.head_dim
311
+ self.num_key_value_groups = config.num_heads // config.num_key_value_heads
312
+
313
+ self.v_dim = int(config.embedding_dim * config.v_dim_factor)
314
+ self.qk_dim = int(config.embedding_dim * config.qk_dim_factor)
315
+ if self.config.weight_mode == "single":
316
+ self.q = nn.Linear(
317
+ in_features=self.config.hidden_size,
318
+ out_features=self.config.num_heads * self.head_dim,
319
+ bias=self.config.use_bias,
320
+ )
321
+ self.k = nn.Linear(
322
+ in_features=self.config.hidden_size,
323
+ out_features=config.num_key_value_heads * self.head_dim,
324
+ bias=self.config.use_bias,
325
+ )
326
+ self.v = nn.Linear(
327
+ in_features=self.config.hidden_size,
328
+ out_features=config.num_key_value_heads * self.head_dim,
329
+ bias=self.config.use_bias,
330
+ )
331
+
332
+ self.ogate_preact = nn.Linear(
333
+ in_features=self.config.hidden_size,
334
+ out_features=self.head_dim * self.config.num_heads,
335
+ # out_features=self.config.hidden_size,
336
+ bias=self.config.use_bias,
337
+ )
338
+ self.igate_preact = nn.Linear(
339
+ # in_features=self.head_dim * self.config.num_heads,
340
+ in_features=self.config.hidden_size,
341
+ out_features=self.config.num_heads,
342
+ bias=True,
343
+ )
344
+ self.fgate_preact = nn.Linear(
345
+ # in_features=self.head_dim * self.config.num_heads,
346
+ in_features=self.config.hidden_size,
347
+ out_features=self.config.num_heads,
348
+ bias=True,
349
+ )
350
+ elif self.config.weight_mode == "fused":
351
+ self.qkv_opreact = nn.Linear(
352
+ in_features=self.config.hidden_size,
353
+ out_features=2 * self.qk_dim + 2 * self.v_dim,
354
+ bias=self.config.use_bias,
355
+ )
356
+ self.ifgate_preact = nn.Linear(
357
+ in_features=self.config.hidden_size,
358
+ out_features=2 * self.config.num_heads,
359
+ bias=True,
360
+ )
361
+
362
+ self.ogate_act_fn = nn.Sigmoid()
363
+ self.mlstm_backend = mLSTMBackend(config=self.config.mlstm_backend_config())
364
+
365
+ self.multihead_norm = MultiHeadLayerNorm(
366
+ num_heads=self.config.num_heads,
367
+ head_dim=self.head_dim,
368
+ eps=self.config.norm_eps,
369
+ use_weight=True,
370
+ use_bias=self.config.use_bias,
371
+ force_float32_reductions=self.config.norm_reduction_force_float32,
372
+ )
373
+ self.out_proj = nn.Linear(
374
+ in_features=self.head_dim * self.config.num_heads,
375
+ out_features=self.config.hidden_size,
376
+ bias=self.config.use_bias,
377
+ )
378
+
379
+ if self.config.use_sliding_window:
380
+ self.block_mask = None
381
+ self.swa_attention = None
382
+ if self.config.swa_modulation == "dynamic":
383
+ self.swa_alpha = nn.Parameter(
384
+ torch.tensor(
385
+ 0.5, dtype=torch.float32, requires_grad=True
386
+ )
387
+ )
388
+
389
+ if self.config.use_short_conv:
390
+
391
+ self.q_conv1d = ShortConvolution(
392
+ hidden_size=self.config.hidden_size,
393
+ kernel_size=self.config.conv_size,
394
+ bias=False,
395
+ activation='silu'
396
+ )
397
+ self.k_conv1d = ShortConvolution(
398
+ hidden_size=self.config.hidden_size,
399
+ kernel_size=self.config.conv_size,
400
+ bias=False,
401
+ activation='silu'
402
+ )
403
+ self.v_conv1d = ShortConvolution(
404
+ hidden_size=self.config.hidden_size,
405
+ kernel_size=self.config.conv_size,
406
+ bias=False,
407
+ activation='silu'
408
+ )
409
+
410
+ if self.config.use_hedgehog:
411
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_dim)
412
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_dim)
413
+
414
+
415
+ def set_swa_block_mask(self, q_len, mem_window=4):
416
+ block_mask = self.get_swa_block(with_memory=self.config.swa_with_memory, mem_window=mem_window)
417
+ self.block_mask = create_block_mask(block_mask, B=None, H=None, Q_LEN=q_len, KV_LEN=q_len)
418
+ self.swa_attention = functools.partial(
419
+ flex_attention, block_mask=self.block_mask
420
+ )
421
+ self.q_len = q_len
422
+
423
+ def get_swa_block(self, with_memory=False, mem_window=None):
424
+ if with_memory:
425
+ assert mem_window is not None, "mem_window must be specified for sliding window with memory"
426
+ def swa_with_memory(b, h, q_idx, kv_idx):
427
+ """ Sliding window causal attention with memory.
428
+
429
+ Add mask so model always attents to first m tokens in the sequence.
430
+
431
+ """
432
+ causal_mask = q_idx >= kv_idx
433
+ window_mask = (q_idx - kv_idx) <= self.config.sliding_window
434
+ memory_mask = kv_idx < mem_window
435
+ return (causal_mask & window_mask) | memory_mask
436
+
437
+ return swa_with_memory
438
+
439
+ def sliding_window_causal(b, h, q_idx, kv_idx):
440
+ causal_mask = q_idx >= kv_idx
441
+ window_mask = (q_idx - kv_idx) <= self.config.sliding_window
442
+ return causal_mask & window_mask
443
+
444
+ return sliding_window_causal
445
+
446
+ def forward(
447
+ self, x: torch.Tensor,
448
+ state: mLSTMLayerStateType | None = None,
449
+ output_attentions: bool = False,
450
+ attention_mask: Optional[torch.Tensor] = None,
451
+ position_ids: Optional[torch.LongTensor] = None,
452
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
453
+ ) -> tuple[torch.Tensor, mLSTMLayerStateType | None]:
454
+ assert x.ndim == 3, f"Input must have shape [B, S, D], got {x.shape}"
455
+ B, S, _ = x.shape
456
+ if self.config.weight_mode == "single":
457
+ q = self.q(x)
458
+ k = self.k(x)
459
+ v = self.v(x)
460
+
461
+ if self.config.use_short_conv:
462
+ q, _ = self.q_conv1d(q)
463
+ k, _ = self.k_conv1d(k)
464
+ v, _ = self.v_conv1d(v)
465
+
466
+ o_preact = self.ogate_preact(x)
467
+ i_preact = soft_cap(
468
+ self.igate_preact(x), cap_value=self.config.gate_soft_cap
469
+ )
470
+ f_preact = soft_cap(
471
+ self.fgate_preact(x), cap_value=self.config.gate_soft_cap
472
+ )
473
+
474
+ elif self.config.weight_mode == "fused":
475
+ qkv_opreact = self.qkv_opreact(x)
476
+ q, k, v, o_preact = torch.tensor_split(
477
+ qkv_opreact,
478
+ (
479
+ self.qk_dim,
480
+ 2 * self.qk_dim,
481
+ 2 * self.qk_dim + self.v_dim,
482
+ ),
483
+ dim=-1,
484
+ )
485
+
486
+ if_preact = soft_cap(
487
+ self.ifgate_preact(x), cap_value=self.config.gate_soft_cap
488
+ )
489
+ i_preact, f_preact = torch.tensor_split(
490
+ if_preact, (self.config.num_heads,), dim=-1
491
+ )
492
+
493
+ q = q.reshape(B, S, self.config.num_heads, -1).transpose(1, 2)
494
+ k = k.reshape(B, S, self.config.num_key_value_heads, -1).transpose(1, 2)
495
+ v = v.reshape(B, S, self.config.num_key_value_heads, -1).transpose(1, 2)
496
+
497
+ k = repeat_kv(k, self.num_key_value_groups)
498
+ v = repeat_kv(v, self.num_key_value_groups)
499
+
500
+ if self.config.use_hedgehog:
501
+ q = self.feature_map_q(q)
502
+ k = self.feature_map_k(k)
503
+
504
+ if self.config.use_sliding_window:
505
+ sq, sk, sv = q, k, v
506
+ # assert position_ids is not None, "position_ids must be provided for sliding window attention"
507
+ if position_ids is None:
508
+ position_ids = torch.arange(S, device=x.device).unsqueeze(0)
509
+
510
+ cos, sin = position_embeddings
511
+ sq, sk, = apply_rotary_pos_emb(sq, sk, cos, sin)
512
+
513
+ i_preact = i_preact.transpose(1, 2)
514
+ f_preact = f_preact.transpose(1, 2)
515
+ if state is None:
516
+ c_initial, n_initial, m_initial = None, None, None
517
+ else:
518
+ c_initial, n_initial, m_initial = state
519
+
520
+
521
+ h, state = self.mlstm_backend(
522
+ q=q,
523
+ k=k,
524
+ v=v,
525
+ i=i_preact,
526
+ f=f_preact,
527
+ c_initial=c_initial,
528
+ n_initial=n_initial,
529
+ m_initial=m_initial,
530
+ )
531
+
532
+ h = h.transpose(1, 2)
533
+ h_norm = self.multihead_norm(h)
534
+
535
+ if self.config.use_sliding_window:
536
+
537
+ if sq.dtype == torch.float32:
538
+ sq, sk, sv = sq.to(torch.float16), sk.to(torch.float16), sv.to(torch.float16)
539
+
540
+ q_len = sq.size(-2)
541
+
542
+ if self.block_mask is None or self.swa_attention is None:
543
+ self.set_swa_block_mask(q_len, mem_window=self.config.sliding_window_memory)
544
+ elif self.q_len != q_len:
545
+ self.set_swa_block_mask(q_len, mem_window=self.config.sliding_window_memory)
546
+
547
+ y = self.swa_attention(sq, sk, sv).transpose(1, 2)
548
+
549
+ # y = _flash_attention_forward( # Reashape to the expected shape for Flash Attention
550
+ # sq.transpose(1, 2),
551
+ # sk.transpose(1, 2),
552
+ # sv.transpose(1, 2),
553
+ # attention_mask,
554
+ # q_len,
555
+ # position_ids=position_ids,
556
+ # dropout=0.0,
557
+ # sliding_window=self.config.sliding_window,
558
+ # use_top_left_mask=False,
559
+ # is_causal=True,
560
+ # target_dtype=torch.float32,
561
+ # )
562
+
563
+ # TODO: Indepent normalization for sliding window?
564
+ y = self.multihead_norm(y)
565
+ if self.config.swa_modulation == "static":
566
+ out = 0.5 * y + 0.5 * h_norm
567
+ elif self.config.swa_modulation == "dynamic":
568
+ if self.config.swa_modulation_bounded:
569
+ out = y + torch.tanh(self.swa_alpha) * h_norm
570
+ else:
571
+ out = y + self.swa_alpha * h_norm
572
+ else:
573
+ out = y
574
+ # raise ValueError("Unknown sliding window modulation type: {}".format(self.config.swa_modulation))
575
+ else:
576
+ out = h_norm
577
+
578
+
579
+ out = out.reshape(B, S, -1)
580
+ out = self.ogate_act_fn(o_preact) * out
581
+ y = self.out_proj(out)
582
+ return y, state
583
+
584
+ token_mixer_type = {
585
+ "qwen_attention": xQwenAttention,
586
+ "xlstm_attention": mLSTMLayer,
587
+ }
588
+ def build_mlstm_config(config):
589
+ return config
590
+ return mLSTMLayerConfig(
591
+ embedding_dim=config.embedding_dim,
592
+ num_heads=config.num_heads,
593
+ use_bias=config.use_bias,
594
+ norm_eps=config.rms_norm_eps,
595
+ norm_reduction_force_float32=config.norm_reduction_force_float32,
596
+ qk_dim_factor=1,
597
+ v_dim_factor=1,
598
+ num_key_value_heads=config.num_key_value_heads,
599
+ gate_soft_cap=config.gate_soft_cap,
600
+ weight_mode="single",
601
+ mlstm_backend=mLSTMBackendConfig(
602
+ chunkwise_kernel=config.chunkwise_kernel,
603
+ sequence_kernel=config.sequence_kernel,
604
+ step_kernel=config.step_kernel,
605
+ mode=config.mode,
606
+ chunk_size=config.chunk_size,
607
+ return_last_states=config.return_last_states,
608
+ autocast_kernel_dtype=config.autocast_kernel_dtype,
609
+ eps=config.eps,
610
+ inference_state_dtype=config.inference_state_dtype,
611
+ ),
612
+ )
613
+
614
+ def build_gdp(config):
615
+ assert fla_available, "GatedDeltaProduct requires fla package to be installed."
616
+ # config.hidden_size = 512
617
+ return GatedDeltaProduct(
618
+ hidden_size=config.hidden_size,
619
+ expand_v=1,
620
+ head_dim=config.hidden_size // config.num_attention_heads,
621
+ num_heads=config.num_attention_heads,
622
+ use_output_gate=False,
623
+ use_short_conv=True,
624
+ use_forget_gate=True,
625
+ num_householder=2
626
+ )
627
+
628
+
629
+ class xQwenDecoderLayer(GradientCheckpointingLayer):
630
+ def __init__(self, config: xQwenConfig, layer_idx: int):
631
+ super().__init__()
632
+ self.hidden_size = config.hidden_size
633
+
634
+ self.attention_type = config.layer_types[layer_idx]
635
+ if self.attention_type == "qwen_attention":
636
+ self.self_attn = xQwenAttention(config=config, layer_idx=layer_idx)
637
+ elif self.attention_type == "xlstm_attention":
638
+ self.self_attn = mLSTMLayer(build_mlstm_config(config))
639
+ elif self.attention_type == "gdp_attention":
640
+ self.self_attn = build_gdp(config)
641
+ else:
642
+ raise ValueError("Unsupported attention type: {}".format(self.attention_type))
643
+
644
+ self.mlp = xQwenMLP(config)
645
+ self.input_layernorm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
646
+ self.post_attention_layernorm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
647
+
648
+ def forward(
649
+ self,
650
+ hidden_states: torch.Tensor,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ position_ids: Optional[torch.LongTensor] = None,
653
+ past_key_value: Optional[Cache] = None,
654
+ output_attentions: Optional[bool] = False,
655
+ use_cache: Optional[bool] = False,
656
+ cache_position: Optional[torch.LongTensor] = None,
657
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
658
+ state: mLSTMStateType | None = None,
659
+ **kwargs: Unpack[FlashAttentionKwargs],
660
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
661
+ residual = hidden_states
662
+ hidden_states = self.input_layernorm(hidden_states)
663
+
664
+ if output_attentions:
665
+ return None, self.self_attn(
666
+ hidden_states,
667
+ attention_mask=attention_mask,
668
+ output_attentions=output_attentions,
669
+ position_ids=position_ids,
670
+ position_embeddings=position_embeddings,
671
+ )
672
+
673
+ # Self Attention
674
+ hidden_states, *state = self.self_attn(
675
+ hidden_states,
676
+ attention_mask=attention_mask,
677
+ output_attentions=output_attentions,
678
+ position_ids=position_ids,
679
+ position_embeddings=position_embeddings,
680
+ state=state,
681
+ )
682
+
683
+ if len(state) == 1:
684
+ state = state[0] # unpack the single state tuple
685
+ else:
686
+ state = None
687
+
688
+
689
+ hidden_states = residual + hidden_states
690
+
691
+ # Fully Connected
692
+ residual = hidden_states
693
+ hidden_states = self.post_attention_layernorm(hidden_states)
694
+ hidden_states = self.mlp(hidden_states)
695
+ hidden_states = residual + hidden_states
696
+
697
+ outputs = (hidden_states,)
698
+ if output_attentions:
699
+ outputs += (self_attn_weights,)
700
+
701
+ return outputs, state
702
+
703
+
704
+ @auto_docstring
705
+ class xQwenPreTrainedModel(PreTrainedModel):
706
+ config_class = xQwenConfig
707
+ base_model_prefix = "model"
708
+ supports_gradient_checkpointing = True
709
+ _no_split_modules = ["xQwenDecoderLayer"]
710
+ _skip_keys_device_placement = ["past_key_values"]
711
+ _supports_flash_attn_2 = True
712
+ _supports_sdpa = True
713
+ _supports_flex_attn = True
714
+ _supports_cache_class = True
715
+ _supports_quantized_cache = True
716
+ _supports_static_cache = True
717
+ _supports_attention_backend = True
718
+
719
+ def _init_weights(self, module):
720
+ std = self.config.initializer_range
721
+ if isinstance(module, nn.Linear):
722
+ module.weight.data.normal_(mean=0.0, std=std)
723
+ if module.bias is not None:
724
+ module.bias.data.zero_()
725
+ elif isinstance(module, nn.Embedding):
726
+ module.weight.data.normal_(mean=0.0, std=std)
727
+ if module.padding_idx is not None:
728
+ module.weight.data[module.padding_idx].zero_()
729
+ elif isinstance(module, xQwenRMSNorm):
730
+ module.weight.data.fill_(1.0)
731
+
732
+
733
+ class xQwenRotaryEmbedding(nn.Module):
734
+ def __init__(self, config: xQwenConfig, device=None):
735
+ super().__init__()
736
+ # BC: "rope_type" was originally "type"
737
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
738
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
739
+ else:
740
+ self.rope_type = "default"
741
+ self.max_seq_len_cached = config.max_position_embeddings
742
+ self.original_max_seq_len = config.max_position_embeddings\
743
+
744
+ self.config = config
745
+
746
+ # Hedgehog feature map doubles the hidden size for q and k
747
+ if self.config.use_sliding_window and self.config.use_hedgehog:
748
+ # self.config.hidden_size *= 2
749
+ self.config.head_dim *= 2
750
+
751
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
752
+
753
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
754
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
755
+ self.original_inv_freq = self.inv_freq
756
+
757
+ @torch.no_grad()
758
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
759
+ def forward(self, x, position_ids):
760
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
761
+ position_ids_expanded = position_ids[:, None, :].float()
762
+
763
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
764
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
765
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
766
+ emb = torch.cat((freqs, freqs), dim=-1)
767
+ cos = emb.cos() * self.attention_scaling
768
+ sin = emb.sin() * self.attention_scaling
769
+
770
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
771
+
772
+
773
+ @auto_docstring
774
+ class xQwenModel(xQwenPreTrainedModel):
775
+ config_class = xQwenConfig
776
+ def __init__(self, config: xQwenConfig):
777
+ super().__init__(config)
778
+ self.padding_idx = config.pad_token_id
779
+ self.vocab_size = config.vocab_size
780
+
781
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
782
+ self.layers = nn.ModuleList(
783
+ [xQwenDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
784
+ )
785
+ self.norm = xQwenRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
786
+ self.rotary_emb = xQwenRotaryEmbedding(config=config)
787
+ self.gradient_checkpointing = False
788
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
789
+
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ def get_input_embeddings(self):
794
+ return self.embed_tokens
795
+
796
+ def set_input_embeddings(self, value):
797
+ self.embed_tokens = value
798
+
799
+ @can_return_tuple
800
+ @auto_docstring
801
+ def forward(
802
+ self,
803
+ input_ids: Optional[torch.LongTensor] = None,
804
+ attention_mask: Optional[torch.Tensor] = None,
805
+ position_ids: Optional[torch.LongTensor] = None,
806
+ past_key_values: Optional[Cache] = None,
807
+ inputs_embeds: Optional[torch.FloatTensor] = None,
808
+ use_cache: Optional[bool] = None,
809
+ output_attentions: Optional[bool] = None,
810
+ output_hidden_states: Optional[bool] = None,
811
+ cache_position: Optional[torch.LongTensor] = None,
812
+ cache_params: Optional[xLSTMCache] = None,
813
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
814
+ ) -> BaseModelOutputWithPast:
815
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
816
+ output_hidden_states = (
817
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
818
+ )
819
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
820
+
821
+ if (input_ids is None) ^ (inputs_embeds is not None):
822
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
823
+
824
+ if self.gradient_checkpointing and self.training and use_cache:
825
+ logger.warning_once(
826
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
827
+ )
828
+ use_cache = False
829
+
830
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
831
+ if not isinstance(past_key_values, (type(None), Cache)):
832
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
833
+
834
+ if inputs_embeds is None:
835
+ inputs_embeds = self.embed_tokens(input_ids)
836
+
837
+ if use_cache and past_key_values is None:
838
+ past_key_values = DynamicCache()
839
+
840
+ if cache_position is None:
841
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
842
+ cache_position = torch.arange(
843
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
844
+ )
845
+
846
+ if position_ids is None:
847
+ position_ids = cache_position.unsqueeze(0)
848
+
849
+ # It may already have been prepared by e.g. `generate`
850
+ # if not isinstance(causal_mask_mapping := attention_mask, dict):
851
+ # # Prepare mask arguments
852
+ # mask_kwargs = {
853
+ # "config": self.config,
854
+ # "input_embeds": inputs_embeds,
855
+ # "attention_mask": attention_mask,
856
+ # "cache_position": cache_position,
857
+ # "past_key_values": past_key_values,
858
+ # }
859
+ # # Create the masks
860
+ # causal_mask_mapping = {
861
+ # "full_attention": create_causal_mask(**mask_kwargs),
862
+ # }
863
+ # # The sliding window alternating layers are not always activated depending on the config
864
+ # if self.has_sliding_layers:
865
+ # causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
866
+ use_cache = False
867
+ if use_cache:
868
+ if cache_params is None:
869
+ cache_params = xLSTMCache(
870
+ self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
871
+ )
872
+ else:
873
+ cache_params = None
874
+
875
+ hidden_states = inputs_embeds
876
+
877
+ # create position embeddings to be shared across the decoder layers
878
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
879
+
880
+ # decoder layers
881
+ all_hidden_states = () if output_hidden_states else None
882
+ all_self_attns = () if output_attentions else None
883
+ inference_with_cache = (
884
+ not self.training
885
+ and self.config.max_inference_chunksize < hidden_states.shape[1]
886
+ and not output_hidden_states
887
+ )
888
+ if inference_with_cache:
889
+ all_hidden_states = None
890
+ offset = 0
891
+ with torch.no_grad():
892
+ if cache_params is None:
893
+ cache_params = xLSTMCache(config=self.config, batch_size=hidden_states.shape[0])
894
+ final_state = torch.zeros_like(hidden_states)
895
+ while offset < hidden_states.shape[1]:
896
+ hidden_states_chunk = hidden_states[
897
+ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
898
+ ]
899
+
900
+ for i, layer in self.layers[: self.config.num_hidden_layers]:
901
+ hidden_state_chunk, rnn_state = layer(
902
+ hidden_state_chunk,
903
+ state=cache_params.rnn_state[i],
904
+ )
905
+ for state_idx in range(len(cache_params.rnn_state[1])):
906
+ local_rnn_state = rnn_state[state_idx]
907
+ cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
908
+ cache_params.rnn_state_initial = False
909
+ final_state[
910
+ :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
911
+ ] = hidden_state_chunk
912
+ offset += self.config.max_inference_chunksize
913
+ hidden_states = final_state
914
+ else:
915
+ for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
916
+ if output_hidden_states:
917
+ all_hidden_states += (hidden_states,)
918
+
919
+ layer_outputs, rnn_state = decoder_layer(
920
+ hidden_states,
921
+ # attention_mask=causal_mask_mapping[decoder_layer.attention_type],
922
+ position_ids=position_ids,
923
+ past_key_value=past_key_values,
924
+ output_attentions=output_attentions,
925
+ use_cache=use_cache,
926
+ cache_position=cache_position,
927
+ position_embeddings=position_embeddings,
928
+ state=cache_params.rnn_state[i] if cache_params is not None else None,
929
+ **flash_attn_kwargs,
930
+ )
931
+
932
+ if cache_params:
933
+ for state_idx in range(len(cache_params.rnn_state[i])):
934
+ local_rnn_state = rnn_state[state_idx]
935
+ cache_params.rnn_state[i][state_idx].copy_(local_rnn_state)
936
+ cache_params.rnn_state_initial = False
937
+
938
+ hidden_states = layer_outputs[0]
939
+
940
+ if output_attentions:
941
+ all_self_attns += (layer_outputs[1],)
942
+
943
+ if use_cache:
944
+ cache_params.seqlen_offset += inputs_embeds.shape[1]
945
+
946
+ hidden_states = self.norm(hidden_states)
947
+
948
+ # add hidden states from the last decoder layer
949
+ if output_hidden_states:
950
+ all_hidden_states += (hidden_states,)
951
+
952
+ return xQwenModelOutputWithPast(
953
+ last_hidden_state=hidden_states,
954
+ past_key_values=past_key_values if use_cache else None,
955
+ hidden_states=all_hidden_states,
956
+ attentions=all_self_attns,
957
+ cache_params= cache_params if use_cache else None
958
+ )
959
+
960
+
961
+ # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
962
+
963
+
964
+ @auto_docstring
965
+ class xQwenForCausalLM(xQwenPreTrainedModel, GenerationMixin):
966
+ config_class = xQwenConfig
967
+ _tied_weights_keys = ["lm_head.weight"]
968
+ _tp_plan = {"lm_head": "colwise_rep"}
969
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
970
+
971
+ def __init__(self, config):
972
+ super().__init__(config)
973
+ self.model = xQwenModel(config)
974
+ self.vocab_size = config.vocab_size
975
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
976
+
977
+ # Initialize weights and apply final processing
978
+ self.post_init()
979
+
980
+ def get_input_embeddings(self):
981
+ return self.model.embed_tokens
982
+
983
+ def set_input_embeddings(self, value):
984
+ self.model.embed_tokens = value
985
+
986
+ def get_output_embeddings(self):
987
+ return self.lm_head
988
+
989
+ def set_output_embeddings(self, new_embeddings):
990
+ self.lm_head = new_embeddings
991
+
992
+ def set_decoder(self, decoder):
993
+ self.model = decoder
994
+
995
+ def get_decoder(self):
996
+ return self.model
997
+
998
+ @can_return_tuple
999
+ @auto_docstring
1000
+ def forward(
1001
+ self,
1002
+ input_ids: Optional[torch.LongTensor] = None,
1003
+ attention_mask: Optional[torch.Tensor] = None,
1004
+ position_ids: Optional[torch.LongTensor] = None,
1005
+ past_key_values: Optional[Cache] = None,
1006
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1007
+ labels: Optional[torch.LongTensor] = None,
1008
+ use_cache: Optional[bool] = None,
1009
+ output_attentions: Optional[bool] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ cache_position: Optional[torch.LongTensor] = None,
1012
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1013
+ **kwargs,
1014
+ ) -> CausalLMOutputWithPast:
1015
+ r"""
1016
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1017
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1018
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1019
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1020
+
1021
+ Example:
1022
+
1023
+ ```python
1024
+ >>> from transformers import AutoTokenizer, xQwenForCausalLM
1025
+
1026
+ >>> model = xQwenForCausalLM.from_pretrained("Qwen/xQwen-8B")
1027
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/xQwen-8B")
1028
+
1029
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1030
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1031
+
1032
+ >>> # Generate
1033
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1034
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1035
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1036
+ ```"""
1037
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1038
+ output_hidden_states = (
1039
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1040
+ )
1041
+
1042
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1043
+ outputs: BaseModelOutputWithPast = self.model(
1044
+ input_ids=input_ids,
1045
+ attention_mask=attention_mask,
1046
+ position_ids=position_ids,
1047
+ past_key_values=past_key_values,
1048
+ inputs_embeds=inputs_embeds,
1049
+ use_cache=use_cache,
1050
+ output_attentions=output_attentions,
1051
+ output_hidden_states=output_hidden_states,
1052
+ cache_position=cache_position,
1053
+ **kwargs,
1054
+ )
1055
+
1056
+ hidden_states = outputs.last_hidden_state
1057
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1058
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1059
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1060
+
1061
+ loss = None
1062
+ if labels is not None:
1063
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1064
+
1065
+ # return CausalLMOutputWithPast(
1066
+ return xQwenCausalLMOutput(
1067
+ loss=loss,
1068
+ logits=logits,
1069
+ past_key_values=outputs.past_key_values,
1070
+ hidden_states=outputs.hidden_states,
1071
+ attentions=outputs.attentions,
1072
+ )
1073
+
1074
+ def copy_from_teacher(self, teacher, copy_qkv: bool = True):
1075
+ assert len(self.model.layers) == len(teacher.model.layers)
1076
+
1077
+ self.model.embed_tokens.weight.data.copy_(teacher.get_input_embeddings().weight.data)
1078
+ self.model.norm.weight.data.copy_(teacher.model.norm.weight.data)
1079
+ self.lm_head.weight.data.copy_(teacher.get_output_embeddings().weight.data)
1080
+
1081
+ for self_layer, teacher_layer in zip(self.model.layers, teacher.model.layers):
1082
+ # self_layer.token_mixer.load_state_dict(teacher_layer.token_mixer.state_dict())
1083
+ self_layer.mlp.load_state_dict(teacher_layer.mlp.state_dict())
1084
+ self_layer.input_layernorm.load_state_dict(teacher_layer.input_layernorm.state_dict())
1085
+ self_layer.post_attention_layernorm.load_state_dict(teacher_layer.post_attention_layernorm.state_dict())
1086
+
1087
+ if copy_qkv:
1088
+ self_layer.self_attn.q.load_state_dict(teacher_layer.self_attn.q_proj.state_dict())
1089
+ self_layer.self_attn.out_proj.load_state_dict(teacher_layer.self_attn.o_proj.state_dict())
1090
+
1091
+ v_proj_unrolled = teacher_layer.self_attn.v_proj
1092
+ k_proj_unrolled = teacher_layer.self_attn.k_proj
1093
+
1094
+ self_layer.self_attn.v.load_state_dict(v_proj_unrolled.state_dict())
1095
+ self_layer.self_attn.k.load_state_dict(k_proj_unrolled.state_dict())
1096
+
1097
+ self_layer.self_attn.igate_preact.bias.data.fill_(torch.log(torch.tensor(2.0)))
1098
+ self_layer.self_attn.igate_preact.bias.data.fill_(-torch.log(torch.tensor(2.0)))
1099
+
1100
+ # Init weight with small values
1101
+ init.xavier_uniform_(self_layer.self_attn.igate_preact.weight.data)
1102
+ self_layer.self_attn.igate_preact.weight.data *= 0.1
1103
+ init.xavier_uniform_(self_layer.self_attn.fgate_preact.weight.data)
1104
+ self_layer.self_attn.fgate_preact.weight.data *= 0.1
1105
+
1106
+ def prepare_inputs_for_generation(
1107
+ self,
1108
+ input_ids,
1109
+ inputs_embeds=None,
1110
+ use_cache=None,
1111
+ cache_params = None,
1112
+ cache_position: Optional[torch.LongTensor] = None,
1113
+ attention_mask: Optional[torch.Tensor] = None,
1114
+ **kwargs,
1115
+ ):
1116
+ # Overwritten -- uses `cache_params` as opposed to `past_key_values`
1117
+ # Does not support using additional convolution states via inputs_embeds
1118
+ # as opposed to Mamba, currently.
1119
+ if use_cache:
1120
+ # `cache_position` should have been initialized in `generate`
1121
+ if cache_position is None:
1122
+ raise ValueError(
1123
+ "`cache_position` should not be None as it should have been initialized in "
1124
+ "`model.generate`, you are responsible for passing in a valid `cache_position` if "
1125
+ "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
1126
+ )
1127
+ # If the first cache position is non-zero, we assume we are in generation mode.
1128
+ # Thus, the cache_params state is assumed to be the state before the last token
1129
+ # (lastly generated token), and all previous tokens are already ingested.
1130
+ # This should as well support generation from scratch with the [BOS] token inserted first.
1131
+
1132
+ # if is_torchdynamo_compiling() or cache_position[0] > 0:
1133
+ if cache_params is not None:
1134
+ input_ids = input_ids[:, -1:]
1135
+ if inputs_embeds is not None:
1136
+ inputs_embeds = inputs_embeds[:, -1:]
1137
+
1138
+ attention_mask = None
1139
+
1140
+ if inputs_embeds is not None and cache_params is None:
1141
+ model_inputs = {"inputs_embeds": inputs_embeds}
1142
+ else:
1143
+ model_inputs = {"input_ids": input_ids}
1144
+
1145
+ model_inputs.update(
1146
+ {
1147
+ "attention_mask": attention_mask,
1148
+ "cache_params": cache_params,
1149
+ "use_cache": use_cache,
1150
+ "cache_position": cache_position,
1151
+ }
1152
+ )
1153
+ return model_inputs
1154
+
1155
+ class xQwenForSequenceClassification(xQwenPreTrainedModel):
1156
+ def __init__(self, config):
1157
+ super().__init__(config)
1158
+ self.num_labels = config.num_labels
1159
+ # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
1160
+ self.model = xQwenModel(config)
1161
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1162
+
1163
+ # Initialize weights and apply final processing
1164
+ self.post_init()
1165
+
1166
+ @can_return_tuple
1167
+ @auto_docstring
1168
+ def forward(
1169
+ self,
1170
+ input_ids: Optional[torch.LongTensor] = None,
1171
+ attention_mask: Optional[torch.Tensor] = None,
1172
+ position_ids: Optional[torch.LongTensor] = None,
1173
+ past_key_values: Optional[Cache] = None,
1174
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1175
+ labels: Optional[torch.LongTensor] = None,
1176
+ use_cache: Optional[bool] = None,
1177
+ **kwargs
1178
+ ):
1179
+ transformer_outputs: BaseModelOutputWithPast = self.model(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ position_ids=position_ids,
1183
+ past_key_values=past_key_values,
1184
+ inputs_embeds=inputs_embeds,
1185
+ use_cache=use_cache,
1186
+ **kwargs,
1187
+ )
1188
+ hidden_states = transformer_outputs.last_hidden_state
1189
+ logits = self.score(hidden_states)
1190
+
1191
+ if input_ids is not None:
1192
+ batch_size = input_ids.shape[0]
1193
+ else:
1194
+ batch_size = inputs_embeds.shape[0]
1195
+
1196
+ if self.config.pad_token_id is None and batch_size != 1:
1197
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1198
+ if self.config.pad_token_id is None:
1199
+ last_non_pad_token = -1
1200
+ elif input_ids is not None:
1201
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1202
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1203
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
1204
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1205
+ else:
1206
+ last_non_pad_token = -1
1207
+ logger.warning_once(
1208
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1209
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1210
+ )
1211
+
1212
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1213
+
1214
+ loss = None
1215
+ if labels is not None:
1216
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1217
+
1218
+ return SequenceClassifierOutputWithPast(
1219
+ loss=loss,
1220
+ logits=pooled_logits,
1221
+ past_key_values=transformer_outputs.past_key_values,
1222
+ hidden_states=transformer_outputs.hidden_states,
1223
+ attentions=transformer_outputs.attentions,
1224
+ )
1225
+
1226
+
1227
+ class xQwenForTokenClassification(xQwenPreTrainedModel):
1228
+ def __init__(self, config):
1229
+ super().__init__(config)
1230
+ self.num_labels = config.num_labels
1231
+ # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
1232
+ self.model = xQwenModel(config)
1233
+ if getattr(config, "classifier_dropout", None) is not None:
1234
+ classifier_dropout = config.classifier_dropout
1235
+ elif getattr(config, "hidden_dropout", None) is not None:
1236
+ classifier_dropout = config.hidden_dropout
1237
+ else:
1238
+ classifier_dropout = 0.1
1239
+ self.dropout = nn.Dropout(classifier_dropout)
1240
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1241
+
1242
+ # Initialize weights and apply final processing
1243
+ self.post_init()
1244
+
1245
+ @can_return_tuple
1246
+ @auto_docstring
1247
+ def forward(
1248
+ self,
1249
+ input_ids: Optional[torch.LongTensor] = None,
1250
+ attention_mask: Optional[torch.Tensor] = None,
1251
+ position_ids: Optional[torch.LongTensor] = None,
1252
+ past_key_values: Optional[Cache] = None,
1253
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1254
+ labels: Optional[torch.LongTensor] = None,
1255
+ use_cache: Optional[bool] = None,
1256
+ **kwargs,
1257
+ ) -> TokenClassifierOutput:
1258
+ outputs: BaseModelOutputWithPast = self.model(
1259
+ input_ids,
1260
+ attention_mask=attention_mask,
1261
+ position_ids=position_ids,
1262
+ past_key_values=past_key_values,
1263
+ inputs_embeds=inputs_embeds,
1264
+ use_cache=use_cache,
1265
+ **kwargs,
1266
+ )
1267
+ sequence_output = outputs.last_hidden_state
1268
+ sequence_output = self.dropout(sequence_output)
1269
+ logits = self.score(sequence_output)
1270
+
1271
+ loss = None
1272
+ if labels is not None:
1273
+ loss = self.loss_function(logits, labels, self.config)
1274
+
1275
+ return TokenClassifierOutput(
1276
+ loss=loss,
1277
+ logits=logits,
1278
+ hidden_states=outputs.hidden_states,
1279
+ attentions=outputs.attentions,
1280
+ )
1281
+
1282
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1283
+
1284
+ AutoConfig.register(xQwenConfig.model_type, xQwenConfig)
1285
+ AutoModel.register(xQwenConfig, xQwenModel)
1286
+ AutoModelForCausalLM.register(xQwenConfig, xQwenForCausalLM)
1287
+
1288
+
1289
+ __all__ = [
1290
+ "xQwenForCausalLM",
1291
+ "xQwenModel",
1292
+ "xQwenPreTrainedModel",
1293
+ "xQwenForSequenceClassification",
1294
+ "xQwenForTokenClassification",
1295
+ ]
1296
+