LH-Tech-AI commited on
Commit
e40716c
·
verified ·
1 Parent(s): d80d74f

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +224 -0
train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # Extractive Question Answering – From Scratch on SQuAD
3
+ # Kaggle T4 (16GB VRAM) | HF Transformers
4
+ # ============================================================
5
+
6
+ # ── Imports ─────────────────────────────────────────────────
7
+ import numpy as np
8
+ import collections
9
+ import evaluate
10
+ from datasets import load_dataset
11
+ from transformers import (
12
+ BertConfig,
13
+ BertForQuestionAnswering,
14
+ BertTokenizerFast,
15
+ DefaultDataCollator,
16
+ TrainingArguments,
17
+ Trainer,
18
+ )
19
+
20
+ # ── Config ───────────────────────────────────────────────────
21
+ MODEL_NAME = "bert-base-uncased" # tokenizer only!
22
+ MAX_LENGTH = 384
23
+ DOC_STRIDE = 128
24
+ BATCH_SIZE = 16
25
+ EPOCHS = 3
26
+ LR = 3e-4
27
+ OUTPUT_DIR = "distill"
28
+
29
+ # ── 1. Dataset ───────────────────────────────────────────────
30
+ raw = load_dataset("squad")
31
+
32
+ # ── 2. Tokenizer (pretrained vocab, NO pretrained weights) ─
33
+ tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
34
+
35
+ # ── 3. Preprocessing ─────────────────────────────────────────
36
+ def preprocess_train(examples):
37
+ tokenized = tokenizer(
38
+ examples["question"],
39
+ examples["context"],
40
+ max_length=MAX_LENGTH,
41
+ truncation="only_second",
42
+ stride=DOC_STRIDE,
43
+ return_overflowing_tokens=True,
44
+ return_offsets_mapping=True,
45
+ padding="max_length",
46
+ )
47
+ sample_map = tokenized.pop("overflow_to_sample_mapping")
48
+ offset_mapping = tokenized.pop("offset_mapping")
49
+
50
+ start_positions, end_positions = [], []
51
+
52
+ for i, offsets in enumerate(offset_mapping):
53
+ sample_idx = sample_map[i]
54
+ answers = examples["answers"][sample_idx]
55
+ cls_index = tokenized["input_ids"][i].index(tokenizer.cls_token_id)
56
+
57
+ sequence_ids = tokenized.sequence_ids(i)
58
+
59
+ if len(answers["answer_start"]) == 0:
60
+ start_positions.append(cls_index)
61
+ end_positions.append(cls_index)
62
+ continue
63
+
64
+ start_char = answers["answer_start"][0]
65
+ end_char = start_char + len(answers["text"][0])
66
+
67
+ token_start = next((j for j, s in enumerate(sequence_ids) if s == 1), None)
68
+ token_end = next((j for j in range(len(sequence_ids)-1, -1, -1) if sequence_ids[j] == 1), None)
69
+
70
+ if offsets[token_start][0] > end_char or offsets[token_end][1] < start_char:
71
+ start_positions.append(cls_index)
72
+ end_positions.append(cls_index)
73
+ continue
74
+
75
+ start_tok = token_start
76
+ while start_tok <= token_end and offsets[start_tok][0] <= start_char:
77
+ start_tok += 1
78
+ start_positions.append(start_tok - 1)
79
+
80
+ end_tok = token_end
81
+ while end_tok >= token_start and offsets[end_tok][1] >= end_char:
82
+ end_tok -= 1
83
+ end_positions.append(end_tok + 1)
84
+
85
+ tokenized["start_positions"] = start_positions
86
+ tokenized["end_positions"] = end_positions
87
+ return tokenized
88
+
89
+
90
+ def preprocess_validation(examples):
91
+ tokenized = tokenizer(
92
+ examples["question"],
93
+ examples["context"],
94
+ max_length=MAX_LENGTH,
95
+ truncation="only_second",
96
+ stride=DOC_STRIDE,
97
+ return_overflowing_tokens=True,
98
+ return_offsets_mapping=True,
99
+ padding="max_length",
100
+ )
101
+ sample_map = tokenized.pop("overflow_to_sample_mapping")
102
+ tokenized["example_id"] = []
103
+
104
+ for i in range(len(tokenized["input_ids"])):
105
+ sample_idx = sample_map[i]
106
+ tokenized["example_id"].append(examples["id"][sample_idx])
107
+ sequence_ids = tokenized.sequence_ids(i)
108
+ tokenized["offset_mapping"][i] = [
109
+ o if sequence_ids[j] == 1 else None
110
+ for j, o in enumerate(tokenized["offset_mapping"][i])
111
+ ]
112
+ return tokenized
113
+
114
+
115
+ train_dataset = raw["train"].map(
116
+ preprocess_train,
117
+ batched=True,
118
+ remove_columns=raw["train"].column_names,
119
+ )
120
+ val_dataset = raw["validation"].map(
121
+ preprocess_validation,
122
+ batched=True,
123
+ remove_columns=raw["validation"].column_names,
124
+ )
125
+
126
+ # ── 4. Modell FROM SCRATCH ────────────────────────────────────
127
+ config = BertConfig(
128
+ vocab_size=tokenizer.vocab_size, # 30522
129
+ hidden_size=384,
130
+ num_hidden_layers=6,
131
+ num_attention_heads=6,
132
+ intermediate_size=1536,
133
+ max_position_embeddings=512,
134
+ hidden_dropout_prob=0.1,
135
+ attention_probs_dropout_prob=0.1,
136
+ )
137
+ model = BertForQuestionAnswering(config)
138
+ print(f"Parameters: {model.num_parameters():,}") # ~22M
139
+
140
+ # ── 5. Evaluation (Exact Match + F1) ─────────────────────────
141
+ metric = evaluate.load("squad")
142
+
143
+ def compute_metrics(p):
144
+ # p = EvalPrediction with predictions=(start_logits, end_logits)
145
+ start_logits, end_logits = p.predictions
146
+
147
+ n_best = 20
148
+ max_answer_len = 30
149
+ example_ids = val_dataset["example_id"]
150
+ offset_mappings = val_dataset["offset_mapping"]
151
+ contexts = {ex["id"]: ex["context"] for ex in raw["validation"]}
152
+ references = {ex["id"]: ex["answers"] for ex in raw["validation"]}
153
+
154
+ feat_per_example = collections.defaultdict(list)
155
+ for feat_idx, ex_id in enumerate(example_ids):
156
+ feat_per_example[ex_id].append(feat_idx)
157
+
158
+ predicted_answers = []
159
+ for ex_id, feat_indices in feat_per_example.items():
160
+ context = contexts[ex_id]
161
+ candidates = []
162
+
163
+ for fi in feat_indices:
164
+ offsets = offset_mappings[fi]
165
+ s_logits = start_logits[fi]
166
+ e_logits = end_logits[fi]
167
+ s_indexes = np.argsort(s_logits)[-1:-n_best-1:-1].tolist()
168
+ e_indexes = np.argsort(e_logits)[-1:-n_best-1:-1].tolist()
169
+
170
+ for s in s_indexes:
171
+ for e in e_indexes:
172
+ if offsets[s] is None or offsets[e] is None:
173
+ continue
174
+ if e < s or e - s + 1 > max_answer_len:
175
+ continue
176
+ candidates.append({
177
+ "score": s_logits[s] + e_logits[e],
178
+ "text": context[offsets[s][0]: offsets[e][1]],
179
+ })
180
+
181
+ best = max(candidates, key=lambda x: x["score"]) if candidates else {"text": ""}
182
+ predicted_answers.append({"id": ex_id, "prediction_text": best["text"]})
183
+
184
+ formatted_refs = [{"id": k, "answers": v} for k, v in references.items()]
185
+ return metric.compute(predictions=predicted_answers, references=formatted_refs)
186
+
187
+
188
+ # ── 6. Training ───────────────────────────────────────────────
189
+ args = TrainingArguments(
190
+ output_dir=OUTPUT_DIR,
191
+ eval_strategy="steps",
192
+ eval_steps=500,
193
+ save_strategy="steps",
194
+ save_steps=500,
195
+ learning_rate=LR,
196
+ per_device_train_batch_size=BATCH_SIZE,
197
+ per_device_eval_batch_size=BATCH_SIZE,
198
+ num_train_epochs=EPOCHS,
199
+ weight_decay=0.01,
200
+ logging_steps=100,
201
+ fp16=True,
202
+ report_to="none",
203
+ )
204
+
205
+ trainer = Trainer(
206
+ model=model,
207
+ args=args,
208
+ train_dataset=train_dataset,
209
+ eval_dataset=val_for_trainer,
210
+ processing_class=tokenizer,
211
+ data_collator=DefaultDataCollator(),
212
+ compute_metrics=None,
213
+ )
214
+
215
+ trainer.train()
216
+
217
+ # ── 7. Final evaluation ────────────────────────────
218
+ print("--- Starting final evaluation ---")
219
+ predictions = trainer.predict(val_for_trainer)
220
+ final_metrics = compute_metrics(predictions)
221
+ print(f"Final results: {final_metrics}")
222
+
223
+ trainer.save_model(OUTPUT_DIR)
224
+ print("✅ DONE!")