Upload bert_padding.py with huggingface_hub
Browse files- bert_padding.py +154 -0
bert_padding.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 MosaicML Examples authors
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
| 5 |
+
# Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from typing import Tuple, cast
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class IndexFirstAxis(torch.autograd.Function):
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def forward(ctx, input: torch.Tensor,
|
| 19 |
+
indices: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""Get just the values of `input` which are at `indices`.
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
ctx: the autograd context object
|
| 24 |
+
input: (b, ...) 2+ dimensional tensor
|
| 25 |
+
indices: (num_idx) 1D tensor
|
| 26 |
+
"""
|
| 27 |
+
ctx.save_for_backward(indices)
|
| 28 |
+
assert input.ndim >= 2
|
| 29 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
|
| 30 |
+
1:] # type: ignore
|
| 31 |
+
second_dim = other_shape.numel(
|
| 32 |
+
) # product of sizes of all but first dimension
|
| 33 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 34 |
+
return torch.gather(
|
| 35 |
+
rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
|
| 36 |
+
0,
|
| 37 |
+
repeat(indices, 'z -> z d',
|
| 38 |
+
d=second_dim) # (indices,) -> (indices, second_dim)
|
| 39 |
+
).reshape(-1, *other_shape) # (num_idx, ...)
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
| 43 |
+
indices, = ctx.saved_tensors
|
| 44 |
+
assert grad_output.ndim >= 2
|
| 45 |
+
other_shape = grad_output.shape[1:]
|
| 46 |
+
grad_output = rearrange(grad_output, 'b ... -> b (...)')
|
| 47 |
+
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
|
| 48 |
+
device=grad_output.device,
|
| 49 |
+
dtype=grad_output.dtype)
|
| 50 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 51 |
+
# grad_input[indices] = grad_output
|
| 52 |
+
grad_input.scatter_(0,
|
| 53 |
+
repeat(indices, 'z -> z d', d=grad_output.shape[1]),
|
| 54 |
+
grad_output)
|
| 55 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
index_first_axis = IndexFirstAxis.apply
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
|
| 65 |
+
first_axis_dim) -> torch.Tensor:
|
| 66 |
+
ctx.save_for_backward(indices)
|
| 67 |
+
assert indices.ndim == 1
|
| 68 |
+
assert values.ndim >= 2
|
| 69 |
+
output = torch.zeros(first_axis_dim,
|
| 70 |
+
*values.shape[1:],
|
| 71 |
+
device=values.device,
|
| 72 |
+
dtype=values.dtype)
|
| 73 |
+
output[indices] = values
|
| 74 |
+
return output
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def backward(ctx,
|
| 78 |
+
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
|
| 79 |
+
indices, = ctx.saved_tensors
|
| 80 |
+
grad_values = grad_output[indices]
|
| 81 |
+
return grad_values, None, None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def unpad_input(
|
| 88 |
+
hidden_states: torch.Tensor,
|
| 89 |
+
attention_mask: torch.Tensor,
|
| 90 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
| 91 |
+
"""Remove padding from input sequences.
|
| 92 |
+
|
| 93 |
+
Arguments:
|
| 94 |
+
hidden_states: (batch, seqlen, ...)
|
| 95 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 99 |
+
indices: (total_nnz)
|
| 100 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 101 |
+
max_seqlen_in_batch: int ()
|
| 102 |
+
"""
|
| 103 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 104 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 105 |
+
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
| 106 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
|
| 107 |
+
(1, 0))
|
| 108 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 109 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 110 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 111 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
| 112 |
+
# so we write custom forward and backward to make it a bit faster.
|
| 113 |
+
hidden_states = cast(
|
| 114 |
+
torch.Tensor,
|
| 115 |
+
index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
|
| 116 |
+
indices))
|
| 117 |
+
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def unpad_input_only(
|
| 121 |
+
hidden_states: torch.Tensor,
|
| 122 |
+
attention_mask: torch.Tensor,
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
"""Like unpad_input, but only return the unpadded first tensor.
|
| 125 |
+
|
| 126 |
+
Save a small amount of overhead.
|
| 127 |
+
|
| 128 |
+
Arguments:
|
| 129 |
+
hidden_states: (batch, seqlen, ...)
|
| 130 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 134 |
+
"""
|
| 135 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 136 |
+
return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
|
| 137 |
+
indices)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
|
| 141 |
+
seqlen: int) -> torch.Tensor:
|
| 142 |
+
"""Add padding to sequences.
|
| 143 |
+
|
| 144 |
+
Arguments:
|
| 145 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 146 |
+
indices: (total_nnz)
|
| 147 |
+
batch: int batch_size
|
| 148 |
+
seqlen: int max sequence length
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
hidden_states: (batch, seqlen, ...)
|
| 152 |
+
"""
|
| 153 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 154 |
+
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|