Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import re | |
| import string | |
| from collections import Counter | |
| from typing import List, Union | |
| import torch | |
| from torchmetrics import Metric | |
| __all__ = ['TopKClassificationAccuracy'] | |
| class TopKClassificationAccuracy(Metric): | |
| """ | |
| This metric computes numerator and denominator for Overall Accuracy between logits and labels. | |
| When doing distributed training/evaluation the result of res=TopKClassificationAccuracy(logits, labels) calls | |
| will be all-reduced between all workers using SUM operations. | |
| Here contains two numbers res=[correctly_predicted, total_samples]. Accuracy=correctly_predicted/total_samples. | |
| If used with PytorchLightning LightningModule, include correct_count and total_count inside validation_step results. | |
| Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. | |
| Example: | |
| def validation_step(self, batch, batch_idx): | |
| ... | |
| correct_count, total_count = self._accuracy(logits, labels) | |
| self.val_outputs = {'val_loss': loss_value, 'val_correct_count': correct_count, 'val_total_count': total_count} | |
| return self.val_outputs | |
| def on_validation_epoch_end(self): | |
| ... | |
| val_loss_mean = torch.stack([x['val_loss'] for x in self.val_outputs]).mean() | |
| correct_counts = torch.stack([x['val_correct_counts'] for x in self.val_outputs]) | |
| total_counts = torch.stack([x['val_total_counts'] for x in self.val_outputs]) | |
| topk_scores = compute_topk_accuracy(correct_counts, total_counts) | |
| tensorboard_log = {'val_loss': val_loss_mean} | |
| for top_k, score in zip(self._accuracy.top_k, topk_scores): | |
| tensorboard_log['val_epoch_top@{}'.format(top_k)] = score | |
| self.val_outputs.clear() # free memory | |
| return {'log': tensorboard_log} | |
| Args: | |
| top_k: Optional list of integers. Defaults to [1]. | |
| Returns: | |
| res: a torch.Tensor object with two elements: [correct_count, total_count]. To correctly compute average | |
| accuracy, compute acc=correct_count/total_count | |
| """ | |
| full_state_update = True | |
| def __init__(self, top_k=None, dist_sync_on_step=False): | |
| super().__init__(dist_sync_on_step=dist_sync_on_step) | |
| if top_k is None: | |
| top_k = [1] | |
| self.top_k = top_k | |
| self.add_state( | |
| "correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False | |
| ) | |
| self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False) | |
| def top_k_predicted_labels(self, logits: torch.Tensor) -> torch.Tensor: | |
| max_k = max(self.top_k) | |
| _, predictions = logits.topk(max_k, dim=1, largest=True, sorted=True) | |
| return predictions | |
| def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | |
| predictions = self.top_k_predicted_labels(logits) | |
| predictions = predictions.t() | |
| correct = predictions.eq(labels.view(1, -1)).expand_as(predictions) | |
| correct_counts_k = [] | |
| total_counts_k = [] | |
| for k in self.top_k: | |
| correct_k = correct[:k].reshape(-1).long().sum() | |
| total_k = labels.shape[0] | |
| correct_counts_k.append(correct_k) | |
| total_counts_k.append(total_k) | |
| self.correct_counts_k = torch.tensor(correct_counts_k, dtype=labels.dtype, device=labels.device) | |
| self.total_counts_k = torch.tensor(total_counts_k, dtype=labels.dtype, device=labels.device) | |
| def compute(self): | |
| """ | |
| Computes the top-k accuracy. | |
| Returns: | |
| A list of length `K`, such that k-th index corresponds to top-k accuracy | |
| over all distributed processes. | |
| """ | |
| if not len(self.correct_counts_k) == len(self.top_k) == len(self.total_counts_k): | |
| raise ValueError("length of counts must match to topk length") | |
| if self.top_k == [1]: | |
| return [self.correct_counts_k.float() / self.total_counts_k] | |
| else: | |
| top_k_scores = compute_topk_accuracy(self.correct_counts_k, self.total_counts_k) | |
| return top_k_scores | |
| def top_k(self) -> List[int]: | |
| return self._top_k | |
| def top_k(self, value: List[int]): | |
| if value is None: | |
| value = [1] | |
| if type(value) == int: | |
| value = [value] | |
| if type(value) != list: | |
| value = list(value) | |
| self._top_k = value | |
| def compute_topk_accuracy(correct_counts_k, total_counts_k): | |
| """ | |
| Computes the top-k accuracy | |
| Args: | |
| correct_counts: Tensor of shape [K], K being the top-k parameter. | |
| total_counts: Tensor of shape [K], and K being the top-k parameter. | |
| Returns: | |
| A list of length `K`, such that k-th index corresponds to top-k accuracy | |
| over all distributed processes. | |
| """ | |
| top_k_scores = [] | |
| for ki in range(len(correct_counts_k)): | |
| correct_count = correct_counts_k[ki].item() | |
| total_count = total_counts_k[ki].item() | |
| top_k_scores.append(correct_count / float(total_count)) | |
| return top_k_scores | |
| class ExactStringPerCategoryMatchMetric(Metric): | |
| def __init__(self, categories=[], dist_sync_on_step=False, *args, **kwargs): | |
| super().__init__(dist_sync_on_step=dist_sync_on_step) | |
| self.categories = set(categories) | |
| self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") | |
| self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | |
| for category in categories: | |
| self.add_state(f"{category}_total", default=torch.tensor(0), dist_reduce_fx="sum") | |
| self.add_state(f"{category}_correct", default=torch.tensor(0), dist_reduce_fx="sum") | |
| def update(self, pred: str, target: str, category: str = None): | |
| if pred == target: | |
| self.correct += 1 | |
| self.total += 1 | |
| if category is None: | |
| return | |
| if category in self.categories: | |
| val = getattr(self, f"{category}_total") | |
| setattr(self, f"{category}_total", val + 1) | |
| if pred == target: | |
| val = getattr(self, f"{category}_correct") | |
| setattr(self, f"{category}_correct", val + 1) | |
| else: | |
| logging.warning(f'{category} is not in the pre-defined list') | |
| def compute(self): | |
| results = {} | |
| results['acc'] = self.correct.float() / self.total | |
| for category in self.categories: | |
| results[category] = getattr(self, f"{category}_correct") / getattr(self, f"{category}_total") | |
| for category in self.categories: | |
| results[f"{category}_total"] = getattr(self, f"{category}_total") | |
| return results | |
| class ExactStringMatchMetric(Metric): | |
| def __init__(self, dist_sync_on_step=False, *args, **kwargs): | |
| super().__init__(dist_sync_on_step=dist_sync_on_step) | |
| self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") | |
| self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | |
| def update(self, pred: str, target: str): | |
| if pred == target: | |
| self.correct += 1 | |
| self.total += 1 | |
| def compute(self): | |
| return self.correct.float() / self.total | |
| class TokenF1Score(Metric): | |
| """Taken from the official evaluation script for v1.1 of the SQuAD dataset""" | |
| def __init__(self, dist_sync_on_step=False, *args, **kwargs): | |
| super().__init__(dist_sync_on_step=dist_sync_on_step) | |
| self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
| self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | |
| def update(self, pred: str, target: Union[str, List[str]]): | |
| if isinstance(target, str): | |
| self.correct += self.f1_score(pred, target) | |
| elif isinstance(target, list): | |
| self.correct += max([self.f1_score(pred, tgt) for tgt in target]) | |
| self.total += 1 | |
| def compute(self): | |
| return self.correct.float() / self.total | |
| def f1_score(self, prediction, ground_truth): | |
| prediction_tokens = self.normalize(prediction).split() | |
| ground_truth_tokens = self.normalize(ground_truth).split() | |
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
| num_same = sum(common.values()) | |
| if num_same == 0: | |
| return 0.0 | |
| precision = 1.0 * num_same / len(prediction_tokens) | |
| recall = 1.0 * num_same / len(ground_truth_tokens) | |
| f1 = (2 * precision * recall) / (precision + recall) | |
| return f1 | |
| def normalize(self, s): | |
| """Lower text and remove punctuation, articles and extra whitespace.""" | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |