Commit
·
46f1d8c
1
Parent(s):
cd708ec
commit files to HF hub
Browse files- config.json +3 -1
- fasttext_fsc.py +186 -46
config.json
CHANGED
|
@@ -19,7 +19,9 @@
|
|
| 19 |
},
|
| 20 |
"max_length": 128,
|
| 21 |
"model_type": "fasttext_classification",
|
| 22 |
-
"
|
|
|
|
|
|
|
| 23 |
"tokenizerI_class": "FastTextJpTokenizer",
|
| 24 |
"tokenizer_class": "FastTextJpTokenizer",
|
| 25 |
"torch_dtype": "float32",
|
|
|
|
| 19 |
},
|
| 20 |
"max_length": 128,
|
| 21 |
"model_type": "fasttext_classification",
|
| 22 |
+
"ngrams": [
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
"tokenizerI_class": "FastTextJpTokenizer",
|
| 26 |
"tokenizer_class": "FastTextJpTokenizer",
|
| 27 |
"torch_dtype": "float32",
|
fasttext_fsc.py
CHANGED
|
@@ -11,82 +11,92 @@ class FastTextForSeuqenceClassificationConfig(FastTextJpConfig):
|
|
| 11 |
model_type = "fasttext_classification"
|
| 12 |
|
| 13 |
def __init__(self,
|
| 14 |
-
ngram: int = 2,
|
| 15 |
tokenizer_class="FastTextJpTokenizer",
|
| 16 |
**kwargs):
|
| 17 |
"""初期化処理
|
| 18 |
|
| 19 |
Args:
|
| 20 |
-
ngram (int, optional):
|
| 21 |
-
文章を分割する際のNgram
|
| 22 |
tokenizer_class (str, optional):
|
| 23 |
tokenizer_classを指定しないと、pipelineから読み込まれません。
|
| 24 |
config.jsonに記載されます。
|
| 25 |
"""
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
kwargs["tokenizer_class"] = tokenizer_class
|
| 28 |
super().__init__(**kwargs)
|
| 29 |
|
| 30 |
|
| 31 |
-
class
|
| 32 |
-
"""FastTextのベクトルをベースとした分類を行います。
|
| 33 |
-
"""
|
| 34 |
|
| 35 |
-
def __init__(self
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
-
|
|
|
|
| 45 |
"""
|
| 46 |
-
input_ids = inputs["input_ids"]
|
| 47 |
-
outputs = self.word_embeddings(input_ids)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
for idx in range(len(outputs)):
|
| 51 |
-
output = outputs[idx]
|
| 52 |
-
# token_type_ids == 0が文章、1がラベルです。
|
| 53 |
-
token_type_ids = inputs["token_type_ids"][idx]
|
| 54 |
-
# attention_mask == 1がパディングでないもの
|
| 55 |
-
attention_mask = inputs["attention_mask"][idx]
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
attention_mask == 1)]
|
| 61 |
-
sentence_words = self.split_ngram(sentence, self.max_ngram)
|
| 62 |
-
candidate_label_mean = torch.mean(candidate_label,
|
| 63 |
-
dim=-2,
|
| 64 |
-
keepdim=True)
|
| 65 |
-
p = self.cosine_similarity(sentence_words, candidate_label_mean)
|
| 66 |
-
logits.append([torch.log(p), -torch.inf, torch.log(1 - p)])
|
| 67 |
-
logits = torch.FloatTensor(logits)
|
| 68 |
-
return SequenceClassifierOutput(
|
| 69 |
-
loss=None,
|
| 70 |
-
logits=logits,
|
| 71 |
-
hidden_states=None,
|
| 72 |
-
attentions=None,
|
| 73 |
-
)
|
| 74 |
|
| 75 |
def cosine_similarity(
|
| 76 |
-
self,
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
res = torch.tensor(0.)
|
| 79 |
-
for i in range(len(
|
| 80 |
-
sw =
|
| 81 |
p = torch.nn.functional.cosine_similarity(sw,
|
| 82 |
-
|
| 83 |
dim=0)
|
| 84 |
if p > res:
|
| 85 |
res = p
|
| 86 |
return res
|
| 87 |
|
| 88 |
-
def split_ngram(self, sentences: TensorType["
|
| 89 |
-
n: int) -> TensorType["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
res = []
|
| 91 |
if len(sentences) <= n:
|
| 92 |
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)])
|
|
@@ -96,6 +106,136 @@ class FastTextForSeuqenceClassification(FastTextJpModel):
|
|
| 96 |
return torch.stack(res)
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
|
| 100 |
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
|
| 101 |
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
|
|
|
|
| 11 |
model_type = "fasttext_classification"
|
| 12 |
|
| 13 |
def __init__(self,
|
| 14 |
+
ngram: int | list[int] = 2,
|
| 15 |
tokenizer_class="FastTextJpTokenizer",
|
| 16 |
**kwargs):
|
| 17 |
"""初期化処理
|
| 18 |
|
| 19 |
Args:
|
| 20 |
+
ngram (int | list[int], optional):
|
| 21 |
+
文章を分割する際のNgram。
|
| 22 |
tokenizer_class (str, optional):
|
| 23 |
tokenizer_classを指定しないと、pipelineから読み込まれません。
|
| 24 |
config.jsonに記載されます。
|
| 25 |
"""
|
| 26 |
+
if isinstance(ngram, int):
|
| 27 |
+
self.ngrams = [ngram]
|
| 28 |
+
elif isinstance(ngram, list):
|
| 29 |
+
self.ngrams = ngram
|
| 30 |
+
else:
|
| 31 |
+
raise TypeError(f"got unknown type {type(ngram)}")
|
| 32 |
kwargs["tokenizer_class"] = tokenizer_class
|
| 33 |
super().__init__(**kwargs)
|
| 34 |
|
| 35 |
|
| 36 |
+
class NgramForSeuqenceClassification():
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
def __init__(self):
|
| 39 |
+
...
|
| 40 |
|
| 41 |
+
def __call__(self, sentence: TensorType["A", "vectors"],
|
| 42 |
+
candidate_label: TensorType["B", "vectors"],
|
| 43 |
+
ngram: int) -> TensorType[3]:
|
| 44 |
+
"""Ngramで文章を分けてコサイン類似度を算出する。
|
| 45 |
|
| 46 |
+
Args:
|
| 47 |
+
sentence (TensorType["A", "vectors"]): 文章ベクトル
|
| 48 |
+
candidate_label (TensorType["B", "vectors"]): ラベルベクトル
|
| 49 |
+
ngram (int): Ngram
|
| 50 |
|
| 51 |
Returns:
|
| 52 |
+
TensorType[3]:
|
| 53 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
| 54 |
"""
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
sentence_ngrams = self.split_ngram(sentence, ngram)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
candidate_label_mean = torch.mean(candidate_label, dim=0, keepdim=True)
|
| 59 |
+
p = self.cosine_similarity(sentence_ngrams, candidate_label_mean)
|
| 60 |
+
return torch.tensor([torch.log(p), -torch.inf, torch.log(1 - p)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def cosine_similarity(
|
| 63 |
+
self, sentence_ngrams: TensorType["ngrams", "vectors"],
|
| 64 |
+
candidate_label_mean: TensorType[1, "vectors"]) -> TensorType[1]:
|
| 65 |
+
"""コサイン類似度を計算する。
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
sentence_ngrams (TensorType["ngrams", "vectors"]):
|
| 69 |
+
Ngram化された文章ベクトル
|
| 70 |
+
candidate_label_mean (TensorType[1, "vectors"]):
|
| 71 |
+
ラベルベクトル
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
TensorType[1]: _description_
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
res = torch.tensor(0.)
|
| 78 |
+
for i in range(len(sentence_ngrams)):
|
| 79 |
+
sw = sentence_ngrams[i]
|
| 80 |
p = torch.nn.functional.cosine_similarity(sw,
|
| 81 |
+
candidate_label_mean[0],
|
| 82 |
dim=0)
|
| 83 |
if p > res:
|
| 84 |
res = p
|
| 85 |
return res
|
| 86 |
|
| 87 |
+
def split_ngram(self, sentences: TensorType["A", "vectors"],
|
| 88 |
+
n: int) -> TensorType["ngrams", "vectors"]:
|
| 89 |
+
"""AとBの関連度を計算します。
|
| 90 |
+
Args:
|
| 91 |
+
sentences(TensorType["A", "vectors"]):
|
| 92 |
+
対象の文章
|
| 93 |
+
n(int):
|
| 94 |
+
ngram
|
| 95 |
+
Returns:
|
| 96 |
+
TensorType["ngrams", "vectors"]:
|
| 97 |
+
Ngram化された文章
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
res = []
|
| 101 |
if len(sentences) <= n:
|
| 102 |
return torch.stack([torch.mean(sentences, dim=0, keepdim=False)])
|
|
|
|
| 106 |
return torch.stack(res)
|
| 107 |
|
| 108 |
|
| 109 |
+
class NgramsForSeuqenceClassification():
|
| 110 |
+
|
| 111 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
| 112 |
+
self.max_ngrams = config.ngrams
|
| 113 |
+
self.ngram_layer = NgramForSeuqenceClassification()
|
| 114 |
+
|
| 115 |
+
def __call__(self, sentence: TensorType["A", "vectors"],
|
| 116 |
+
candidate_label: TensorType["B", "vectors"]) -> TensorType[3]:
|
| 117 |
+
"""AとBの関連度を計算します。
|
| 118 |
+
Args:
|
| 119 |
+
sentence(TensorType["A", "vectors"]):
|
| 120 |
+
対象の文章
|
| 121 |
+
candidate_label(TensorType["B", "vectors"]):
|
| 122 |
+
ラベルの文章
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
TensorType[3]:
|
| 126 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
res = [-torch.inf, -torch.inf, -torch.inf]
|
| 130 |
+
for ngram in self.max_ngrams:
|
| 131 |
+
logit = self.ngram_layer(sentence, candidate_label, ngram)
|
| 132 |
+
if logit[0] > res[0]:
|
| 133 |
+
res = logit
|
| 134 |
+
return torch.tensor(res)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class BatchedNgramsForSeuqenceClassification():
|
| 138 |
+
|
| 139 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
| 140 |
+
self.ngrams_layer = NgramsForSeuqenceClassification(config)
|
| 141 |
+
|
| 142 |
+
def __call__(
|
| 143 |
+
self,
|
| 144 |
+
last_hidden_state: TensorType["batch", "A+B", "vectors"],
|
| 145 |
+
token_type_ids: TensorType["batch", "A+B"],
|
| 146 |
+
attention_mask: TensorType["batch", "A+B"],
|
| 147 |
+
) -> TensorType["batch", 3]:
|
| 148 |
+
"""AとBの関連度を計算します。
|
| 149 |
+
Args:
|
| 150 |
+
last_hidden_state(TensorType["batch", "A+B", "vectors"]):
|
| 151 |
+
embeddingsの値。
|
| 152 |
+
token_type_ids(TensorType["A+B"]):
|
| 153 |
+
文章のid。0か1で、Bの場合1。
|
| 154 |
+
attention_mask(TensorType["A+B"]):
|
| 155 |
+
padを識別する。0か1で、padの場合1。
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
TensorType["batch", 3]:
|
| 159 |
+
文章の類似度。[Entailment, Neutral, Contradiction]
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
logits = []
|
| 163 |
+
embeddings = last_hidden_state
|
| 164 |
+
for idx in range(len(embeddings)):
|
| 165 |
+
vec = embeddings[idx]
|
| 166 |
+
# token_type_ids == 0が文章、1がラベルです。
|
| 167 |
+
token_type_ids = token_type_ids[idx]
|
| 168 |
+
# attention_mask == 1がパディングでないもの
|
| 169 |
+
attention_mask = attention_mask[idx]
|
| 170 |
+
|
| 171 |
+
sentence, candidate_label = self.split_sentence(
|
| 172 |
+
vec, token_type_ids, attention_mask)
|
| 173 |
+
logit = self.ngrams_layer(sentence, candidate_label)
|
| 174 |
+
logits.append(logit)
|
| 175 |
+
logits = torch.tensor(logits)
|
| 176 |
+
return logits
|
| 177 |
+
|
| 178 |
+
def split_sentence(
|
| 179 |
+
self, vec: TensorType["A+B", "vectors"],
|
| 180 |
+
token_type_ids: TensorType["A+B"], attention_mask: TensorType["A+B"]
|
| 181 |
+
) -> tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
|
| 182 |
+
"""CrossEncoderになっているので、文章を分割します。
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
vec(TensorType["A+B","vectors"]):
|
| 186 |
+
単語ベクトル
|
| 187 |
+
|
| 188 |
+
token_type_ids(TensorType["A+B"]):
|
| 189 |
+
文章のid。0か1で、Bの場合1。
|
| 190 |
+
|
| 191 |
+
attention_mask(TensorType["A+B"]):
|
| 192 |
+
padを識別する。0か1で、padの場合1。
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
tuple[TensorType["A", "vectors"], TensorType["B", "vectors"]]:
|
| 196 |
+
AとBの文章を分割して返します。
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
sentence = vec[torch.logical_and(token_type_ids == 0,
|
| 200 |
+
attention_mask == 1)]
|
| 201 |
+
candidate_label = vec[torch.logical_and(token_type_ids == 1,
|
| 202 |
+
attention_mask == 1)]
|
| 203 |
+
return sentence, candidate_label
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class FastTextForSeuqenceClassification(FastTextJpModel):
|
| 207 |
+
"""FastTextのベクトルをベースとした分類を行います。
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
|
| 211 |
+
|
| 212 |
+
self.layer = BatchedNgramsForSeuqenceClassification(config)
|
| 213 |
+
super().__init__(config)
|
| 214 |
+
|
| 215 |
+
def forward(
|
| 216 |
+
self,
|
| 217 |
+
input_ids: TensorType["batch", "A+B", "vecotors"] = None,
|
| 218 |
+
attention_mask: TensorType["batch", "A+B"] = None,
|
| 219 |
+
token_type_ids: TensorType["batch", "A+B"] = None
|
| 220 |
+
) -> SequenceClassifierOutput:
|
| 221 |
+
"""候補となるラベルから分類を行います。
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
SequenceClassifierOutput: 候補が正解している確率
|
| 225 |
+
"""
|
| 226 |
+
outputs = self.word_embeddings(input_ids)
|
| 227 |
+
logits = self.layer(last_hidden_state=outputs,
|
| 228 |
+
attention_mask=attention_mask,
|
| 229 |
+
token_type_ids=token_type_ids)
|
| 230 |
+
|
| 231 |
+
return SequenceClassifierOutput(
|
| 232 |
+
loss=None,
|
| 233 |
+
logits=logits,
|
| 234 |
+
hidden_states=None,
|
| 235 |
+
attentions=None,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
|
| 240 |
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
|
| 241 |
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
|