diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index 6d3ab383dc2..ba4d61b79b1 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -449,6 +449,53 @@ public void enableBatch() { } } + /** + * Returns the truncation policy. + * + * @return the truncation policy + */ + public String getTruncation() { + return truncation.name(); + } + + /** + * Returns the padding policy. + * + * @return the padding policy + */ + public String getPadding() { + return padding.name(); + } + + /** + * Returns the max token length. + * + * @return the max token length + */ + public int getMaxLength() { + return maxLength; + } + + /** + * Returns the stride to use in overflow overlap when truncating sequences longer than the model + * supports. + * + * @return the stride to use in overflow overlap when truncating sequences longer than the model + * supports + */ + public int getStride() { + return stride; + } + + /** + * Returns the padToMultipleOf for padding. + * + * @return the padToMultipleOf for padding + */ + public int getPadToMultipleOf() { + return padToMultipleOf; + } + /** * Creates a builder to build a {@code HuggingFaceTokenizer}. * diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 94889660d23..0c548d51aec 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -41,6 +41,12 @@ public void testTokenizer() throws IOException { }; try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) { + Assert.assertEquals(tokenizer.getTruncation(), "DO_NOT_TRUNCATE"); + Assert.assertEquals(tokenizer.getPadding(), "DO_NOT_PAD"); + Assert.assertEquals(tokenizer.getMaxLength(), -1); + Assert.assertEquals(tokenizer.getStride(), 0); + Assert.assertEquals(tokenizer.getPadToMultipleOf(), 0); + List ret = tokenizer.tokenize(input); Assert.assertEquals(ret.toArray(Utils.EMPTY_ARRAY), expected); Encoding encoding = tokenizer.encode(input);