Skip to content

Commit

Permalink
[api] Optimized text embedding post processing performance
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Sep 8, 2024
1 parent 61ffdbb commit 8fb72ee
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
Expand All @@ -34,7 +33,7 @@
import java.util.List;

/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
public class CrossEncoderServingTranslator implements NoBatchifyTranslator<Input, Output> {
public class CrossEncoderServingTranslator implements Translator<Input, Output> {

private Translator<StringPair, float[]> translator;

Expand All @@ -56,74 +55,13 @@ public void prepare(TranslatorContext ctx) throws Exception {
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
ReRankingInput in = ReRankingInput.parseInput(input);
if (in.batch != null) {
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, in.batch);
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
ctx.setAttachment("batch", Boolean.TRUE);
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
inputs.add(parseStringPair(obj));
}
return translator.batchProcessInput(ctx, inputs);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> inputs = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
inputs.add(new StringPair(key, value));
}
ctx.setAttachment("batch", Boolean.TRUE);
return translator.batchProcessInput(ctx, inputs);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}

NDList ret = translator.processInput(ctx, pair);
NDList ret = translator.processInput(ctx, in.pair);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
Expand All @@ -132,34 +70,160 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
return ret;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) throws Exception {
int[] mapping = new int[inputs.size()];
List<StringPair> prompts = new ArrayList<>(mapping.length);
for (int i = 0; i < mapping.length; ++i) {
ReRankingInput in = ReRankingInput.parseInput(inputs.get(i));
if (in.batch != null) {
List<StringPair> batch = in.batch;
mapping[i] = batch.size();
prompts.addAll(batch);
} else {
mapping[i] = -1;
prompts.add(in.pair);
}
}
ctx.setAttachment("mapping", mapping);
return translator.batchProcessInput(ctx, prompts);
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
} else {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}

private StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
/** {@inheritDoc} */
@Override
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
List<Output> ret = new ArrayList<>(mapping.length);
int index = 0;
for (int size : mapping) {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (size == -1) {
// non-batching
output.add(BytesSupplier.wrapAsJson(outputs.get(index++)));
} else {
// client side batching
float[][] embeddings = new float[size][];
for (int j = 0; j < size; ++j) {
embeddings[j] = outputs.get(index++);
}
output.add(BytesSupplier.wrapAsJson(embeddings));
}
ret.add(output);
}
return ret;
}

private static final class ReRankingInput {

private StringPair pair;
private List<StringPair> batch;

ReRankingInput(StringPair pair) {
this.pair = pair;
}

ReRankingInput(List<StringPair> batch) {
this.batch = batch;
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());

static ReRankingInput parseInput(Input input) throws TranslateException {
PairList<String, BytesSupplier> content = input.getContent();
if (content.isEmpty()) {
throw new TranslateException("Input data is empty.");
}

String contentType = input.getProperty("Content-Type", null);
if (contentType != null) {
int pos = contentType.indexOf(';');
if (pos > 0) {
contentType = contentType.substring(0, pos);
}
}
StringPair pair = null;
if ("application/json".equals(contentType)) {
String json = input.getData().getAsString();
try {
JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
if (element.isJsonArray()) {
JsonArray array = element.getAsJsonArray();
int size = array.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
JsonObject obj = array.get(i).getAsJsonObject();
batch.add(parseStringPair(obj));
}
return new ReRankingInput(batch);
} else if (element.isJsonObject()) {
JsonObject obj = element.getAsJsonObject();
JsonElement query = obj.get("query");
if (query != null) {
String key = query.getAsString();
JsonArray texts = obj.get("texts").getAsJsonArray();
int size = texts.size();
List<StringPair> batch = new ArrayList<>(size);
for (int i = 0; i < size; ++i) {
String value = texts.get(i).getAsString();
batch.add(new StringPair(key, value));
}
return new ReRankingInput(batch);
} else {
pair = parseStringPair(obj);
}
} else {
throw new TranslateException("Unexpected json type");
}
} catch (JsonParseException e) {
throw new TranslateException("Input is not a valid json.", e);
}
} else {
String text = input.getAsString("text");
String textPair = input.getAsString("text_pair");
if (text != null && textPair != null) {
pair = new StringPair(text, textPair);
}
String key = input.getAsString("key");
String value = input.getAsString("value");
if (key != null && value != null) {
pair = new StringPair(key, value);
}
}

if (pair == null) {
throw new TranslateException("Missing key or value in input.");
}
return new ReRankingInput(pair);
}

private static StringPair parseStringPair(JsonObject json) throws TranslateException {
JsonElement text = json.get("text");
JsonElement textPair = json.get("text_pair");
if (text != null && textPair != null) {
return new StringPair(text.getAsString(), textPair.getAsString());
}
JsonElement key = json.get("key");
JsonElement value = json.get("value");
if (key != null && value != null) {
return new StringPair(key.getAsString(), value.getAsString());
}
throw new TranslateException("Missing text or text_pair in json.");
}
throw new TranslateException("Missing text or text_pair in json.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(translator.processOutput(ctx, list));
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
output.add(translator.processOutput(ctx, list));
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,17 @@ public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) throw
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
} else {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings({"PMD.SignatureDeclareThrowsException", "unchecked"})
@SuppressWarnings("PMD.SignatureDeclareThrowsException")
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) throws Exception {
List<float[]> outputs = translator.batchProcessOutput(ctx, list);
int[] mapping = (int[]) ctx.getAttachment("mapping");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,10 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
if (ctx.getAttachment("batch") != null) {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
} else {
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
list = batchifier.unbatchify(list)[0];
}
if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
} else {
output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}
Expand Down
2 changes: 1 addition & 1 deletion extensions/tokenizers/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ tasks {
downloadPath gzipInto file
}

if ("text_embedding" != task)
if (task !in arrayOf("text_embedding", "text_classification"))
continue

file = prefix / task / "ai.djl.huggingface.rust.json"
Expand Down
Loading

0 comments on commit 8fb72ee

Please sign in to comment.