diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java index 751019c760f..089d071c65b 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -39,6 +39,8 @@ import ai.djl.translate.TranslatorContext; import ai.djl.util.JsonUtils; +import com.google.gson.annotations.SerializedName; + import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -340,7 +342,7 @@ float[][] getLabels() { public static Sam2Input fromJson(String input) throws IOException { Prompt prompt = JsonUtils.GSON.fromJson(input, Prompt.class); if (prompt.image == null) { - throw new IllegalArgumentException("Missing image value"); + throw new IllegalArgumentException("Missing image_url value"); } if (prompt.prompt == null || prompt.prompt.length == 0) { throw new IllegalArgumentException("Missing prompt value"); @@ -477,7 +479,10 @@ public void setLabel(int label) { } private static final class Prompt { + + @SerializedName("image_url") String image; + Location[] prompt; boolean visualize; diff --git a/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java index a742e92d49d..41789604902 100644 --- a/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java +++ b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java @@ -30,7 +30,7 @@ public void test() throws IOException { Path file = Paths.get("../examples/src/test/resources/kitten.jpg"); Image img = ImageFactory.getInstance().fromFile(file); String json = - "{\"image\": \"" + "{\"image_url\": \"" + file.toUri().toURL() + "\",\n" + "\"visualize\": true,\n"