Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] Return detail for QA inference #3555

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private BertQaInference() {}

public static void main(String[] args) throws IOException, TranslateException, ModelException {
String answer = BertQaInference.predict();
logger.info("Answer: {}", answer);
logger.info("Output: {}", answer);
}

public static String predict() throws IOException, TranslateException, ModelException {
Expand All @@ -69,6 +69,7 @@ public static String predict() throws IOException, TranslateException, ModelExce
"djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2")
.optEngine("PyTorch")
.optTranslatorFactory(new QuestionAnsweringTranslatorFactory())
.optArgument("detail", true)
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import ai.djl.ModelException;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;

import com.google.gson.JsonObject;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -28,6 +31,8 @@ public void testBertQa() throws ModelException, TranslateException, IOException
TestRequirements.linux();

String result = BertQaInference.predict();
Assert.assertEquals(result, "december 2004");
JsonObject json = JsonUtils.GSON.fromJson(result, JsonObject.class);
String answer = json.get("answer").getAsString();
Assert.assertEquals(answer, "december 2004");
}
}
27 changes: 26 additions & 1 deletion extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use tk::models::bpe::BPE;
use tk::tokenizer::{EncodeInput, Encoding};
use tk::utils::padding::{PaddingParams, PaddingStrategy};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::Tokenizer;
use tk::Offsets;
use tk::Tokenizer;

#[cfg(not(target_os = "android"))]
use tk::FromPretrainedParameters;
Expand Down Expand Up @@ -407,6 +407,31 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getSequenceIds<
'local,
>(
env: JNIEnv<'local>,
_: JObject,
handle: jlong,
) -> JLongArray<'local> {
let encoding = cast_handle::<Encoding>(handle);
let sequence_ids = encoding.get_sequence_ids();
let len = sequence_ids.len() as jsize;
let mut long_ids: Vec<jlong> = Vec::new();
for i in sequence_ids {
if let Some(sequence_id) = i {
long_ids.push(sequence_id as jlong)
} else {
long_ids.push(-1)
}
}

let array = env.new_long_array(len).unwrap();
env.set_long_array_region(&array, 0, &long_ids).unwrap();
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokens<
'local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class Encoding {
private long[] typeIds;
private String[] tokens;
private long[] wordIds;
private long[] sequenceIds;
private long[] attentionMask;
private long[] specialTokenMask;
private CharSpan[] charTokenSpans;
Expand All @@ -36,6 +37,7 @@ protected Encoding(
long[] typeIds,
String[] tokens,
long[] wordIds,
long[] sequenceIds,
long[] attentionMask,
long[] specialTokenMask,
CharSpan[] charTokenSpans,
Expand All @@ -45,6 +47,7 @@ protected Encoding(
this.typeIds = typeIds;
this.tokens = tokens;
this.wordIds = wordIds;
this.sequenceIds = sequenceIds;
this.attentionMask = attentionMask;
this.specialTokenMask = specialTokenMask;
this.charTokenSpans = charTokenSpans;
Expand Down Expand Up @@ -109,6 +112,15 @@ public long[] getWordIds() {
return wordIds;
}

/**
* Returns the sequence ids.
*
* @return the sequence ids
*/
public long[] getSequenceIds() {
return sequenceIds;
}

/**
* Returns the attention masks.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
long[] sequenceIds = TokenizersLibrary.LIB.getSequenceIds(encoding);
long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
Expand All @@ -646,6 +647,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
typeIds,
tokens,
wordIds,
sequenceIds,
attentionMask,
specialTokenMask,
charSpans,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public native long[] batchEncodePair(

public native long[] getWordIds(long encoding);

public native long[] getSequenceIds(long encoding);

public native String[] getTokens(long encoding);

public native long[] getAttentionMask(long encoding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,32 @@
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** The translator for Huggingface question answering model. */
public class QuestionAnsweringTranslator implements Translator<QAInput, String> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private boolean detail;

QuestionAnsweringTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
Batchifier batchifier,
boolean detail) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
this.detail = detail;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -102,6 +109,27 @@ private String decode(NDList list, Encoding encoding) {
startLogits = startLogits.duplicate();
endLogits = endLogits.duplicate();
}
if (detail) {
// exclude undesired sequences
long[] sequenceIds = encoding.getSequenceIds();
List<Integer> undesired = new ArrayList<>();
for (int i = 0; i < sequenceIds.length; ++i) {
if (sequenceIds[i] == 0) {
undesired.add(i);
}
}
int[] idx = undesired.stream().mapToInt(Integer::intValue).toArray();
NDIndex ndIndex = new NDIndex("{}", list.getManager().create(idx));
startLogits.set(ndIndex, -100000f);
endLogits.set(ndIndex, -100000f);

// normalize
startLogits = startLogits.sub(startLogits.max()).exp();
startLogits = startLogits.div(startLogits.sum());
endLogits = endLogits.sub(endLogits.max()).exp();
endLogits = endLogits.div(endLogits.sum());
}

// exclude <CLS>, TODO: exclude impossible ids properly and handle max answer length
startLogits.set(new NDIndex(0), -100000);
endLogits.set(new NDIndex(0), -100000);
Expand All @@ -111,12 +139,26 @@ private String decode(NDList list, Encoding encoding) {
int tmp = startIdx;
startIdx = endIdx;
endIdx = tmp;
NDArray tmpArray = startLogits;
startLogits = endLogits;
endLogits = tmpArray;
}
long[] indices = encoding.getIds();
int len = endIdx - startIdx + 1;
long[] ids = new long[len];
System.arraycopy(indices, startIdx, ids, 0, len);
return tokenizer.decode(ids).trim();
String answer = tokenizer.decode(ids).trim();
if (detail) {
float score = startLogits.getFloat(startIdx) * endLogits.getFloat(endIdx);

Map<String, Object> dict = new ConcurrentHashMap<>();
dict.put("score", score);
dict.put("start", startIdx);
dict.put("end", endIdx);
dict.put("answer", answer);
return JsonUtils.toJson(dict);
}
return answer;
}

/**
Expand Down Expand Up @@ -149,6 +191,7 @@ public static final class Builder {
private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;
private boolean detail;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -176,6 +219,17 @@ public Builder optBatchifier(Batchifier batchifier) {
return this;
}

/**
* Sets if output detail for the {@link Translator}.
*
* @param detail true to output detail
* @return this builder
*/
public Builder optDetail(boolean detail) {
this.detail = detail;
return this;
}

/**
* Configures the builder with the model arguments.
*
Expand All @@ -184,6 +238,7 @@ public Builder optBatchifier(Batchifier batchifier) {
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optDetail(ArgumentsUtil.booleanValue(arguments, "detail"));
optBatchifier(Batchifier.fromString(batchifierStr));
}

Expand All @@ -194,7 +249,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public QuestionAnsweringTranslator build() throws IOException {
return new QuestionAnsweringTranslator(tokenizer, includeTokenTypes, batchifier);
return new QuestionAnsweringTranslator(
tokenizer, includeTokenTypes, batchifier, detail);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ public void testTokenizer() throws IOException {
long[] ids = {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 100, 136, 102};
long[] typeIds = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] wordIds = {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -1};
long[] sequenceIds = {-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1};
long[] attentionMask = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
long[] specialTokenMask = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1};

Assert.assertEquals(expected, encoding.getTokens());
Assert.assertEquals(ids, encoding.getIds());
Assert.assertEquals(typeIds, encoding.getTypeIds());
Assert.assertEquals(wordIds, encoding.getWordIds());
Assert.assertEquals(sequenceIds, encoding.getSequenceIds());
Assert.assertEquals(attentionMask, encoding.getAttentionMask());
Assert.assertEquals(specialTokenMask, encoding.getSpecialTokenMask());

Expand Down Expand Up @@ -104,6 +106,10 @@ public void testTokenizer() throws IOException {
Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd());
}

encoding = tokenizer.encode(inputs[0], inputs[1]);
sequenceIds = new long[] {-1, 0, 0, 0, 0, 0, 0, -1, 1, 1, 1, 1, 1, -1};
Assert.assertEquals(encoding.getSequenceIds(), sequenceIds);

Assert.assertThrows(() -> tokenizer.encode((String) null));
Assert.assertThrows(() -> tokenizer.encode(new String[] {null}));
Assert.assertThrows(() -> tokenizer.encode(null, null));
Expand Down
Loading