File size: 7,132 Bytes
8e3610c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import numpy as np
import networkx as nx
from typing import Optional, List
class Pooler:
def __init__(self, pooling_types: List[str]):
self.pooling_types = pooling_types
self.pooling_options = {
'mean': self.mean_pooling,
'max': self.max_pooling,
'norm': self.norm_pooling,
'median': self.median_pooling,
'std': self.std_pooling,
'var': self.var_pooling,
'cls': self.cls_pooling,
'parti': self._pool_parti,
}
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
maxed_attentions = torch.max(attentions, dim=1)[0]
return maxed_attentions
def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
# Run PageRank on the attention matrix converted to a graph.
# Raises exceptions if the graph doesn't match the token sequence or has no edges.
# Returns the PageRank scores for each token node.
G = self._convert_to_graph(attention_matrix)
if G.number_of_nodes() != attention_matrix.shape[0]:
raise Exception(
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
if G.number_of_edges() == 0:
raise Exception(f"You don't seem to have any attention edges left in the graph.")
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
def _convert_to_graph(self, matrix):
# Convert a matrix (e.g., attention scores) to a directed graph using networkx.
# Each element in the matrix represents a directed edge with a weight.
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
return G
def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
# Remove keys where attention_mask is 0
if attention_mask is not None:
for k in list(dict_importance.keys()):
if attention_mask[k] == 0:
del dict_importance[k]
#dict_importance[0] # remove cls
#dict_importance[-1] # remove eos
total = sum(dict_importance.values())
return np.array([v / total for _, v in dict_importance.items()])
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
# emb is (b, L, d), maxed_attentions is (b, L, L)
emb_pooled = []
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
dict_importance = self._page_rank(a)
importance_weights = self._calculate_importance_weights(dict_importance, mask)
num_tokens = int(mask.sum().item())
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
pooled = torch.tensor(np.array(emb_pooled))
return pooled
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.mean(dim=1)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.max(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).max(dim=1).values
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.norm(dim=1, p=2)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).norm(dim=1, p=2)
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.median(dim=1).values
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).median(dim=1).values
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.std(dim=1)
else:
# Compute variance correctly over non-masked positions, then take sqrt
var = self.var_pooling(emb, attention_mask, **kwargs)
return torch.sqrt(var)
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
if attention_mask is None:
return emb.var(dim=1)
else:
# Correctly compute variance over only non-masked positions
attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
# Compute mean over non-masked positions
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
mean = mean.unsqueeze(1) # (b, 1, d)
# Compute squared differences from mean, only over non-masked positions
squared_diff = (emb - mean) ** 2 # (b, L, d)
# Sum squared differences over non-masked positions and divide by count
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
return var
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
return emb[:, 0, :]
def __call__(
self,
emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attentions: Optional[torch.Tensor] = None
): # [mean, max]
final_emb = []
for pooling_type in self.pooling_types:
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
if __name__ == "__main__":
# py -m pooler
pooler = Pooler(pooling_types=['max', 'parti'])
batch_size = 8
seq_len = 64
hidden_size = 128
num_layers = 12
emb = torch.randn(batch_size, seq_len, hidden_size)
attentions = torch.randn(batch_size, num_layers, seq_len, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions)
print(y.shape)
|