Spaces:
Runtime error
Runtime error
File size: 9,570 Bytes
0558aa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from nemo.collections.common.metrics.classification_accuracy import TopKClassificationAccuracy
from nemo.collections.common.metrics.punct_er import (
DatasetPunctuationErrorRate,
OccurancePunctuationErrorRate,
punctuation_error_rate,
)
from .loss_inputs import ALL_NUM_MEASUREMENTS_ARE_ZERO, NO_ZERO_NUM_MEASUREMENTS, SOME_NUM_MEASUREMENTS_ARE_ZERO
from .perplexity_inputs import NO_PROBS_NO_LOGITS, ONLY_LOGITS1, ONLY_LOGITS100, ONLY_PROBS, PROBS_AND_LOGITS
from .pl_utils import LossTester, PerplexityTester
class TestCommonMetrics:
top_k_logits = torch.tensor(
[[0.1, 0.3, 0.2, 0.0], [0.9, 0.6, 0.2, 0.3], [0.2, 0.1, 0.4, 0.3]],
) # 1 # 0 # 2
@pytest.mark.unit
def test_top_1_accuracy(self):
labels = torch.tensor([0, 0, 2], dtype=torch.long)
accuracy = TopKClassificationAccuracy(top_k=None)
acc = accuracy(logits=self.top_k_logits, labels=labels)
assert accuracy.correct_counts_k.shape == torch.Size([1])
assert accuracy.total_counts_k.shape == torch.Size([1])
assert abs(acc[0] - 0.667) < 1e-3
@pytest.mark.unit
def test_top_1_2_accuracy(self):
labels = torch.tensor([0, 1, 0], dtype=torch.long)
accuracy = TopKClassificationAccuracy(top_k=[1, 2])
top1_acc, top2_acc = accuracy(logits=self.top_k_logits, labels=labels)
assert accuracy.correct_counts_k.shape == torch.Size([2])
assert accuracy.total_counts_k.shape == torch.Size([2])
assert abs(top1_acc - 0.0) < 1e-3
assert abs(top2_acc - 0.333) < 1e-3
@pytest.mark.unit
def test_top_1_accuracy_distributed(self):
# Simulate test on 2 process DDP execution
labels = torch.tensor([[0, 0, 2], [2, 0, 0]], dtype=torch.long)
accuracy = TopKClassificationAccuracy(top_k=None)
proc1_acc = accuracy(logits=self.top_k_logits, labels=labels[0])
correct1, total1 = accuracy.correct_counts_k, accuracy.total_counts_k
accuracy.reset()
proc2_acc = accuracy(logits=torch.flip(self.top_k_logits, dims=[1]), labels=labels[1]) # reverse logits
correct2, total2 = accuracy.correct_counts_k, accuracy.total_counts_k
correct = torch.stack([correct1, correct2])
total = torch.stack([total1, total2])
assert correct.shape == torch.Size([2, 1])
assert total.shape == torch.Size([2, 1])
assert abs(proc1_acc[0] - 0.667) < 1e-3 # 2/3
assert abs(proc2_acc[0] - 0.333) < 1e-3 # 1/3
accuracy.reset()
accuracy.correct_counts_k = torch.tensor([correct.sum()])
accuracy.total_counts_k = torch.tensor([total.sum()])
acc_topk = accuracy.compute()
acc_top1 = acc_topk[0]
assert abs(acc_top1 - 0.5) < 1e-3 # 3/6
@pytest.mark.unit
def test_top_1_accuracy_distributed_uneven_batch(self):
# Simulate test on 2 process DDP execution
accuracy = TopKClassificationAccuracy(top_k=None)
proc1_acc = accuracy(logits=self.top_k_logits, labels=torch.tensor([0, 0, 2]))
correct1, total1 = accuracy.correct_counts_k, accuracy.total_counts_k
proc2_acc = accuracy(
logits=torch.flip(self.top_k_logits, dims=[1])[:2, :], # reverse logits, select first 2 samples
labels=torch.tensor([2, 0]),
) # reduce number of labels
correct2, total2 = accuracy.correct_counts_k, accuracy.total_counts_k
correct = torch.stack([correct1, correct2])
total = torch.stack([total1, total2])
assert correct.shape == torch.Size([2, 1])
assert total.shape == torch.Size([2, 1])
assert abs(proc1_acc[0] - 0.667) < 1e-3 # 2/3
assert abs(proc2_acc[0] - 0.500) < 1e-3 # 1/2
accuracy.correct_counts_k = torch.tensor([correct.sum()])
accuracy.total_counts_k = torch.tensor([total.sum()])
acc_topk = accuracy.compute()
acc_top1 = acc_topk[0]
assert abs(acc_top1 - 0.6) < 1e-3 # 3/5
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize(
"probs, logits",
[
(ONLY_PROBS.probs, ONLY_PROBS.logits),
(ONLY_LOGITS1.probs, ONLY_LOGITS1.logits),
(ONLY_LOGITS100.probs, ONLY_LOGITS100.logits),
(PROBS_AND_LOGITS.probs, PROBS_AND_LOGITS.logits),
(NO_PROBS_NO_LOGITS.probs, NO_PROBS_NO_LOGITS.logits),
],
)
class TestPerplexity(PerplexityTester):
@pytest.mark.pleasefixme
def test_perplexity(self, ddp, dist_sync_on_step, probs, logits):
self.run_class_perplexity_test(
ddp=ddp,
probs=probs,
logits=logits,
dist_sync_on_step=dist_sync_on_step,
)
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("take_avg_loss", [True, False])
@pytest.mark.parametrize(
"loss_sum_or_avg, num_measurements",
[
(NO_ZERO_NUM_MEASUREMENTS.loss_sum_or_avg, NO_ZERO_NUM_MEASUREMENTS.num_measurements),
(SOME_NUM_MEASUREMENTS_ARE_ZERO.loss_sum_or_avg, SOME_NUM_MEASUREMENTS_ARE_ZERO.num_measurements),
(ALL_NUM_MEASUREMENTS_ARE_ZERO.loss_sum_or_avg, ALL_NUM_MEASUREMENTS_ARE_ZERO.num_measurements),
],
)
class TestLoss(LossTester):
def test_loss(self, ddp, dist_sync_on_step, loss_sum_or_avg, num_measurements, take_avg_loss):
self.run_class_loss_test(
ddp=ddp,
loss_sum_or_avg=loss_sum_or_avg,
num_measurements=num_measurements,
dist_sync_on_step=dist_sync_on_step,
take_avg_loss=take_avg_loss,
)
class TestPunctuationErrorRate:
reference = "Hi, dear! Nice to see you. What's"
hypothesis = "Hi dear! Nice to see you! What's?"
punctuation_marks = [".", ",", "!", "?"]
operation_amounts = {
'.': {'Correct': 0, 'Deletions': 0, 'Insertions': 0, 'Substitutions': 1},
',': {'Correct': 0, 'Deletions': 1, 'Insertions': 0, 'Substitutions': 0},
'!': {'Correct': 1, 'Deletions': 0, 'Insertions': 0, 'Substitutions': 0},
'?': {'Correct': 0, 'Deletions': 0, 'Insertions': 1, 'Substitutions': 0},
}
substitution_amounts = {
'.': {'.': 0, ',': 0, '!': 1, '?': 0},
',': {'.': 0, ',': 0, '!': 0, '?': 0},
'!': {'.': 0, ',': 0, '!': 0, '?': 0},
'?': {'.': 0, ',': 0, '!': 0, '?': 0},
}
correct_rate = 0.25
deletions_rate = 0.25
insertions_rate = 0.25
substitutions_rate = 0.25
punct_er = 0.75
operation_rates = {
'.': {'Correct': 0.0, 'Deletions': 0.0, 'Insertions': 0.0, 'Substitutions': 1.0},
',': {'Correct': 0.0, 'Deletions': 1.0, 'Insertions': 0.0, 'Substitutions': 0.0},
'!': {'Correct': 1.0, 'Deletions': 0.0, 'Insertions': 0.0, 'Substitutions': 0.0},
'?': {'Correct': 0.0, 'Deletions': 0.0, 'Insertions': 1.0, 'Substitutions': 0.0},
}
substitution_rates = {
'.': {'.': 0.0, ',': 0.0, '!': 1.0, '?': 0.0},
',': {'.': 0.0, ',': 0.0, '!': 0.0, '?': 0.0},
'!': {'.': 0.0, ',': 0.0, '!': 0.0, '?': 0.0},
'?': {'.': 0.0, ',': 0.0, '!': 0.0, '?': 0.0},
}
@pytest.mark.unit
def test_punctuation_error_rate(self):
assert punctuation_error_rate([self.reference], [self.hypothesis], self.punctuation_marks) == self.punct_er
@pytest.mark.unit
def test_OccurancePunctuationErrorRate(self):
oper_obj = OccurancePunctuationErrorRate(self.punctuation_marks)
operation_amounts, substitution_amounts, punctuation_rates = oper_obj.compute(self.reference, self.hypothesis)
assert operation_amounts == self.operation_amounts
assert substitution_amounts == self.substitution_amounts
assert punctuation_rates.correct_rate == self.correct_rate
assert punctuation_rates.deletions_rate == self.deletions_rate
assert punctuation_rates.insertions_rate == self.insertions_rate
assert punctuation_rates.substitutions_rate == self.substitutions_rate
assert punctuation_rates.punct_er == self.punct_er
assert punctuation_rates.operation_rates == self.operation_rates
assert punctuation_rates.substitution_rates == self.substitution_rates
@pytest.mark.unit
def test_DatasetPunctuationErrorRate(self):
dper_obj = DatasetPunctuationErrorRate([self.reference], [self.hypothesis], self.punctuation_marks)
dper_obj.compute()
assert dper_obj.correct_rate == self.correct_rate
assert dper_obj.deletions_rate == self.deletions_rate
assert dper_obj.insertions_rate == self.insertions_rate
assert dper_obj.substitutions_rate == self.substitutions_rate
assert dper_obj.punct_er == self.punct_er
assert dper_obj.operation_rates == self.operation_rates
assert dper_obj.substitution_rates == self.substitution_rates
|