From 8fb72eeeba69414d8f8b6312055ea9d482bb0967 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 8 Sep 2024 15:48:25 -0700 Subject: [PATCH] [api] Optimized text embedding post processing performance --- .../CrossEncoderServingTranslator.java | 234 +++++++++++------- .../nlp/translator/QaServingTranslator.java | 10 +- .../TextClassificationServingTranslator.java | 10 +- .../TextEmbeddingServingTranslator.java | 12 +- .../TokenClassificationServingTranslator.java | 10 +- extensions/tokenizers/build.gradle.kts | 2 +- .../translator/CrossEncoderTranslator.java | 28 ++- .../translator/TextEmbeddingTranslator.java | 17 +- .../CrossEncoderTranslatorTest.java | 61 ++++- 9 files changed, 254 insertions(+), 130 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java index 6c92a3c3977a..fefd489aadfe 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java @@ -17,7 +17,6 @@ import ai.djl.ndarray.BytesSupplier; import ai.djl.ndarray.NDList; import ai.djl.translate.Batchifier; -import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; @@ -34,7 +33,7 @@ import java.util.List; /** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */ -public class CrossEncoderServingTranslator implements NoBatchifyTranslator { +public class CrossEncoderServingTranslator implements Translator { private Translator translator; @@ -56,74 +55,13 @@ public void prepare(TranslatorContext ctx) throws Exception { /** {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, Input input) throws Exception { - PairList content = input.getContent(); - if (content.isEmpty()) { - throw new TranslateException("Input data is empty."); + ReRankingInput in = ReRankingInput.parseInput(input); + if (in.batch != null) { + ctx.setAttachment("batch", Boolean.TRUE); + return translator.batchProcessInput(ctx, in.batch); } - String contentType = input.getProperty("Content-Type", null); - if (contentType != null) { - int pos = contentType.indexOf(';'); - if (pos > 0) { - contentType = contentType.substring(0, pos); - } - } - StringPair pair = null; - if ("application/json".equals(contentType)) { - String json = input.getData().getAsString(); - try { - JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class); - if (element.isJsonArray()) { - ctx.setAttachment("batch", Boolean.TRUE); - JsonArray array = element.getAsJsonArray(); - int size = array.size(); - List inputs = new ArrayList<>(size); - for (int i = 0; i < size; ++i) { - JsonObject obj = array.get(i).getAsJsonObject(); - inputs.add(parseStringPair(obj)); - } - return translator.batchProcessInput(ctx, inputs); - } else if (element.isJsonObject()) { - JsonObject obj = element.getAsJsonObject(); - JsonElement query = obj.get("query"); - if (query != null) { - String key = query.getAsString(); - JsonArray texts = obj.get("texts").getAsJsonArray(); - int size = texts.size(); - List inputs = new ArrayList<>(size); - for (int i = 0; i < size; ++i) { - String value = texts.get(i).getAsString(); - inputs.add(new StringPair(key, value)); - } - ctx.setAttachment("batch", Boolean.TRUE); - return translator.batchProcessInput(ctx, inputs); - } else { - pair = parseStringPair(obj); - } - } else { - throw new TranslateException("Unexpected json type"); - } - } catch (JsonParseException e) { - throw new TranslateException("Input is not a valid json.", e); - } - } else { - String text = input.getAsString("text"); - String textPair = input.getAsString("text_pair"); - if (text != null && textPair != null) { - pair = new StringPair(text, textPair); - } - String key = input.getAsString("key"); - String value = input.getAsString("value"); - if (key != null && value != null) { - pair = new StringPair(key, value); - } - } - - if (pair == null) { - throw new TranslateException("Missing key or value in input."); - } - - NDList ret = translator.processInput(ctx, pair); + NDList ret = translator.processInput(ctx, in.pair); Batchifier batchifier = translator.getBatchifier(); if (batchifier != null) { NDList[] batch = {ret}; @@ -132,34 +70,160 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception return ret; } + /** {@inheritDoc} */ + @Override + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + public NDList batchProcessInput(TranslatorContext ctx, List inputs) throws Exception { + int[] mapping = new int[inputs.size()]; + List prompts = new ArrayList<>(mapping.length); + for (int i = 0; i < mapping.length; ++i) { + ReRankingInput in = ReRankingInput.parseInput(inputs.get(i)); + if (in.batch != null) { + List batch = in.batch; + mapping[i] = batch.size(); + prompts.addAll(batch); + } else { + mapping[i] = -1; + prompts.add(in.pair); + } + } + ctx.setAttachment("mapping", mapping); + return translator.batchProcessInput(ctx, prompts); + } + /** {@inheritDoc} */ @Override public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { Output output = new Output(); output.addProperty("Content-Type", "application/json"); - if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); - } else { - Batchifier batchifier = translator.getBatchifier(); - if (batchifier != null) { - list = batchifier.unbatchify(list)[0]; - } + if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) { output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } else { + output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); } return output; } - private StringPair parseStringPair(JsonObject json) throws TranslateException { - JsonElement text = json.get("text"); - JsonElement textPair = json.get("text_pair"); - if (text != null && textPair != null) { - return new StringPair(text.getAsString(), textPair.getAsString()); + /** {@inheritDoc} */ + @Override + @SuppressWarnings("PMD.SignatureDeclareThrowsException") + public List batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception { + List outputs = translator.batchProcessOutput(ctx, list); + int[] mapping = (int[]) ctx.getAttachment("mapping"); + List ret = new ArrayList<>(mapping.length); + int index = 0; + for (int size : mapping) { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + if (size == -1) { + // non-batching + output.add(BytesSupplier.wrapAsJson(outputs.get(index++))); + } else { + // client side batching + float[][] embeddings = new float[size][]; + for (int j = 0; j < size; ++j) { + embeddings[j] = outputs.get(index++); + } + output.add(BytesSupplier.wrapAsJson(embeddings)); + } + ret.add(output); + } + return ret; + } + + private static final class ReRankingInput { + + private StringPair pair; + private List batch; + + ReRankingInput(StringPair pair) { + this.pair = pair; + } + + ReRankingInput(List batch) { + this.batch = batch; } - JsonElement key = json.get("key"); - JsonElement value = json.get("value"); - if (key != null && value != null) { - return new StringPair(key.getAsString(), value.getAsString()); + + static ReRankingInput parseInput(Input input) throws TranslateException { + PairList content = input.getContent(); + if (content.isEmpty()) { + throw new TranslateException("Input data is empty."); + } + + String contentType = input.getProperty("Content-Type", null); + if (contentType != null) { + int pos = contentType.indexOf(';'); + if (pos > 0) { + contentType = contentType.substring(0, pos); + } + } + StringPair pair = null; + if ("application/json".equals(contentType)) { + String json = input.getData().getAsString(); + try { + JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class); + if (element.isJsonArray()) { + JsonArray array = element.getAsJsonArray(); + int size = array.size(); + List batch = new ArrayList<>(size); + for (int i = 0; i < size; ++i) { + JsonObject obj = array.get(i).getAsJsonObject(); + batch.add(parseStringPair(obj)); + } + return new ReRankingInput(batch); + } else if (element.isJsonObject()) { + JsonObject obj = element.getAsJsonObject(); + JsonElement query = obj.get("query"); + if (query != null) { + String key = query.getAsString(); + JsonArray texts = obj.get("texts").getAsJsonArray(); + int size = texts.size(); + List batch = new ArrayList<>(size); + for (int i = 0; i < size; ++i) { + String value = texts.get(i).getAsString(); + batch.add(new StringPair(key, value)); + } + return new ReRankingInput(batch); + } else { + pair = parseStringPair(obj); + } + } else { + throw new TranslateException("Unexpected json type"); + } + } catch (JsonParseException e) { + throw new TranslateException("Input is not a valid json.", e); + } + } else { + String text = input.getAsString("text"); + String textPair = input.getAsString("text_pair"); + if (text != null && textPair != null) { + pair = new StringPair(text, textPair); + } + String key = input.getAsString("key"); + String value = input.getAsString("value"); + if (key != null && value != null) { + pair = new StringPair(key, value); + } + } + + if (pair == null) { + throw new TranslateException("Missing key or value in input."); + } + return new ReRankingInput(pair); + } + + private static StringPair parseStringPair(JsonObject json) throws TranslateException { + JsonElement text = json.get("text"); + JsonElement textPair = json.get("text_pair"); + if (text != null && textPair != null) { + return new StringPair(text.getAsString(), textPair.getAsString()); + } + JsonElement key = json.get("key"); + JsonElement value = json.get("value"); + if (key != null && value != null) { + return new StringPair(key.getAsString(), value.getAsString()); + } + throw new TranslateException("Missing text or text_pair in json."); } - throw new TranslateException("Missing text or text_pair in json."); } } diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java index 45e3f59afbad..de472261ad2a 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/QaServingTranslator.java @@ -115,14 +115,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { Output output = new Output(); output.addProperty("Content-Type", "application/json"); - if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); + if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) { + output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); } else { - Batchifier batchifier = translator.getBatchifier(); - if (batchifier != null) { - list = batchifier.unbatchify(list)[0]; - } - output.add(translator.processOutput(ctx, list)); + output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); } return output; } diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java index f74116aeb291..ef55b59978f1 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextClassificationServingTranslator.java @@ -74,14 +74,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { Output output = new Output(); output.addProperty("Content-Type", "application/json"); - if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); + if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) { + output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); } else { - Batchifier batchifier = translator.getBatchifier(); - if (batchifier != null) { - list = batchifier.unbatchify(list)[0]; - } - output.add(translator.processOutput(ctx, list)); + output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); } return output; } diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java index 52dcc593ce96..861b4bade643 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TextEmbeddingServingTranslator.java @@ -93,21 +93,17 @@ public NDList batchProcessInput(TranslatorContext ctx, List inputs) throw public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { Output output = new Output(); output.addProperty("Content-Type", "application/json"); - if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); - } else { - Batchifier batchifier = translator.getBatchifier(); - if (batchifier != null) { - list = batchifier.unbatchify(list)[0]; - } + if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) { output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } else { + output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); } return output; } /** {@inheritDoc} */ @Override - @SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"}) + @SuppressWarnings("PMD.SignatureDeclareThrowsException") public List batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception { List outputs = translator.batchProcessOutput(ctx, list); int[] mapping = (int[]) ctx.getAttachment("mapping"); diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java index a6ab6e8af20d..2f73c4d67aef 100644 --- a/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java +++ b/api/src/main/java/ai/djl/modality/nlp/translator/TokenClassificationServingTranslator.java @@ -72,14 +72,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { Output output = new Output(); output.addProperty("Content-Type", "application/json"); - if (ctx.getAttachment("batch") != null) { - output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); - } else { - Batchifier batchifier = translator.getBatchifier(); - if (batchifier != null) { - list = batchifier.unbatchify(list)[0]; - } + if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) { output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list))); + } else { + output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list))); } return output; } diff --git a/extensions/tokenizers/build.gradle.kts b/extensions/tokenizers/build.gradle.kts index 30fd0e4dcab8..c4437da53013 100644 --- a/extensions/tokenizers/build.gradle.kts +++ b/extensions/tokenizers/build.gradle.kts @@ -89,7 +89,7 @@ tasks { downloadPath gzipInto file } - if ("text_embedding" != task) + if (task !in arrayOf("text_embedding", "text_classification")) continue file = prefix / task / "ai.djl.huggingface.rust.json" diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java index 61663519a358..ac5e4ed552ef 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/CrossEncoderTranslator.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -88,14 +89,29 @@ public float[] processOutput(TranslatorContext ctx, NDList list) { /** {@inheritDoc} */ @Override public List batchProcessOutput(TranslatorContext ctx, NDList list) { - NDList[] batches = batchifier.unbatchify(list); - List ret = new ArrayList<>(batches.length); - for (NDList batch : batches) { - NDArray result = batch.get(0); - if (sigmoid) { + if (sigmoid) { + NDList[] batches = batchifier.unbatchify(list); + List ret = new ArrayList<>(batches.length); + for (NDList batch : batches) { + NDArray result = batch.get(0); result = result.getNDArrayInternal().sigmoid(); + ret.add(result.toFloatArray()); } - ret.add(result.toFloatArray()); + return ret; + } + NDArray array = list.get(0); + int batchSize = Math.toIntExact(array.size(0)); + float[] buf = list.get(0).toFloatArray(); + if (batchSize == 1) { + return Collections.singletonList(buf); + } + + int length = buf.length / batchSize; + List ret = new ArrayList<>(batchSize); + for (int i = 0; i < batchSize; ++i) { + float[] f = new float[length]; + System.arraycopy(buf, i * length, f, 0, length); + ret.add(f); } return ret; } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java index 8a952056155d..0d91734d5be7 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -30,6 +30,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -155,14 +156,20 @@ public float[] processOutput(TranslatorContext ctx, NDList list) { /** {@inheritDoc} */ @Override public List batchProcessOutput(TranslatorContext ctx, NDList list) { - int batchSize = Math.toIntExact(list.head().size(0)); NDArray attentionMask = (NDArray) ctx.getAttachment("attentionMask"); NDArray output = processEmbedding(list, attentionMask); + int batchSize = Math.toIntExact(output.size(0)); + float[] buf = output.toFloatArray(); + if (batchSize == 1) { + return Collections.singletonList(buf); + } + + int length = buf.length / batchSize; List ret = new ArrayList<>(batchSize); - NDList splitList = output.split(batchSize); - for (int i = 0; i < batchSize; i++) { - NDArray array = splitList.get(i); - ret.add(array.toFloatArray()); + for (int i = 0; i < batchSize; ++i) { + float[] f = new float[length]; + System.arraycopy(buf, i * length, f, 0, length); + ret.add(f); } return ret; } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java index b59e840e7114..4f0c64a2b294 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -37,6 +37,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,7 +45,6 @@ public class CrossEncoderTranslatorTest { @Test - @SuppressWarnings("unchecked") public void testCrossEncoderTranslator() throws ModelException, IOException, TranslateException { String text1 = "Sentence 1"; @@ -146,21 +146,21 @@ public void testCrossEncoderTranslator() input.addProperty("Content-Type", "application/json"); input.add("data", "{\"query\": \"" + text1 + "\", \"texts\": [\"" + text2 + "\"]}"); res = predictor.predict(input); - buf = ((List) res.getData().getAsObject()).get(0); + buf = ((float[][]) res.getData().getAsObject())[0]; Assert.assertEquals(buf[0], 0.32455865, 0.0001); input = new Input(); input.addProperty("Content-Type", "application/json"); input.add("data", "{\"query\": \"" + text1 + "\", \"texts\": [\"" + text2 + "\"]}"); res = predictor.predict(input); - buf = ((List) res.getData().getAsObject()).get(0); + buf = ((float[][]) res.getData().getAsObject())[0]; Assert.assertEquals(buf[0], 0.32455865, 0.0001); input = new Input(); input.addProperty("Content-Type", "application/json"); input.add("data", "[{\"text\": \"" + text1 + "\", \"text_pair\": \"" + text2 + "\"}]"); res = predictor.predict(input); - buf = ((List) res.getData().getAsObject()).get(0); + buf = ((float[][]) res.getData().getAsObject())[0]; Assert.assertEquals(buf[0], 0.32455865, 0.0001); Assert.assertThrows(TranslateException.class, () -> predictor.predict(new Input())); @@ -213,4 +213,57 @@ public void testCrossEncoderTranslator() () -> factory.newInstance(String.class, Integer.class, model, arguments)); } } + + @Test + public void testCrossEncoderTranslatorServingBatch() + throws ModelException, IOException, TranslateException { + String text1 = "Sentence 1"; + String text2 = "Sentence 2"; + Block block = + new LambdaBlock( + a -> { + NDManager manager = a.getManager(); + NDArray array = + manager.create( + new float[][] {{-0.7329f}, {-0.7329f}, {-0.7329f}}); + return new NDList(array); + }, + "model"); + Path modelDir = Paths.get("build/model"); + Files.createDirectories(modelDir); + + Criteria criteria = + Criteria.builder() + .setTypes(Input.class, Output.class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optArgument("reranking", true) + .optArgument("sigmoid", false) + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + Input input1 = new Input(); + input1.add("text", text1); + input1.add("text_pair", text2); + + Input input2 = new Input(); + input2.addProperty("Content-Type", "application/json; charset=utf-8"); + input2.add( + "data", + "{\"query\": \"query\", \"texts\": [\"" + text1 + "\", \"" + text2 + "\"]}"); + List batchInput = Arrays.asList(input1, input2); + + List batchOutput = predictor.batchPredict(batchInput); + Assert.assertEquals(batchOutput.size(), 2); + float[] ret1 = (float[]) batchOutput.get(0).getData().getAsObject(); + float[][] ret2 = (float[][]) batchOutput.get(1).getData().getAsObject(); + Assert.assertEquals(ret1[0], -0.7329f, 0.0001); + Assert.assertEquals(ret2[1][0], -0.7329f, 0.0001); + } + } }