diff --git a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java
index 6c92a3c3977a..fefd489aadfe 100644
--- a/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java
+++ b/api/src/main/java/ai/djl/modality/nlp/translator/CrossEncoderServingTranslator.java
@@ -17,7 +17,6 @@
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
-import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
@@ -34,7 +33,7 @@
import java.util.List;
/** A {@link Translator} that can handle generic cross encoder {@link Input} and {@link Output}. */
-public class CrossEncoderServingTranslator implements NoBatchifyTranslator {
+public class CrossEncoderServingTranslator implements Translator {
private Translator translator;
@@ -56,74 +55,13 @@ public void prepare(TranslatorContext ctx) throws Exception {
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
- PairList content = input.getContent();
- if (content.isEmpty()) {
- throw new TranslateException("Input data is empty.");
+ ReRankingInput in = ReRankingInput.parseInput(input);
+ if (in.batch != null) {
+ ctx.setAttachment("batch", Boolean.TRUE);
+ return translator.batchProcessInput(ctx, in.batch);
}
- String contentType = input.getProperty("Content-Type", null);
- if (contentType != null) {
- int pos = contentType.indexOf(';');
- if (pos > 0) {
- contentType = contentType.substring(0, pos);
- }
- }
- StringPair pair = null;
- if ("application/json".equals(contentType)) {
- String json = input.getData().getAsString();
- try {
- JsonElement element = JsonUtils.GSON.fromJson(json, JsonElement.class);
- if (element.isJsonArray()) {
- ctx.setAttachment("batch", Boolean.TRUE);
- JsonArray array = element.getAsJsonArray();
- int size = array.size();
- List inputs = new ArrayList<>(size);
- for (int i = 0; i < size; ++i) {
- JsonObject obj = array.get(i).getAsJsonObject();
- inputs.add(parseStringPair(obj));
- }
- return translator.batchProcessInput(ctx, inputs);
- } else if (element.isJsonObject()) {
- JsonObject obj = element.getAsJsonObject();
- JsonElement query = obj.get("query");
- if (query != null) {
- String key = query.getAsString();
- JsonArray texts = obj.get("texts").getAsJsonArray();
- int size = texts.size();
- List inputs = new ArrayList<>(size);
- for (int i = 0; i < size; ++i) {
- String value = texts.get(i).getAsString();
- inputs.add(new StringPair(key, value));
- }
- ctx.setAttachment("batch", Boolean.TRUE);
- return translator.batchProcessInput(ctx, inputs);
- } else {
- pair = parseStringPair(obj);
- }
- } else {
- throw new TranslateException("Unexpected json type");
- }
- } catch (JsonParseException e) {
- throw new TranslateException("Input is not a valid json.", e);
- }
- } else {
- String text = input.getAsString("text");
- String textPair = input.getAsString("text_pair");
- if (text != null && textPair != null) {
- pair = new StringPair(text, textPair);
- }
- String key = input.getAsString("key");
- String value = input.getAsString("value");
- if (key != null && value != null) {
- pair = new StringPair(key, value);
- }
- }
-
- if (pair == null) {
- throw new TranslateException("Missing key or value in input.");
- }
-
- NDList ret = translator.processInput(ctx, pair);
+ NDList ret = translator.processInput(ctx, in.pair);
Batchifier batchifier = translator.getBatchifier();
if (batchifier != null) {
NDList[] batch = {ret};
@@ -132,34 +70,160 @@ public NDList processInput(TranslatorContext ctx, Input input) throws Exception
return ret;
}
+ /** {@inheritDoc} */
+ @Override
+ @SuppressWarnings("PMD.SignatureDeclareThrowsException")
+ public NDList batchProcessInput(TranslatorContext ctx, List inputs) throws Exception {
+ int[] mapping = new int[inputs.size()];
+ List prompts = new ArrayList<>(mapping.length);
+ for (int i = 0; i < mapping.length; ++i) {
+ ReRankingInput in = ReRankingInput.parseInput(inputs.get(i));
+ if (in.batch != null) {
+ List batch = in.batch;
+ mapping[i] = batch.size();
+ prompts.addAll(batch);
+ } else {
+ mapping[i] = -1;
+ prompts.add(in.pair);
+ }
+ }
+ ctx.setAttachment("mapping", mapping);
+ return translator.batchProcessInput(ctx, prompts);
+ }
+
/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) throws Exception {
Output output = new Output();
output.addProperty("Content-Type", "application/json");
- if (ctx.getAttachment("batch") != null) {
- output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
- } else {
- Batchifier batchifier = translator.getBatchifier();
- if (batchifier != null) {
- list = batchifier.unbatchify(list)[0];
- }
+ if (ctx.getAttachment("batch") == null && translator.getBatchifier() == null) {
output.add(BytesSupplier.wrapAsJson(translator.processOutput(ctx, list)));
+ } else {
+ output.add(BytesSupplier.wrapAsJson(translator.batchProcessOutput(ctx, list)));
}
return output;
}
- private StringPair parseStringPair(JsonObject json) throws TranslateException {
- JsonElement text = json.get("text");
- JsonElement textPair = json.get("text_pair");
- if (text != null && textPair != null) {
- return new StringPair(text.getAsString(), textPair.getAsString());
+ /** {@inheritDoc} */
+ @Override
+ @SuppressWarnings("PMD.SignatureDeclareThrowsException")
+ public List