From 55b90a90daeed9e25a1566e47c2abe523f7cdaab Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Mon, 18 Dec 2023 17:06:35 +0000 Subject: [PATCH] Change WER signature (#190) --- src/trustyai/metrics/language.py | 39 +++++++++++--------------- tests/general/test_metrics_language.py | 8 +++--- 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/trustyai/metrics/language.py b/src/trustyai/metrics/language.py index d774402..b80891a 100644 --- a/src/trustyai/metrics/language.py +++ b/src/trustyai/metrics/language.py @@ -4,17 +4,16 @@ # pylint: disable = import-error from typing import List, Optional, Union, Callable -from org.kie.trustyai.metrics.language.wer import ( +from org.kie.trustyai.metrics.language.levenshtein import ( WordErrorRate as _WordErrorRate, - WordErrorRateResult as _WordErrorRateResult, + ErrorRateResult as _ErrorRateResult, ) - from opennlp.tools.tokenize import Tokenizer @dataclass -class TokenSequenceAlignmentCounters: - """Token Sequence Alignment Counters""" +class LevenshteinCounters: + """LevenshteinCounters Counters""" substitutions: int insertions: int @@ -23,25 +22,19 @@ class TokenSequenceAlignmentCounters: @dataclass -class WordErrorRateResult: +class ErrorRateResult: """Word Error Rate Result""" - wer: float - aligned_reference: str - aligned_input: str - alignment_counters: TokenSequenceAlignmentCounters + value: float + alignment_counters: LevenshteinCounters @staticmethod - def convert(wer_result: _WordErrorRateResult): - """Converts a Java WordErrorRateResult to a Python WordErrorRateResult""" - wer = wer_result.getWordErrorRate() - aligned_reference = wer_result.getAlignedReferenceString() - aligned_input = wer_result.getAlignedInputString() - alignment_counters = wer_result.getAlignmentCounters() - return WordErrorRateResult( - wer=wer, - aligned_reference=aligned_reference, - aligned_input=aligned_input, + def convert(result: _ErrorRateResult): + """Converts a Java ErrorRateResult to a Python ErrorRateResult""" + value = result.getValue() + alignment_counters = result.getAlignmentCounters() + return ErrorRateResult( + value=value, alignment_counters=alignment_counters, ) @@ -50,7 +43,7 @@ def word_error_rate( reference: str, hypothesis: str, tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None, -) -> WordErrorRateResult: +) -> ErrorRateResult: """Calculate Word Error Rate between reference and hypothesis strings""" if not tokenizer: _wer = _WordErrorRate() @@ -60,9 +53,9 @@ def word_error_rate( tokenized_reference = tokenizer(reference) tokenized_hypothesis = tokenizer(hypothesis) _wer = _WordErrorRate() - return WordErrorRateResult.convert( + return ErrorRateResult.convert( _wer.calculate(tokenized_reference, tokenized_hypothesis) ) else: raise ValueError("Unsupported tokenizer") - return WordErrorRateResult.convert(_wer.calculate(reference, hypothesis)) + return ErrorRateResult.convert(_wer.calculate(reference, hypothesis)) diff --git a/tests/general/test_metrics_language.py b/tests/general/test_metrics_language.py index 3c0b42b..75f2a4c 100644 --- a/tests/general/test_metrics_language.py +++ b/tests/general/test_metrics_language.py @@ -22,7 +22,7 @@ def test_default_tokenizer(): """Test default tokenizer""" results = [4 / 7, 1 / 26, 1] for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): - wer = word_error_rate(reference, hypothesis).wer + wer = word_error_rate(reference, hypothesis).value assert math.isclose(wer, results[i], rel_tol=tolerance), \ f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." @@ -36,7 +36,7 @@ def tokenizer(text: str) -> List[str]: return CommonsStringTokenizer(text).getTokenList() for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): - wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer + wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value assert math.isclose(wer, results[i], rel_tol=tolerance), \ f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." @@ -47,7 +47,7 @@ def test_opennlp_tokenizer(): results = [9 / 14., 3 / 78., 1.0] tokenizer = OpenNLPTokenizer() for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): - wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer + wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value assert math.isclose(wer, results[i], rel_tol=tolerance), \ f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." @@ -61,6 +61,6 @@ def tokenizer(text: str) -> List[str]: return text.split(" ") for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): - wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer + wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value assert math.isclose(wer, results[i], rel_tol=tolerance), \ f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."