subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import string
from collections import Counter
from typing import List, Union
import torch
from torchmetrics import Metric
__all__ = ['TopKClassificationAccuracy']
class TopKClassificationAccuracy(Metric):
"""
This metric computes numerator and denominator for Overall Accuracy between logits and labels.
When doing distributed training/evaluation the result of res=TopKClassificationAccuracy(logits, labels) calls
will be all-reduced between all workers using SUM operations.
Here contains two numbers res=[correctly_predicted, total_samples]. Accuracy=correctly_predicted/total_samples.
If used with PytorchLightning LightningModule, include correct_count and total_count inside validation_step results.
Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER.
Example:
def validation_step(self, batch, batch_idx):
...
correct_count, total_count = self._accuracy(logits, labels)
self.val_outputs = {'val_loss': loss_value, 'val_correct_count': correct_count, 'val_total_count': total_count}
return self.val_outputs
def on_validation_epoch_end(self):
...
val_loss_mean = torch.stack([x['val_loss'] for x in self.val_outputs]).mean()
correct_counts = torch.stack([x['val_correct_counts'] for x in self.val_outputs])
total_counts = torch.stack([x['val_total_counts'] for x in self.val_outputs])
topk_scores = compute_topk_accuracy(correct_counts, total_counts)
tensorboard_log = {'val_loss': val_loss_mean}
for top_k, score in zip(self._accuracy.top_k, topk_scores):
tensorboard_log['val_epoch_top@{}'.format(top_k)] = score
self.val_outputs.clear() # free memory
return {'log': tensorboard_log}
Args:
top_k: Optional list of integers. Defaults to [1].
Returns:
res: a torch.Tensor object with two elements: [correct_count, total_count]. To correctly compute average
accuracy, compute acc=correct_count/total_count
"""
full_state_update = True
def __init__(self, top_k=None, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
if top_k is None:
top_k = [1]
self.top_k = top_k
self.add_state(
"correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False
)
self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False)
@torch.no_grad()
def top_k_predicted_labels(self, logits: torch.Tensor) -> torch.Tensor:
max_k = max(self.top_k)
_, predictions = logits.topk(max_k, dim=1, largest=True, sorted=True)
return predictions
def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
predictions = self.top_k_predicted_labels(logits)
predictions = predictions.t()
correct = predictions.eq(labels.view(1, -1)).expand_as(predictions)
correct_counts_k = []
total_counts_k = []
for k in self.top_k:
correct_k = correct[:k].reshape(-1).long().sum()
total_k = labels.shape[0]
correct_counts_k.append(correct_k)
total_counts_k.append(total_k)
self.correct_counts_k = torch.tensor(correct_counts_k, dtype=labels.dtype, device=labels.device)
self.total_counts_k = torch.tensor(total_counts_k, dtype=labels.dtype, device=labels.device)
def compute(self):
"""
Computes the top-k accuracy.
Returns:
A list of length `K`, such that k-th index corresponds to top-k accuracy
over all distributed processes.
"""
if not len(self.correct_counts_k) == len(self.top_k) == len(self.total_counts_k):
raise ValueError("length of counts must match to topk length")
if self.top_k == [1]:
return [self.correct_counts_k.float() / self.total_counts_k]
else:
top_k_scores = compute_topk_accuracy(self.correct_counts_k, self.total_counts_k)
return top_k_scores
@property
def top_k(self) -> List[int]:
return self._top_k
@top_k.setter
def top_k(self, value: List[int]):
if value is None:
value = [1]
if type(value) == int:
value = [value]
if type(value) != list:
value = list(value)
self._top_k = value
def compute_topk_accuracy(correct_counts_k, total_counts_k):
"""
Computes the top-k accuracy
Args:
correct_counts: Tensor of shape [K], K being the top-k parameter.
total_counts: Tensor of shape [K], and K being the top-k parameter.
Returns:
A list of length `K`, such that k-th index corresponds to top-k accuracy
over all distributed processes.
"""
top_k_scores = []
for ki in range(len(correct_counts_k)):
correct_count = correct_counts_k[ki].item()
total_count = total_counts_k[ki].item()
top_k_scores.append(correct_count / float(total_count))
return top_k_scores
class ExactStringPerCategoryMatchMetric(Metric):
def __init__(self, categories=[], dist_sync_on_step=False, *args, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.categories = set(categories)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
for category in categories:
self.add_state(f"{category}_total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state(f"{category}_correct", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, pred: str, target: str, category: str = None):
if pred == target:
self.correct += 1
self.total += 1
if category is None:
return
if category in self.categories:
val = getattr(self, f"{category}_total")
setattr(self, f"{category}_total", val + 1)
if pred == target:
val = getattr(self, f"{category}_correct")
setattr(self, f"{category}_correct", val + 1)
else:
logging.warning(f'{category} is not in the pre-defined list')
def compute(self):
results = {}
results['acc'] = self.correct.float() / self.total
for category in self.categories:
results[category] = getattr(self, f"{category}_correct") / getattr(self, f"{category}_total")
for category in self.categories:
results[f"{category}_total"] = getattr(self, f"{category}_total")
return results
class ExactStringMatchMetric(Metric):
def __init__(self, dist_sync_on_step=False, *args, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, pred: str, target: str):
if pred == target:
self.correct += 1
self.total += 1
def compute(self):
return self.correct.float() / self.total
class TokenF1Score(Metric):
"""Taken from the official evaluation script for v1.1 of the SQuAD dataset"""
def __init__(self, dist_sync_on_step=False, *args, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, pred: str, target: Union[str, List[str]]):
if isinstance(target, str):
self.correct += self.f1_score(pred, target)
elif isinstance(target, list):
self.correct += max([self.f1_score(pred, tgt) for tgt in target])
self.total += 1
def compute(self):
return self.correct.float() / self.total
def f1_score(self, prediction, ground_truth):
prediction_tokens = self.normalize(prediction).split()
ground_truth_tokens = self.normalize(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0.0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def normalize(self, s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))