From 3defba635d4bdfc0bba385c5e0dd1fdaed0c1003 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 22 Jan 2024 16:46:35 -0800 Subject: [PATCH] [tokenizer] Adds getters for HuggingfaceTokenizer (#2958) --- .../tokenizers/HuggingFaceTokenizer.java | 47 +++++++++++++++++++ .../tokenizers/HuggingFaceTokenizerTest.java | 6 +++ 2 files changed, 53 insertions(+) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 6d3ab383dc2..ba4d61b79b1 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -449,6 +449,53 @@ public void enableBatch() { } } + /** + * Returns the truncation policy. + * + * @return the truncation policy + */ + public String getTruncation() { + return truncation.name(); + } + + /** + * Returns the padding policy. + * + * @return the padding policy + */ + public String getPadding() { + return padding.name(); + } + + /** + * Returns the max token length. + * + * @return the max token length + */ + public int getMaxLength() { + return maxLength; + } + + /** + * Returns the stride to use in overflow overlap when truncating sequences longer than the model + * supports. + * + * @return the stride to use in overflow overlap when truncating sequences longer than the model + * supports + */ + public int getStride() { + return stride; + } + + /** + * Returns the padToMultipleOf for padding. + * + * @return the padToMultipleOf for padding + */ + public int getPadToMultipleOf() { + return padToMultipleOf; + } + /** * Creates a builder to build a {@code HuggingFaceTokenizer}. * diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 94889660d23..0c548d51aec 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -41,6 +41,12 @@ public void testTokenizer() throws IOException { }; try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE"); + Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD"); + Assert.assertEquals(tokenizer.getMaxLength(), -1); + Assert.assertEquals(tokenizer.getStride(), 0); + Assert.assertEquals(tokenizer.getPadToMultipleOf(), 0); + List ret = tokenizer.tokenize(input); Assert.assertEquals(ret.toArray(Utils.EMPTY_ARRAY), expected); Encoding encoding = tokenizer.encode(input);