Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import pytest | |
| from nemo.collections.speechlm2.parts.metrics import BLEU, WER | |
| def test_bleu(): | |
| metric = BLEU(verbose=False) | |
| metric.update( | |
| name="dataset_1", | |
| refs=["a b c d e f g h i j k l", "m n o p r s t u v"], | |
| hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], | |
| ) | |
| metric.update( | |
| name="dataset_2", | |
| refs=["a b c"], | |
| hyps=["a b d"], | |
| ) | |
| ans = metric.compute() | |
| assert ans["txt_bleu_dataset_1"] == 100.0 | |
| assert ans["txt_bleu_dataset_2"] == 0.0 | |
| assert ans["txt_bleu"] == 50.0 # average across datasets | |
| def test_wer(): | |
| metric = WER(verbose=False) | |
| metric.update( | |
| name="dataset_1", | |
| refs=["a b c d e f g h i j k l", "m n o p r s t u v"], | |
| hyps=["a b c d e f g h i j k l", "m n o p r s t u v"], | |
| ) | |
| metric.update( | |
| name="dataset_2", | |
| refs=["a b c"], | |
| hyps=["a b d"], | |
| ) | |
| ans = metric.compute() | |
| assert ans["wer_dataset_1"] == 0.0 | |
| assert ans["wer_dataset_2"] == 1 / 3 | |
| assert ans["wer"] == 1 / 6 # average across datasets | |