diff --git a/tensor2tensor/utils/bleu_hook.py b/tensor2tensor/utils/bleu_hook.py index 3ca5070a8..50caf09bf 100644 --- a/tensor2tensor/utils/bleu_hook.py +++ b/tensor2tensor/utils/bleu_hook.py @@ -153,7 +153,7 @@ def __init__(self): def _property_chars(prefix): return ''.join(six.unichr(x) for x in range(sys.maxunicode) if unicodedata.category(six.unichr(x)).startswith(prefix)) - punctuation = self._property_chars('P') + punctuation = _property_chars('P') self.nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') self.punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') self.symbol_re = re.compile('([' + _property_chars('S') + '])') @@ -183,9 +183,10 @@ def bleu_tokenize(string): Returns: a list of tokens """ - string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) - string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) - string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + uregex = UnicodeRegex() + string = uregex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = uregex.punct_nondigit_re.sub(r' \1 \2', string) + string = uregex.symbol_re.sub(r' \1 ', string) return string.split() diff --git a/tensor2tensor/utils/bleu_hook_test.py b/tensor2tensor/utils/bleu_hook_test.py index e4f3a18a9..b616aaf7c 100644 --- a/tensor2tensor/utils/bleu_hook_test.py +++ b/tensor2tensor/utils/bleu_hook_test.py @@ -57,5 +57,9 @@ def testComputeMultipleNgrams(self): actual_bleu = 0.3436 self.assertAllClose(bleu, actual_bleu, atol=1e-03) + def testBleuTokenize(self): + self.assertEqual(bleu_hook.bleu_tokenize(u'hi, “there”'), [u'hi', u',', u'“', u'there', u'”']) + + if __name__ == '__main__': tf.test.main()