alvinwatner
commited on
Commit
·
9525173
1
Parent(s):
7e9570d
Updating training metrics
Browse files- run_summarization_flax.py +26 -9
run_summarization_flax.py
CHANGED
|
@@ -589,8 +589,10 @@ def main():
|
|
| 589 |
desc="Running tokenizer on prediction dataset",
|
| 590 |
)
|
| 591 |
|
| 592 |
-
# Metric
|
| 593 |
-
|
|
|
|
|
|
|
| 594 |
|
| 595 |
def postprocess_text(preds, labels):
|
| 596 |
preds = [pred.strip() for pred in preds]
|
|
@@ -609,14 +611,29 @@ def main():
|
|
| 609 |
# Some simple post-processing
|
| 610 |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
| 611 |
|
| 612 |
-
|
|
|
|
|
|
|
| 613 |
# Extract a few results from ROUGE
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
# Enable tensorboard only on the master node
|
| 622 |
has_tensorboard = is_tensorboard_available()
|
|
|
|
| 589 |
desc="Running tokenizer on prediction dataset",
|
| 590 |
)
|
| 591 |
|
| 592 |
+
# Metric
|
| 593 |
+
rouge_metric = load_metric("rouge")
|
| 594 |
+
bleu_metric = load_metric("bleu")
|
| 595 |
+
meteor_metric = load_metric("meteor")
|
| 596 |
|
| 597 |
def postprocess_text(preds, labels):
|
| 598 |
preds = [pred.strip() for pred in preds]
|
|
|
|
| 611 |
# Some simple post-processing
|
| 612 |
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
| 613 |
|
| 614 |
+
results = {}
|
| 615 |
+
rouge_scores = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer = True, \
|
| 616 |
+
rouge_types=['rougeL'])
|
| 617 |
# Extract a few results from ROUGE
|
| 618 |
+
rouge_scores = {key: value.mid.fmeasure * 100 for key, value in rouge_scores.items()}
|
| 619 |
+
rouge_scores = {k: round(v, 4) for k, v in rouge_scores.items()}
|
| 620 |
+
meteor_scores = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)
|
| 621 |
+
meteor_scores = {k: round(v, 4) for k, v in meteor_scores.items()}
|
| 622 |
+
|
| 623 |
+
# Compute bleu-1,2,3,4 scores
|
| 624 |
+
# Postprocess the predictions and references to compute bleu scores
|
| 625 |
+
tokenized_predictions = [decoded_preds[i].split() for i in range(len(decoded_preds))]
|
| 626 |
+
tokenized_labels = [[decoded_labels[i].split()] for i in range(len(decoded_labels))]
|
| 627 |
+
bleu_scores = {f'bleu-{i}' : \
|
| 628 |
+
bleu_metric.compute(predictions=tokenized_predictions, references=tokenized_labels, max_order=i)['bleu']\
|
| 629 |
+
for i in range(1,5)}
|
| 630 |
+
bleu_scores = {k: round(v, 4) for k, v in bleu_scores.items()}
|
| 631 |
+
|
| 632 |
+
results.update(bleu_scores)
|
| 633 |
+
results.update(rouge_scores)
|
| 634 |
+
results.update(meteor_scores)
|
| 635 |
+
|
| 636 |
+
return results
|
| 637 |
|
| 638 |
# Enable tensorboard only on the master node
|
| 639 |
has_tensorboard = is_tensorboard_available()
|