Upload model.py with huggingface_hub
Browse files
model.py
CHANGED
|
@@ -145,7 +145,7 @@ class GPT(nn.Module):
|
|
| 145 |
|
| 146 |
def _init_weights(self, module):
|
| 147 |
if isinstance(module, nn.Linear):
|
| 148 |
-
torch.nn.init.
|
| 149 |
if module.bias is not None:
|
| 150 |
torch.nn.init.zeros_(module.bias)
|
| 151 |
elif isinstance(module, nn.Embedding):
|
|
@@ -285,90 +285,80 @@ class GPT(nn.Module):
|
|
| 285 |
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
| 286 |
mfu = flops_achieved / flops_promised
|
| 287 |
return mfu
|
| 288 |
-
|
| 289 |
@torch.no_grad()
|
| 290 |
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
|
| 291 |
-
"""
|
| 292 |
-
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 293 |
-
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 294 |
-
Strategy can be 'greedy', 'sampling' or 'top-k'.
|
| 295 |
-
"""
|
| 296 |
-
# check strategy valid
|
| 297 |
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
all_candidates = []
|
| 325 |
-
|
| 326 |
-
for i, seq in enumerate(beam_seqs):
|
| 327 |
-
# Get next token probabilities
|
| 328 |
-
idx_cond = seq if seq.size(1) <= self.config.block_size else seq[:, -self.config.block_size:]
|
| 329 |
-
logits, __ = self(idx_cond)
|
| 330 |
-
logits = logits[:, -1, :]
|
| 331 |
-
probs = F.log_softmax(logits, dim=-1) # Use log probs to avoid numerical instability
|
| 332 |
-
|
| 333 |
-
# Get top sequences for this beam (we could use more than beam_size here for diversity)
|
| 334 |
-
scores, indices = torch.topk(probs, beam_size)
|
| 335 |
-
for j in range(beam_size):
|
| 336 |
-
candidate_seq = torch.cat([seq, indices[:, j:j+1]], dim=1)
|
| 337 |
-
candidate_score = beam_scores[:, i] + scores[:, j]
|
| 338 |
-
|
| 339 |
all_candidates.append((candidate_score, candidate_seq))
|
| 340 |
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
probs = F.softmax(logits, dim=-1)
|
| 362 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
if idx_next == eos_token_id:
|
| 368 |
break
|
| 369 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 370 |
|
| 371 |
-
|
| 372 |
return idx if idx[0][0] != eos_token_id else idx[:, 1:]
|
| 373 |
-
|
| 374 |
-
|
|
|
|
| 145 |
|
| 146 |
def _init_weights(self, module):
|
| 147 |
if isinstance(module, nn.Linear):
|
| 148 |
+
torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
|
| 149 |
if module.bias is not None:
|
| 150 |
torch.nn.init.zeros_(module.bias)
|
| 151 |
elif isinstance(module, nn.Embedding):
|
|
|
|
| 285 |
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
| 286 |
mfu = flops_achieved / flops_promised
|
| 287 |
return mfu
|
| 288 |
+
|
| 289 |
@torch.no_grad()
|
| 290 |
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
| 292 |
|
| 293 |
+
batch_size = idx.size(0)
|
| 294 |
+
if strategy == 'beam_search':
|
| 295 |
+
# Initialize beams
|
| 296 |
+
beam_seqs = [idx.clone() for _ in range(beam_size)]
|
| 297 |
+
beam_scores = torch.zeros((batch_size, beam_size), device=idx.device)
|
| 298 |
+
completed_seqs = []
|
| 299 |
+
|
| 300 |
+
for _ in range(max_new_tokens):
|
| 301 |
+
all_candidates = []
|
| 302 |
+
for i in range(beam_size):
|
| 303 |
+
idx_cond = beam_seqs[i] if beam_seqs[i].size(1) <= self.config.block_size else beam_seqs[i][:, -self.config.block_size:]
|
| 304 |
+
logits, _ = self(idx_cond)
|
| 305 |
+
logits = logits[:, -1, :] / temperature
|
| 306 |
+
if repetition_penalty != 1.0:
|
| 307 |
+
for j in range(idx_cond.size(1)):
|
| 308 |
+
logits[:, idx_cond[:, j]] /= repetition_penalty
|
| 309 |
+
probs = F.log_softmax(logits, dim=-1)
|
| 310 |
+
scores, indices = torch.topk(probs, beam_size, dim=-1)
|
| 311 |
+
|
| 312 |
+
for j in range(beam_size):
|
| 313 |
+
candidate_seq = torch.cat([beam_seqs[i], indices[:, j:j+1]], dim=1)
|
| 314 |
+
candidate_score = beam_scores[:, i] + scores[:, j]
|
| 315 |
+
if indices[0, j] == eos_token_id:
|
| 316 |
+
completed_seqs.append((candidate_score, candidate_seq))
|
| 317 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
all_candidates.append((candidate_score, candidate_seq))
|
| 319 |
|
| 320 |
+
# add random noise when sorting beacause that generated sequences of beam_search remain unchanged if they have the same prefix
|
| 321 |
+
all_candidates.sort(key=lambda x: x[0].mean().item() + torch.rand(1).item() * 5e-1, reverse=True)
|
| 322 |
+
|
| 323 |
+
beam_seqs = [all_candidates[i][1] for i in range(min(beam_size, len(all_candidates)))]
|
| 324 |
+
beam_scores = torch.stack([all_candidates[i][0] for i in range(min(beam_size, len(all_candidates)))], dim=1)
|
| 325 |
+
if len(completed_seqs) >= beam_size:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
if not completed_seqs:
|
| 329 |
+
completed_seqs = [(beam_scores[:, i], beam_seqs[i]) for i in range(beam_size)]
|
| 330 |
+
|
| 331 |
+
completed_seqs.sort(key=lambda x: x[0].mean().item(), reverse=True)
|
| 332 |
+
return completed_seqs[0][1]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
for _ in range(max_new_tokens):
|
| 337 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
| 338 |
+
logits, _ = self(idx_cond)
|
| 339 |
+
logits = logits[:, -1, :] / temperature
|
| 340 |
+
|
| 341 |
+
if repetition_penalty != 1.0:
|
| 342 |
+
for i in range(idx.size(0)):
|
| 343 |
+
for j in range(idx.size(1)):
|
| 344 |
+
logits[i, idx[i, j]] /= repetition_penalty
|
| 345 |
+
|
| 346 |
+
if strategy == 'greedy_search':
|
| 347 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
| 348 |
+
|
| 349 |
+
elif strategy == 'sampling':
|
| 350 |
probs = F.softmax(logits, dim=-1)
|
| 351 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 352 |
+
|
| 353 |
+
elif strategy == 'top_k':
|
| 354 |
+
if top_k is not None:
|
| 355 |
+
logits, indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 356 |
+
probs = F.softmax(logits, dim=-1)
|
| 357 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 358 |
+
idx_next = torch.gather(indices, dim=-1, index=idx_next)
|
| 359 |
+
|
| 360 |
if idx_next == eos_token_id:
|
| 361 |
break
|
| 362 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 363 |
|
|
|
|
| 364 |
return idx if idx[0][0] != eos_token_id else idx[:, 1:]
|
|
|
|
|
|