Upload model.py
Browse files
model.py
CHANGED
|
@@ -23,7 +23,7 @@ class RMSNorm(torch.nn.Module):
|
|
| 23 |
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
| 24 |
|
| 25 |
|
| 26 |
-
def precompute_pos_cis(dim: int, end: int, theta: float =
|
| 27 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 28 |
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 29 |
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
@@ -295,8 +295,9 @@ class MiniMindLM(PreTrainedModel):
|
|
| 295 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 296 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
| 297 |
self.tok_embeddings.weight = self.output.weight
|
| 298 |
-
self.register_buffer("pos_cis",
|
| 299 |
-
|
|
|
|
| 300 |
self.OUT = CausalLMOutputWithPast()
|
| 301 |
|
| 302 |
def forward(self,
|
|
@@ -328,13 +329,13 @@ class MiniMindLM(PreTrainedModel):
|
|
| 328 |
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
| 329 |
# 流式生成
|
| 330 |
if stream:
|
| 331 |
-
return self.
|
| 332 |
|
| 333 |
# 直接生成
|
| 334 |
generated = []
|
| 335 |
for i in range(input_ids.size(0)):
|
| 336 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
| 337 |
-
out = self.
|
| 338 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
| 339 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
| 340 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
|
@@ -348,14 +349,14 @@ class MiniMindLM(PreTrainedModel):
|
|
| 348 |
]
|
| 349 |
return torch.cat(generated, dim=0)
|
| 350 |
|
| 351 |
-
def
|
| 352 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
| 353 |
while input_ids.shape[1] < max_new_tokens - 1:
|
| 354 |
if first_seq or not use_cache:
|
| 355 |
-
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
|
| 356 |
else:
|
| 357 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
| 358 |
-
start_pos=input_ids.shape[1] - 1)
|
| 359 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
| 360 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
| 361 |
logits /= (temperature + 1e-9)
|
|
|
|
| 23 |
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
|
| 24 |
|
| 25 |
|
| 26 |
+
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
| 27 |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 28 |
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 29 |
freqs = torch.outer(t, freqs).float() # type: ignore
|
|
|
|
| 295 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 296 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
| 297 |
self.tok_embeddings.weight = self.output.weight
|
| 298 |
+
self.register_buffer("pos_cis",
|
| 299 |
+
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
| 300 |
+
persistent=False)
|
| 301 |
self.OUT = CausalLMOutputWithPast()
|
| 302 |
|
| 303 |
def forward(self,
|
|
|
|
| 329 |
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
| 330 |
# 流式生成
|
| 331 |
if stream:
|
| 332 |
+
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
| 333 |
|
| 334 |
# 直接生成
|
| 335 |
generated = []
|
| 336 |
for i in range(input_ids.size(0)):
|
| 337 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
| 338 |
+
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
| 339 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
| 340 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
| 341 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
|
|
|
| 349 |
]
|
| 350 |
return torch.cat(generated, dim=0)
|
| 351 |
|
| 352 |
+
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
| 353 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
| 354 |
while input_ids.shape[1] < max_new_tokens - 1:
|
| 355 |
if first_seq or not use_cache:
|
| 356 |
+
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
| 357 |
else:
|
| 358 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
| 359 |
+
start_pos=input_ids.shape[1] - 1, **args)
|
| 360 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
| 361 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
| 362 |
logits /= (temperature + 1e-9)
|