diff --git a/examples/src/main/java/ai/djl/examples/inference/nlp/BertQaInference.java b/examples/src/main/java/ai/djl/examples/inference/nlp/BertQaInference.java index d02da5437fd..507bf258957 100644 --- a/examples/src/main/java/ai/djl/examples/inference/nlp/BertQaInference.java +++ b/examples/src/main/java/ai/djl/examples/inference/nlp/BertQaInference.java @@ -48,7 +48,7 @@ private BertQaInference() {} public static void main(String[] args) throws IOException, TranslateException, ModelException { String answer = BertQaInference.predict(); - logger.info("Answer: {}", answer); + logger.info("Output: {}", answer); } public static String predict() throws IOException, TranslateException, ModelException { @@ -69,6 +69,7 @@ public static String predict() throws IOException, TranslateException, ModelExce "djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2") .optEngine("PyTorch") .optTranslatorFactory(new QuestionAnsweringTranslatorFactory()) + .optArgument("detail", true) .optProgress(new ProgressBar()) .build(); diff --git a/examples/src/test/java/ai/djl/examples/inference/nlp/BertQaTest.java b/examples/src/test/java/ai/djl/examples/inference/nlp/BertQaTest.java index 21a70ff6756..1a464d8f789 100644 --- a/examples/src/test/java/ai/djl/examples/inference/nlp/BertQaTest.java +++ b/examples/src/test/java/ai/djl/examples/inference/nlp/BertQaTest.java @@ -15,6 +15,9 @@ import ai.djl.ModelException; import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonObject; import org.testng.Assert; import org.testng.annotations.Test; @@ -28,6 +31,8 @@ public void testBertQa() throws ModelException, TranslateException, IOException TestRequirements.linux(); String result = BertQaInference.predict(); - Assert.assertEquals(result, "december 2004"); + JsonObject json = JsonUtils.GSON.fromJson(result, JsonObject.class); + String answer = json.get("answer").getAsString(); + Assert.assertEquals(answer, "december 2004"); } } diff --git a/extensions/tokenizers/rust/src/lib.rs b/extensions/tokenizers/rust/src/lib.rs index 3d4cd1264ad..a7d1072c9f9 100644 --- a/extensions/tokenizers/rust/src/lib.rs +++ b/extensions/tokenizers/rust/src/lib.rs @@ -39,8 +39,8 @@ use tk::models::bpe::BPE; use tk::tokenizer::{EncodeInput, Encoding}; use tk::utils::padding::{PaddingParams, PaddingStrategy}; use tk::utils::truncation::{TruncationParams, TruncationStrategy}; -use tk::Tokenizer; use tk::Offsets; +use tk::Tokenizer; #[cfg(not(target_os = "android"))] use tk::FromPretrainedParameters; @@ -407,6 +407,31 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_ array } +#[no_mangle] +pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getSequenceIds< + 'local, +>( + env: JNIEnv<'local>, + _: JObject, + handle: jlong, +) -> JLongArray<'local> { + let encoding = cast_handle::(handle); + let sequence_ids = encoding.get_sequence_ids(); + let len = sequence_ids.len() as jsize; + let mut long_ids: Vec = Vec::new(); + for i in sequence_ids { + if let Some(sequence_id) = i { + long_ids.push(sequence_id as jlong) + } else { + long_ids.push(-1) + } + } + + let array = env.new_long_array(len).unwrap(); + env.set_long_array_region(&array, 0, &long_ids).unwrap(); + array +} + #[no_mangle] pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokens< 'local, diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java index 301490a2bb4..4877f53aa46 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/Encoding.java @@ -25,6 +25,7 @@ public class Encoding { private long[] typeIds; private String[] tokens; private long[] wordIds; + private long[] sequenceIds; private long[] attentionMask; private long[] specialTokenMask; private CharSpan[] charTokenSpans; @@ -36,6 +37,7 @@ protected Encoding( long[] typeIds, String[] tokens, long[] wordIds, + long[] sequenceIds, long[] attentionMask, long[] specialTokenMask, CharSpan[] charTokenSpans, @@ -45,6 +47,7 @@ protected Encoding( this.typeIds = typeIds; this.tokens = tokens; this.wordIds = wordIds; + this.sequenceIds = sequenceIds; this.attentionMask = attentionMask; this.specialTokenMask = specialTokenMask; this.charTokenSpans = charTokenSpans; @@ -109,6 +112,15 @@ public long[] getWordIds() { return wordIds; } + /** + * Returns the sequence ids. + * + * @return the sequence ids + */ + public long[] getSequenceIds() { + return sequenceIds; + } + /** * Returns the attention masks. * diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index cd6eaf289ca..c4c04c4df0b 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -623,6 +623,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) { long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding); String[] tokens = TokenizersLibrary.LIB.getTokens(encoding); long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding); + long[] sequenceIds = TokenizersLibrary.LIB.getSequenceIds(encoding); long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding); long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding); CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding); @@ -646,6 +647,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) { typeIds, tokens, wordIds, + sequenceIds, attentionMask, specialTokenMask, charSpans, diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java index 7ef895091ad..994694fe21d 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/jni/TokenizersLibrary.java @@ -50,6 +50,8 @@ public native long[] batchEncodePair( public native long[] getWordIds(long encoding); + public native long[] getSequenceIds(long encoding); + public native String[] getTokens(long encoding); public native long[] getAttentionMask(long encoding); diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/QuestionAnsweringTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/QuestionAnsweringTranslator.java index f0b0fac0208..1b3b4b098f3 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/QuestionAnsweringTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/QuestionAnsweringTranslator.java @@ -23,12 +23,14 @@ import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; import ai.djl.util.PairList; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** The translator for Huggingface question answering model. */ public class QuestionAnsweringTranslator implements Translator { @@ -36,12 +38,17 @@ public class QuestionAnsweringTranslator implements Translator private HuggingFaceTokenizer tokenizer; private boolean includeTokenTypes; private Batchifier batchifier; + private boolean detail; QuestionAnsweringTranslator( - HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { + HuggingFaceTokenizer tokenizer, + boolean includeTokenTypes, + Batchifier batchifier, + boolean detail) { this.tokenizer = tokenizer; this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; + this.detail = detail; } /** {@inheritDoc} */ @@ -102,6 +109,27 @@ private String decode(NDList list, Encoding encoding) { startLogits = startLogits.duplicate(); endLogits = endLogits.duplicate(); } + if (detail) { + // exclude undesired sequences + long[] sequenceIds = encoding.getSequenceIds(); + List undesired = new ArrayList<>(); + for (int i = 0; i < sequenceIds.length; ++i) { + if (sequenceIds[i] == 0) { + undesired.add(i); + } + } + int[] idx = undesired.stream().mapToInt(Integer::intValue).toArray(); + NDIndex ndIndex = new NDIndex("{}", list.getManager().create(idx)); + startLogits.set(ndIndex, -100000f); + endLogits.set(ndIndex, -100000f); + + // normalize + startLogits = startLogits.sub(startLogits.max()).exp(); + startLogits = startLogits.div(startLogits.sum()); + endLogits = endLogits.sub(endLogits.max()).exp(); + endLogits = endLogits.div(endLogits.sum()); + } + // exclude , TODO: exclude impossible ids properly and handle max answer length startLogits.set(new NDIndex(0), -100000); endLogits.set(new NDIndex(0), -100000); @@ -111,12 +139,26 @@ private String decode(NDList list, Encoding encoding) { int tmp = startIdx; startIdx = endIdx; endIdx = tmp; + NDArray tmpArray = startLogits; + startLogits = endLogits; + endLogits = tmpArray; } long[] indices = encoding.getIds(); int len = endIdx - startIdx + 1; long[] ids = new long[len]; System.arraycopy(indices, startIdx, ids, 0, len); - return tokenizer.decode(ids).trim(); + String answer = tokenizer.decode(ids).trim(); + if (detail) { + float score = startLogits.getFloat(startIdx) * endLogits.getFloat(endIdx); + + Map dict = new ConcurrentHashMap<>(); + dict.put("score", score); + dict.put("start", startIdx); + dict.put("end", endIdx); + dict.put("answer", answer); + return JsonUtils.toJson(dict); + } + return answer; } /** @@ -149,6 +191,7 @@ public static final class Builder { private HuggingFaceTokenizer tokenizer; private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; + private boolean detail; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; @@ -176,6 +219,17 @@ public Builder optBatchifier(Batchifier batchifier) { return this; } + /** + * Sets if output detail for the {@link Translator}. + * + * @param detail true to output detail + * @return this builder + */ + public Builder optDetail(boolean detail) { + this.detail = detail; + return this; + } + /** * Configures the builder with the model arguments. * @@ -184,6 +238,7 @@ public Builder optBatchifier(Batchifier batchifier) { public void configure(Map arguments) { optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); + optDetail(ArgumentsUtil.booleanValue(arguments, "detail")); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -194,7 +249,8 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public QuestionAnsweringTranslator build() throws IOException { - return new QuestionAnsweringTranslator(tokenizer, includeTokenTypes, batchifier); + return new QuestionAnsweringTranslator( + tokenizer, includeTokenTypes, batchifier, detail); } } } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 4e11fee74fd..44103f3a8c8 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -67,6 +67,7 @@ public void testTokenizer() throws IOException { long[] ids = {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 100, 136, 102}; long[] typeIds = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; long[] wordIds = {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -1}; + long[] sequenceIds = {-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1}; long[] attentionMask = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; long[] specialTokenMask = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; @@ -74,6 +75,7 @@ public void testTokenizer() throws IOException { Assert.assertEquals(ids, encoding.getIds()); Assert.assertEquals(typeIds, encoding.getTypeIds()); Assert.assertEquals(wordIds, encoding.getWordIds()); + Assert.assertEquals(sequenceIds, encoding.getSequenceIds()); Assert.assertEquals(attentionMask, encoding.getAttentionMask()); Assert.assertEquals(specialTokenMask, encoding.getSpecialTokenMask()); @@ -104,6 +106,10 @@ public void testTokenizer() throws IOException { Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd()); } + encoding = tokenizer.encode(inputs[0], inputs[1]); + sequenceIds = new long[] {-1, 0, 0, 0, 0, 0, 0, -1, 1, 1, 1, 1, 1, -1}; + Assert.assertEquals(encoding.getSequenceIds(), sequenceIds); + Assert.assertThrows(() -> tokenizer.encode((String) null)); Assert.assertThrows(() -> tokenizer.encode(new String[] {null})); Assert.assertThrows(() -> tokenizer.encode(null, null));