Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).
#17
by
Birchlabs
- opened
- modeling_flash_llama.py +1 -1
modeling_flash_llama.py
CHANGED
|
@@ -378,7 +378,7 @@ class LlamaAttention(nn.Module):
|
|
| 378 |
|
| 379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
| 380 |
attn_output = pad_input(
|
| 381 |
-
attn_output, indices_q, bsz,
|
| 382 |
).reshape(bsz, q_len, h_size)
|
| 383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
| 384 |
|
|
|
|
| 378 |
|
| 379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
| 380 |
attn_output = pad_input(
|
| 381 |
+
attn_output, indices_q, bsz, q_len
|
| 382 |
).reshape(bsz, q_len, h_size)
|
| 383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
| 384 |
|