Skip to content

Commit

Permalink
Change WER signature (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruivieira authored Dec 18, 2023
1 parent e7ade1e commit 55b90a9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 27 deletions.
39 changes: 16 additions & 23 deletions src/trustyai/metrics/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@
# pylint: disable = import-error
from typing import List, Optional, Union, Callable

from org.kie.trustyai.metrics.language.wer import (
from org.kie.trustyai.metrics.language.levenshtein import (
WordErrorRate as _WordErrorRate,
WordErrorRateResult as _WordErrorRateResult,
ErrorRateResult as _ErrorRateResult,
)

from opennlp.tools.tokenize import Tokenizer


@dataclass
class TokenSequenceAlignmentCounters:
"""Token Sequence Alignment Counters"""
class LevenshteinCounters:
"""LevenshteinCounters Counters"""

substitutions: int
insertions: int
Expand All @@ -23,25 +22,19 @@ class TokenSequenceAlignmentCounters:


@dataclass
class WordErrorRateResult:
class ErrorRateResult:
"""Word Error Rate Result"""

wer: float
aligned_reference: str
aligned_input: str
alignment_counters: TokenSequenceAlignmentCounters
value: float
alignment_counters: LevenshteinCounters

@staticmethod
def convert(wer_result: _WordErrorRateResult):
"""Converts a Java WordErrorRateResult to a Python WordErrorRateResult"""
wer = wer_result.getWordErrorRate()
aligned_reference = wer_result.getAlignedReferenceString()
aligned_input = wer_result.getAlignedInputString()
alignment_counters = wer_result.getAlignmentCounters()
return WordErrorRateResult(
wer=wer,
aligned_reference=aligned_reference,
aligned_input=aligned_input,
def convert(result: _ErrorRateResult):
"""Converts a Java ErrorRateResult to a Python ErrorRateResult"""
value = result.getValue()
alignment_counters = result.getAlignmentCounters()
return ErrorRateResult(
value=value,
alignment_counters=alignment_counters,
)

Expand All @@ -50,7 +43,7 @@ def word_error_rate(
reference: str,
hypothesis: str,
tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None,
) -> WordErrorRateResult:
) -> ErrorRateResult:
"""Calculate Word Error Rate between reference and hypothesis strings"""
if not tokenizer:
_wer = _WordErrorRate()
Expand All @@ -60,9 +53,9 @@ def word_error_rate(
tokenized_reference = tokenizer(reference)
tokenized_hypothesis = tokenizer(hypothesis)
_wer = _WordErrorRate()
return WordErrorRateResult.convert(
return ErrorRateResult.convert(
_wer.calculate(tokenized_reference, tokenized_hypothesis)
)
else:
raise ValueError("Unsupported tokenizer")
return WordErrorRateResult.convert(_wer.calculate(reference, hypothesis))
return ErrorRateResult.convert(_wer.calculate(reference, hypothesis))
8 changes: 4 additions & 4 deletions tests/general/test_metrics_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_default_tokenizer():
"""Test default tokenizer"""
results = [4 / 7, 1 / 26, 1]
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
wer = word_error_rate(reference, hypothesis).wer
wer = word_error_rate(reference, hypothesis).value
assert math.isclose(wer, results[i], rel_tol=tolerance), \
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

Expand All @@ -36,7 +36,7 @@ def tokenizer(text: str) -> List[str]:
return CommonsStringTokenizer(text).getTokenList()

for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
assert math.isclose(wer, results[i], rel_tol=tolerance), \
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

Expand All @@ -47,7 +47,7 @@ def test_opennlp_tokenizer():
results = [9 / 14., 3 / 78., 1.0]
tokenizer = OpenNLPTokenizer()
for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
assert math.isclose(wer, results[i], rel_tol=tolerance), \
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

Expand All @@ -61,6 +61,6 @@ def tokenizer(text: str) -> List[str]:
return text.split(" ")

for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)):
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).wer
wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value
assert math.isclose(wer, results[i], rel_tol=tolerance), \
f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}."

0 comments on commit 55b90a9

Please sign in to comment.