# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re import string from collections import Counter from typing import List, Union import torch from torchmetrics import Metric __all__ = ['TopKClassificationAccuracy'] class TopKClassificationAccuracy(Metric): """ This metric computes numerator and denominator for Overall Accuracy between logits and labels. When doing distributed training/evaluation the result of res=TopKClassificationAccuracy(logits, labels) calls will be all-reduced between all workers using SUM operations. Here contains two numbers res=[correctly_predicted, total_samples]. Accuracy=correctly_predicted/total_samples. If used with PytorchLightning LightningModule, include correct_count and total_count inside validation_step results. Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. Example: def validation_step(self, batch, batch_idx): ... correct_count, total_count = self._accuracy(logits, labels) self.val_outputs = {'val_loss': loss_value, 'val_correct_count': correct_count, 'val_total_count': total_count} return self.val_outputs def on_validation_epoch_end(self): ... val_loss_mean = torch.stack([x['val_loss'] for x in self.val_outputs]).mean() correct_counts = torch.stack([x['val_correct_counts'] for x in self.val_outputs]) total_counts = torch.stack([x['val_total_counts'] for x in self.val_outputs]) topk_scores = compute_topk_accuracy(correct_counts, total_counts) tensorboard_log = {'val_loss': val_loss_mean} for top_k, score in zip(self._accuracy.top_k, topk_scores): tensorboard_log['val_epoch_top@{}'.format(top_k)] = score self.val_outputs.clear() # free memory return {'log': tensorboard_log} Args: top_k: Optional list of integers. Defaults to [1]. Returns: res: a torch.Tensor object with two elements: [correct_count, total_count]. To correctly compute average accuracy, compute acc=correct_count/total_count """ full_state_update = True def __init__(self, top_k=None, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) if top_k is None: top_k = [1] self.top_k = top_k self.add_state( "correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False ) self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False) @torch.no_grad() def top_k_predicted_labels(self, logits: torch.Tensor) -> torch.Tensor: max_k = max(self.top_k) _, predictions = logits.topk(max_k, dim=1, largest=True, sorted=True) return predictions def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: with torch.no_grad(): predictions = self.top_k_predicted_labels(logits) predictions = predictions.t() correct = predictions.eq(labels.view(1, -1)).expand_as(predictions) correct_counts_k = [] total_counts_k = [] for k in self.top_k: correct_k = correct[:k].reshape(-1).long().sum() total_k = labels.shape[0] correct_counts_k.append(correct_k) total_counts_k.append(total_k) self.correct_counts_k = torch.tensor(correct_counts_k, dtype=labels.dtype, device=labels.device) self.total_counts_k = torch.tensor(total_counts_k, dtype=labels.dtype, device=labels.device) def compute(self): """ Computes the top-k accuracy. Returns: A list of length `K`, such that k-th index corresponds to top-k accuracy over all distributed processes. """ if not len(self.correct_counts_k) == len(self.top_k) == len(self.total_counts_k): raise ValueError("length of counts must match to topk length") if self.top_k == [1]: return [self.correct_counts_k.float() / self.total_counts_k] else: top_k_scores = compute_topk_accuracy(self.correct_counts_k, self.total_counts_k) return top_k_scores @property def top_k(self) -> List[int]: return self._top_k @top_k.setter def top_k(self, value: List[int]): if value is None: value = [1] if type(value) == int: value = [value] if type(value) != list: value = list(value) self._top_k = value def compute_topk_accuracy(correct_counts_k, total_counts_k): """ Computes the top-k accuracy Args: correct_counts: Tensor of shape [K], K being the top-k parameter. total_counts: Tensor of shape [K], and K being the top-k parameter. Returns: A list of length `K`, such that k-th index corresponds to top-k accuracy over all distributed processes. """ top_k_scores = [] for ki in range(len(correct_counts_k)): correct_count = correct_counts_k[ki].item() total_count = total_counts_k[ki].item() top_k_scores.append(correct_count / float(total_count)) return top_k_scores class ExactStringPerCategoryMatchMetric(Metric): def __init__(self, categories=[], dist_sync_on_step=False, *args, **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.categories = set(categories) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") for category in categories: self.add_state(f"{category}_total", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state(f"{category}_correct", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred: str, target: str, category: str = None): if pred == target: self.correct += 1 self.total += 1 if category is None: return if category in self.categories: val = getattr(self, f"{category}_total") setattr(self, f"{category}_total", val + 1) if pred == target: val = getattr(self, f"{category}_correct") setattr(self, f"{category}_correct", val + 1) else: logging.warning(f'{category} is not in the pre-defined list') def compute(self): results = {} results['acc'] = self.correct.float() / self.total for category in self.categories: results[category] = getattr(self, f"{category}_correct") / getattr(self, f"{category}_total") for category in self.categories: results[f"{category}_total"] = getattr(self, f"{category}_total") return results class ExactStringMatchMetric(Metric): def __init__(self, dist_sync_on_step=False, *args, **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred: str, target: str): if pred == target: self.correct += 1 self.total += 1 def compute(self): return self.correct.float() / self.total class TokenF1Score(Metric): """Taken from the official evaluation script for v1.1 of the SQuAD dataset""" def __init__(self, dist_sync_on_step=False, *args, **kwargs): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred: str, target: Union[str, List[str]]): if isinstance(target, str): self.correct += self.f1_score(pred, target) elif isinstance(target, list): self.correct += max([self.f1_score(pred, tgt) for tgt in target]) self.total += 1 def compute(self): return self.correct.float() / self.total def f1_score(self, prediction, ground_truth): prediction_tokens = self.normalize(prediction).split() ground_truth_tokens = self.normalize(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0.0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def normalize(self, s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s))))