From f9031c76c154d39200055d8093f2d381e0e8a42a Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 2 Oct 2024 12:59:26 -0700 Subject: [PATCH] [api] Standardizes CV output format --- .../java/ai/djl/modality/Classifications.java | 54 ++++----------- .../djl/modality/cv/output/CategoryMask.java | 41 +++++------- .../modality/cv/output/DetectedObjects.java | 20 +----- .../ai/djl/modality/cv/output/Joints.java | 26 ++----- .../ai/djl/modality/cv/output/Landmark.java | 12 ++++ .../java/ai/djl/modality/cv/output/Mask.java | 15 +++++ .../java/ai/djl/modality/cv/output/Point.java | 9 +++ .../ai/djl/modality/cv/output/Rectangle.java | 31 +++++++-- .../java/ai/djl/util/JsonSerializable.java | 41 +++++++++++- api/src/main/java/ai/djl/util/JsonUtils.java | 17 ++++- .../modality/cv/output/CategoryMaskTest.java | 35 ++++++++++ .../cv/output/DetectedObjectsTest.java | 67 +++++++++++++++++++ 12 files changed, 255 insertions(+), 113 deletions(-) create mode 100644 api/src/test/java/ai/djl/modality/cv/output/CategoryMaskTest.java create mode 100644 api/src/test/java/ai/djl/modality/cv/output/DetectedObjectsTest.java diff --git a/api/src/main/java/ai/djl/modality/Classifications.java b/api/src/main/java/ai/djl/modality/Classifications.java index 070c0372a7a..cc4c89d63a0 100644 --- a/api/src/main/java/ai/djl/modality/Classifications.java +++ b/api/src/main/java/ai/djl/modality/Classifications.java @@ -18,14 +18,8 @@ import ai.djl.util.JsonSerializable; import ai.djl.util.JsonUtils; -import com.google.gson.Gson; import com.google.gson.JsonElement; -import com.google.gson.JsonSerializationContext; -import com.google.gson.JsonSerializer; -import java.lang.reflect.Type; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -42,11 +36,6 @@ public class Classifications implements JsonSerializable, Ensembleable classNames; @@ -210,31 +199,25 @@ public T get(String className) { /** {@inheritDoc} */ @Override - public String toJson() { - return GSON.toJson(this) + '\n'; - } - - /** {@inheritDoc} */ - @Override - public String getAsString() { - return toJson(); - } - - /** {@inheritDoc} */ - @Override - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8)); + public JsonElement serialize() { + return JsonUtils.GSON.toJsonTree(topK()); } /** {@inheritDoc} */ @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append('[').append(System.lineSeparator()); - for (Classification item : topK(topK)) { - sb.append('\t').append(item).append(System.lineSeparator()); + sb.append("[\n"); + List list = topK(); + int index = 0; + for (Classification item : list) { + sb.append('\t').append(item); + if (++index < list.size()) { + sb.append(','); + } + sb.append('\n'); } - sb.append(']'); + sb.append("]\n"); return sb.toString(); } @@ -306,7 +289,7 @@ public double getProbability() { @Override public String toString() { StringBuilder sb = new StringBuilder(100); - sb.append("{\"class\": \"").append(className).append("\", \"probability\": "); + sb.append("{\"className\": \"").append(className).append("\", \"probability\": "); if (probability < 0.00001) { sb.append(String.format("%.1e", probability)); } else { @@ -317,15 +300,4 @@ public String toString() { return sb.toString(); } } - - /** A customized Gson serializer to serialize the {@code Classifications} object. */ - public static final class ClassificationsSerializer implements JsonSerializer { - - /** {@inheritDoc} */ - @Override - public JsonElement serialize(Classifications src, Type type, JsonSerializationContext ctx) { - List list = src.topK(); - return ctx.serialize(list); - } - } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java index c7d5414da28..c0aea00b8f2 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/CategoryMask.java @@ -18,15 +18,11 @@ import ai.djl.util.JsonUtils; import ai.djl.util.RandomUtils; -import com.google.gson.Gson; import com.google.gson.JsonElement; -import com.google.gson.JsonSerializationContext; -import com.google.gson.JsonSerializer; +import com.google.gson.JsonObject; -import java.lang.reflect.Type; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.stream.Collectors; /** * A class representing the segmentation result of an image in an {@link @@ -38,12 +34,7 @@ public class CategoryMask implements JsonSerializable { private static final int COLOR_BLACK = 0xFF000000; - private static final Gson GSON = - JsonUtils.builder() - .registerTypeAdapter(CategoryMask.class, new SegmentationSerializer()) - .create(); - - private transient List classes; + private List classes; private int[][] mask; /** @@ -77,14 +68,22 @@ public int[][] getMask() { /** {@inheritDoc} */ @Override - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8)); + public JsonElement serialize() { + JsonObject ret = new JsonObject(); + ret.add("classes", JsonUtils.GSON.toJsonTree(classes)); + ret.add("mask", JsonUtils.GSON.toJsonTree(mask)); + return ret; } /** {@inheritDoc} */ @Override - public String toJson() { - return GSON.toJson(this) + '\n'; + public String toString() { + StringBuilder sb = new StringBuilder(4096); + String list = classes.stream().map(s -> '"' + s + '"').collect(Collectors.joining(", ")); + sb.append("{\n\t\"classes\": [").append(list).append("],\n\t\"mask\": "); + sb.append(JsonUtils.GSON_COMPACT.toJson(mask)); + sb.append("\n}"); + return sb.toString(); } /** @@ -195,14 +194,4 @@ private int[] generateColors(int background, int opacity) { } return colors; } - - /** A customized Gson serializer to serialize the {@code Segmentation} object. */ - public static final class SegmentationSerializer implements JsonSerializer { - - /** {@inheritDoc} */ - @Override - public JsonElement serialize(CategoryMask src, Type type, JsonSerializationContext ctx) { - return ctx.serialize(src.getMask()); - } - } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java index 9d58575af59..a46967fc468 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java +++ b/api/src/main/java/ai/djl/modality/cv/output/DetectedObjects.java @@ -13,9 +13,6 @@ package ai.djl.modality.cv.output; import ai.djl.modality.Classifications; -import ai.djl.util.JsonUtils; - -import com.google.gson.Gson; import java.util.List; @@ -27,11 +24,6 @@ public class DetectedObjects extends Classifications { private static final long serialVersionUID = 1L; - private static final Gson GSON = - JsonUtils.builder() - .registerTypeAdapter(DetectedObjects.class, new ClassificationsSerializer()) - .create(); - @SuppressWarnings("serial") private List boundingBoxes; @@ -69,12 +61,6 @@ public int getNumberOfObjects() { return boundingBoxes.size(); } - /** {@inheritDoc} */ - @Override - public String toJson() { - return GSON.toJson(this) + '\n'; - } - /** A {@code DetectedObject} represents a single potential detected Object for an image. */ public static final class DetectedObject extends Classification { @@ -106,15 +92,15 @@ public BoundingBox getBoundingBox() { public String toString() { double probability = getProbability(); StringBuilder sb = new StringBuilder(200); - sb.append("{\"class\": \"").append(getClassName()).append("\", \"probability\": "); + sb.append("{\"className\": \"").append(getClassName()).append("\", \"probability\": "); if (probability < 0.00001) { sb.append(String.format("%.1e", probability)); } else { probability = (int) (probability * 100000) / 100000f; sb.append(String.format("%.5f", probability)); } - if (getBoundingBox() != null) { - sb.append(", \"bounds\": ").append(getBoundingBox()); + if (boundingBox != null) { + sb.append(", \"boundingBox\": ").append(boundingBox); } sb.append('}'); return sb.toString(); diff --git a/api/src/main/java/ai/djl/modality/cv/output/Joints.java b/api/src/main/java/ai/djl/modality/cv/output/Joints.java index 30fa7ce7ebc..ad9d916893e 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Joints.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Joints.java @@ -12,6 +12,8 @@ */ package ai.djl.modality.cv.output; +import ai.djl.util.JsonUtils; + import java.io.Serializable; import java.util.List; @@ -48,19 +50,7 @@ public List getJoints() { /** {@inheritDoc} */ @Override public String toString() { - StringBuilder sb = new StringBuilder(4000); - sb.append("\n[\n\t"); - boolean first = true; - for (Joint joint : joints) { - if (first) { - first = false; - } else { - sb.append(",\n\t"); - } - sb.append(joint); - } - sb.append("\n]"); - return sb.toString(); + return JsonUtils.GSON_PRETTY.toJson(this) + "\n"; } /** @@ -69,7 +59,9 @@ public String toString() { * @see Joints */ public static class Joint extends Point { + private static final long serialVersionUID = 1L; + private double confidence; /** @@ -92,13 +84,5 @@ public Joint(double x, double y, double confidence) { public double getConfidence() { return confidence; } - - /** {@inheritDoc} */ - @Override - public String toString() { - return String.format( - "{\"Joint\": {\"x\"=%.3f, \"y\"=%.3f}, \"confidence\": %.4f}", - getX(), getY(), getConfidence()); - } } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/Landmark.java b/api/src/main/java/ai/djl/modality/cv/output/Landmark.java index 215b91f078e..7f7237cd0b7 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Landmark.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Landmark.java @@ -12,6 +12,10 @@ */ package ai.djl.modality.cv.output; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonObject; + import java.util.List; /** {@code Landmark} is the container that stores the key points for landmark on a single face. */ @@ -41,4 +45,12 @@ public Landmark(double x, double y, double width, double height, List poi public Iterable getPath() { return points; } + + /** {@inheritDoc} */ + @Override + public JsonObject serialize() { + JsonObject ret = super.serialize(); + ret.add("landmarks", JsonUtils.GSON.toJsonTree(points)); + return ret; + } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/Mask.java b/api/src/main/java/ai/djl/modality/cv/output/Mask.java index 622b807d1b0..350d9a4440e 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Mask.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Mask.java @@ -14,6 +14,10 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.types.Shape; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; /** * A mask with a probability for each pixel within a bounding rectangle. @@ -79,6 +83,17 @@ public boolean isFullImageMask() { return fullImageMask; } + /** {@inheritDoc} */ + @Override + public JsonObject serialize() { + JsonObject ret = super.serialize(); + if (fullImageMask) { + ret.add("fullImageMask", new JsonPrimitive(true)); + } + ret.add("mask", JsonUtils.GSON.toJsonTree(probDist)); + return ret; + } + /** * Converts the mask tensor to a mask array. * diff --git a/api/src/main/java/ai/djl/modality/cv/output/Point.java b/api/src/main/java/ai/djl/modality/cv/output/Point.java index 23be0bdf827..c096eab6480 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Point.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Point.java @@ -12,6 +12,8 @@ */ package ai.djl.modality.cv.output; +import ai.djl.util.JsonUtils; + import java.io.Serializable; /** @@ -20,6 +22,7 @@ public class Point implements Serializable { private static final long serialVersionUID = 1L; + private double x; private double y; @@ -52,4 +55,10 @@ public double getX() { public double getY() { return y; } + + /** {@inheritDoc} */ + @Override + public String toString() { + return JsonUtils.GSON_COMPACT.toJson(this); + } } diff --git a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java index b7b16127a00..4ec1f504c2f 100644 --- a/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java +++ b/api/src/main/java/ai/djl/modality/cv/output/Rectangle.java @@ -12,6 +12,11 @@ */ package ai.djl.modality.cv.output; +import ai.djl.util.JsonSerializable; +import ai.djl.util.JsonUtils; + +import com.google.gson.JsonObject; + import java.util.ArrayList; import java.util.List; import java.util.PriorityQueue; @@ -25,7 +30,7 @@ * if you have an image width of 400 pixels and the rectangle starts at 100 pixels, you would use * .25. */ -public class Rectangle implements BoundingBox { +public class Rectangle implements BoundingBox, JsonSerializable { private static final long serialVersionUID = 1L; @@ -145,13 +150,29 @@ public double getHeight() { return height; } + /** + * Returns the upper left and bottom right coordinates. + * + * @return the upper left and bottom right coordinates + */ + public double[] getCoordinates() { + Point upLeft = corners.get(0); + Point bottomRight = corners.get(2); + return new double[] {upLeft.getX(), upLeft.getY(), bottomRight.getX(), bottomRight.getY()}; + } + + /** {@inheritDoc} */ + @Override + public JsonObject serialize() { + JsonObject ret = new JsonObject(); + ret.add("rect", JsonUtils.GSON.toJsonTree(getCoordinates())); + return ret; + } + /** {@inheritDoc} */ @Override public String toString() { - double x = getX(); - double y = getY(); - return String.format( - "{\"x\"=%.3f, \"y\"=%.3f, \"width\"=%.3f, \"height\"=%.3f}", x, y, width, height); + return toJson(); } /** diff --git a/api/src/main/java/ai/djl/util/JsonSerializable.java b/api/src/main/java/ai/djl/util/JsonSerializable.java index 26a0c038e5e..b998e94e51a 100644 --- a/api/src/main/java/ai/djl/util/JsonSerializable.java +++ b/api/src/main/java/ai/djl/util/JsonSerializable.java @@ -14,7 +14,14 @@ import ai.djl.ndarray.BytesSupplier; +import com.google.gson.JsonElement; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; + import java.io.Serializable; +import java.lang.reflect.Type; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; /** * A class implements {@code JsonSerializable} indicates it can be serialized into a json string. @@ -26,5 +33,37 @@ public interface JsonSerializable extends Serializable, BytesSupplier { * * @return a json string */ - String toJson(); + default String toJson() { + return JsonUtils.GSON_COMPACT.toJson(serialize()); + } + + /** {@inheritDoc} */ + @Override + default String getAsString() { + return toJson(); + } + + /** {@inheritDoc} */ + @Override + default ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(toJson().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Serializes the object to the {@code JsonElement}. + * + * @return the {@code JsonElement} + */ + JsonElement serialize(); + + /** A customized Gson serializer to serialize the {@code Segmentation} object. */ + final class Serializer implements JsonSerializer { + + /** {@inheritDoc} */ + @Override + public JsonElement serialize( + JsonSerializable src, Type type, JsonSerializationContext ctx) { + return src.serialize(); + } + } } diff --git a/api/src/main/java/ai/djl/util/JsonUtils.java b/api/src/main/java/ai/djl/util/JsonUtils.java index b691737b386..fdfde1a8c94 100644 --- a/api/src/main/java/ai/djl/util/JsonUtils.java +++ b/api/src/main/java/ai/djl/util/JsonUtils.java @@ -26,7 +26,8 @@ public interface JsonUtils { boolean PRETTY_PRINT = Boolean.parseBoolean(Utils.getEnvOrSystemProperty("DJL_PRETTY_PRINT")); Gson GSON = builder().create(); - Gson GSON_PRETTY = builder().setPrettyPrinting().create(); + Gson GSON_COMPACT = builder(false).create(); + Gson GSON_PRETTY = builder(true).create(); Type LIST_TYPE = new TypeToken>() {}.getType(); /** @@ -35,10 +36,22 @@ public interface JsonUtils { * @return a custom {@code GsonBuilder} instance. */ static GsonBuilder builder() { + return builder(PRETTY_PRINT); + } + + /** + * Returns a custom {@code GsonBuilder} instance. + * + * @param prettyPrint true for pretty print + * @return a custom {@code GsonBuilder} instance. + */ + static GsonBuilder builder(boolean prettyPrint) { GsonBuilder builder = new GsonBuilder() .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") .serializeSpecialFloatingPointValues() + .registerTypeHierarchyAdapter( + JsonSerializable.class, new JsonSerializable.Serializer()) .registerTypeAdapter( Double.class, (JsonSerializer) @@ -49,7 +62,7 @@ static GsonBuilder builder() { } return new JsonPrimitive(src); }); - if (PRETTY_PRINT) { + if (prettyPrint) { builder.setPrettyPrinting(); } return builder; diff --git a/api/src/test/java/ai/djl/modality/cv/output/CategoryMaskTest.java b/api/src/test/java/ai/djl/modality/cv/output/CategoryMaskTest.java new file mode 100644 index 00000000000..4126d0b22ba --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/output/CategoryMaskTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.output; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +public class CategoryMaskTest { + + private static final Logger logger = LoggerFactory.getLogger(CategoryMaskTest.class); + + @Test + public void test() { + List classes = Arrays.asList("cat", "dog"); + CategoryMask mask = new CategoryMask(classes, new int[][] {{1}, {2}}); + + logger.info("CategoryMask: {}", mask); + Assert.assertEquals(mask.toJson(), "{\"classes\":[\"cat\",\"dog\"],\"mask\":[[1],[2]]}"); + } +} diff --git a/api/src/test/java/ai/djl/modality/cv/output/DetectedObjectsTest.java b/api/src/test/java/ai/djl/modality/cv/output/DetectedObjectsTest.java new file mode 100644 index 00000000000..e91d51e29e4 --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/output/DetectedObjectsTest.java @@ -0,0 +1,67 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.output; + +import ai.djl.modality.Classifications; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +public class DetectedObjectsTest { + + private static final Logger logger = LoggerFactory.getLogger(DetectedObjectsTest.class); + + @Test + public void test() { + List classes = Arrays.asList("cat", "dog"); + List probabilities = Arrays.asList(0.1d, 0.2d); + List boxes = + Arrays.asList(new Rectangle(1, 2, 3, 4), new Rectangle(3, 4, 5, 6)); + Classifications classifications = new Classifications(classes, probabilities); + logger.info("Classifications: {}", classifications); + + Assert.assertEquals( + classifications.toJson(), + "[{\"className\":\"dog\",\"probability\":0.2},{\"className\":\"cat\",\"probability\":0.1}]"); + + DetectedObjects detection = new DetectedObjects(classes, probabilities, boxes); + + logger.info("DetectedObjects: {}", detection); + Assert.assertEquals( + detection.toJson(), + "[{\"boundingBox\":{\"rect\":[3,4,8,10]},\"className\":\"dog\",\"probability\":0.2},{\"boundingBox\":{\"rect\":[1,2,4,6]},\"className\":\"cat\",\"probability\":0.1}]"); + + List points = Arrays.asList(new Point(1, 2), new Point(3, 4)); + boxes = Arrays.asList(new Landmark(1, 2, 3, 4, points), new Landmark(3, 4, 5, 6, points)); + detection = new DetectedObjects(classes, probabilities, boxes); + + logger.info("Landmarks: {}", detection); + Assert.assertEquals( + detection.toJson(), + "[{\"boundingBox\":{\"rect\":[3,4,8,10],\"landmarks\":[{\"x\":1,\"y\":2},{\"x\":3,\"y\":4}]},\"className\":\"dog\",\"probability\":0.2},{\"boundingBox\":{\"rect\":[1,2,4,6],\"landmarks\":[{\"x\":1,\"y\":2},{\"x\":3,\"y\":4}]},\"className\":\"cat\",\"probability\":0.1}]"); + + float[][] masks = {{1, 2, 3}, {4, 5, 6}}; + boxes = Arrays.asList(new Mask(1, 2, 3, 4, masks), new Mask(3, 4, 5, 6, masks, true)); + detection = new DetectedObjects(classes, probabilities, boxes); + + logger.info("Masks: {}", detection); + Assert.assertEquals( + detection.toJson(), + "[{\"boundingBox\":{\"rect\":[3,4,8,10],\"fullImageMask\":true,\"mask\":[[1.0,2.0,3.0],[4.0,5.0,6.0]]},\"className\":\"dog\",\"probability\":0.2},{\"boundingBox\":{\"rect\":[1,2,4,6],\"mask\":[[1.0,2.0,3.0],[4.0,5.0,6.0]]},\"className\":\"cat\",\"probability\":0.1}]"); + } +}