From 2fd104aa81eb4e991d886801fd0ccba22fb760d7 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 29 Sep 2024 12:42:13 -0700 Subject: [PATCH] [api] Adds sam2 model to onnxruntime model zoo --- .../djl/modality/cv/BufferedImageFactory.java | 12 + .../main/java/ai/djl/modality/cv/Image.java | 10 + .../cv/translator/Sam2ServingTranslator.java | 78 +++++++ .../cv/translator/Sam2Translator.java | 205 ++++++++++++++--- .../cv/translator/Sam2TranslatorFactory.java | 6 + .../modality/cv/translator/Sam2InputTest.java | 48 ++++ .../translator/Sam2TranslatorFactoryTest.java | 50 ++++ .../ai/djl/onnxruntime/engine/OrtEngine.java | 8 +- .../ai/djl/onnxruntime/zoo/OrtModelZoo.java | 4 + examples/docs/segment_anything_2.md | 213 ++++++++++++------ .../inference/cv/SegmentAnything2.java | 15 +- .../main/java/ai/djl/opencv/OpenCVImage.java | 16 ++ .../ai/djl/opencv/OpenCVImageFactoryTest.java | 1 - 13 files changed, 559 insertions(+), 107 deletions(-) create mode 100644 api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java create mode 100644 api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java create mode 100644 api/src/test/java/ai/djl/modality/cv/translator/Sam2TranslatorFactoryTest.java diff --git a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java index 13a5b17189e7..bc314faf933a 100644 --- a/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java @@ -342,6 +342,18 @@ public void drawBoundingBoxes(DetectedObjects detections, float opacity) { g.dispose(); } + /** {@inheritDoc} */ + @Override + public void drawRectangle(Rectangle rect, int rgb, int thickness) { + Graphics2D g = (Graphics2D) image.getGraphics(); + g.setStroke(new BasicStroke(thickness)); + g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); + g.setPaint(new Color(rgb)); + int x = (int) rect.getX(); + int y = (int) rect.getY(); + g.drawRect(x, y, (int) rect.getWidth(), (int) rect.getHeight()); + } + /** {@inheritDoc} */ @Override public void drawMarks(List points, int radius) { diff --git a/api/src/main/java/ai/djl/modality/cv/Image.java b/api/src/main/java/ai/djl/modality/cv/Image.java index 89e78eb28630..07fe5f018091 100644 --- a/api/src/main/java/ai/djl/modality/cv/Image.java +++ b/api/src/main/java/ai/djl/modality/cv/Image.java @@ -16,6 +16,7 @@ import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Joints; import ai.djl.modality.cv.output.Point; +import ai.djl.modality.cv.output.Rectangle; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; @@ -137,6 +138,15 @@ default void drawBoundingBoxes(DetectedObjects detections) { */ void drawBoundingBoxes(DetectedObjects detections, float opacity); + /** + * Draws a rectangle on the image. + * + * @param rect the rectangle to draw + * @param rgb the color + * @param thickness the thickness + */ + void drawRectangle(Rectangle rect, int rgb, int thickness); + /** * Draws a mark on the image. * diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java new file mode 100644 index 000000000000..909c3d01f535 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2ServingTranslator.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.ndarray.BytesSupplier; +import ai.djl.ndarray.NDList; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import java.io.IOException; + +/** A {@link Translator} that can serve SAM2 model. */ +public class Sam2ServingTranslator implements Translator { + + private Sam2Translator translator; + + /** + * Constructs a new {@code Sam2ServingTranslator} instance. + * + * @param translator a {@code Sam2Translator} + */ + public Sam2ServingTranslator(Sam2Translator translator) { + this.translator = translator; + } + + /** {@inheritDoc} */ + @Override + public Batchifier getBatchifier() { + return translator.getBatchifier(); + } + + /** {@inheritDoc} */ + @Override + public Output processOutput(TranslatorContext ctx, NDList list) throws Exception { + Output output = new Output(); + output.addProperty("Content-Type", "application/json"); + DetectedObjects obj = translator.processOutput(ctx, list); + output.add(BytesSupplier.wrapAsJson(obj)); + return output; + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Input input) throws Exception { + BytesSupplier data = input.getData(); + try { + if (data == null) { + throw new TranslateException("Input data is empty."); + } + Sam2Input sam2 = Sam2Input.fromJson(data.getAsString()); + return translator.processInput(ctx, sam2); + } catch (IOException e) { + throw new TranslateException("Input is not an Image data type", e); + } + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws Exception { + translator.prepare(ctx); + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java index d67dc9e4b6dc..0c05c569f6db 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -21,6 +21,7 @@ import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Mask; import ai.djl.modality.cv.output.Point; +import ai.djl.modality.cv.output.Rectangle; import ai.djl.modality.cv.transform.Normalize; import ai.djl.modality.cv.transform.Resize; import ai.djl.modality.cv.transform.ToTensor; @@ -36,11 +37,13 @@ import ai.djl.translate.Pipeline; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import ai.djl.util.JsonUtils; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -96,14 +99,12 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except ctx.setAttachment("width", width); ctx.setAttachment("height", height); - List points = input.getPoints(); - int numPoints = points.size(); float[] buf = input.toLocationArray(width, height); NDManager manager = ctx.getNDManager(); NDArray array = image.toNDArray(manager, Image.Flag.COLOR); array = pipeline.transform(new NDList(array)).get(0).expandDims(0); - NDArray locations = manager.create(buf, new Shape(1, numPoints, 2)); + NDArray locations = manager.create(buf, new Shape(1, buf.length / 2, 2)); NDArray labels = manager.create(input.getLabels()); if (predictor == null) { @@ -125,7 +126,7 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except /** {@inheritDoc} */ @Override - public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws Exception { + public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { NDArray logits = list.get(0); NDArray scores = list.get(1).squeeze(0); long best = scores.argMax().getLong(); @@ -202,8 +203,8 @@ public Sam2Translator build() { public static final class Sam2Input { private Image image; - private List points; - private List labels; + private Point[] points; + private int[] labels; /** * Constructs a {@code Sam2Input} instance. @@ -212,7 +213,7 @@ public static final class Sam2Input { * @param points the locations on the image * @param labels the labels for the locations (0: background, 1: foreground) */ - public Sam2Input(Image image, List points, List labels) { + public Sam2Input(Image image, Point[] points, int[] labels) { this.image = image; this.points = points; this.labels = labels; @@ -233,11 +234,34 @@ public Image getImage() { * @return the locations */ public List getPoints() { - return points; + List list = new ArrayList<>(); + for (int i = 0; i < labels.length; ++i) { + if (labels[i] < 2) { + list.add(points[i]); + } + } + return list; + } + + /** + * Returns the box. + * + * @return the box + */ + public List getBoxes() { + List list = new ArrayList<>(); + for (int i = 0; i < labels.length; ++i) { + if (labels[i] == 2) { + double width = points[i + 1].getX() - points[i].getX(); + double height = points[i + 1].getY() - points[i].getY(); + list.add(new Rectangle(points[i], width, height)); + } + } + return list; } float[] toLocationArray(int width, int height) { - float[] ret = new float[points.size() * 2]; + float[] ret = new float[points.length * 2]; int i = 0; for (Point point : points) { ret[i++] = (float) point.getX() / width * 1024; @@ -247,43 +271,156 @@ float[] toLocationArray(int width, int height) { } float[][] getLabels() { - float[][] buf = new float[1][labels.size()]; - for (int i = 0; i < labels.size(); ++i) { - buf[0][i] = labels.get(i); + float[][] buf = new float[1][labels.length]; + for (int i = 0; i < labels.length; ++i) { + buf[0][i] = labels[i]; } return buf; } /** - * Creates a new {@code Sam2Input} instance with the image and a location. + * Constructs a {@code Sam2Input} instance from json string. * - * @param url the image url - * @param x the X of the location - * @param y the Y of the location - * @return a new {@code Sam2Input} instance - * @throws IOException if failed to read image + * @param input the json input + * @return a {@code Sam2Input} instance + * @throws IOException if failed to load the image */ - public static Sam2Input newInstance(String url, int x, int y) throws IOException { - Image image = ImageFactory.getInstance().fromUrl(url); - List points = Collections.singletonList(new Point(x, y)); - List labels = Collections.singletonList(1); - return new Sam2Input(image, points, labels); + public static Sam2Input fromJson(String input) throws IOException { + Prompt prompt = JsonUtils.GSON.fromJson(input, Prompt.class); + if (prompt.image == null) { + throw new IllegalArgumentException("Missing url value"); + } + if (prompt.prompt == null || prompt.prompt.length == 0) { + throw new IllegalArgumentException("Missing prompt value"); + } + Image image = ImageFactory.getInstance().fromUrl(prompt.image); + Builder builder = builder(image); + for (Location location : prompt.prompt) { + int[] data = location.data; + if ("point".equals(location.type)) { + builder.addPoint(data[0], data[1], location.label); + } else if ("rectangle".equals(location.type)) { + builder.addBox(data[0], data[1], data[2], data[3]); + } + } + return builder.build(); } /** - * Creates a new {@code Sam2Input} instance with the image and a location. + * Creates a builder to build a {@code Sam2Input} with the image. * - * @param path the image file path - * @param x the X of the location - * @param y the Y of the location - * @return a new {@code Sam2Input} instance - * @throws IOException if failed to read image + * @param image the image + * @return a new builder */ - public static Sam2Input newInstance(Path path, int x, int y) throws IOException { - Image image = ImageFactory.getInstance().fromFile(path); - List points = Collections.singletonList(new Point(x, y)); - List labels = Collections.singletonList(1); - return new Sam2Input(image, points, labels); + public static Builder builder(Image image) { + return new Builder(image); + } + + /** The builder for {@code Sam2Input}. */ + public static final class Builder { + + private Image image; + private List points; + private List labels; + + Builder(Image image) { + this.image = image; + points = new ArrayList<>(); + labels = new ArrayList<>(); + } + + /** + * Adds a point to the {@code Sam2Input}. + * + * @param x the X coordinate + * @param y the Y coordinate + * @return the builder + */ + public Builder addPoint(int x, int y) { + return addPoint(x, y, 1); + } + + /** + * Adds a point to the {@code Sam2Input}. + * + * @param x the X coordinate + * @param y the Y coordinate + * @param label the label of the point, 0 for background, 1 for foreground + * @return the builder + */ + public Builder addPoint(int x, int y, int label) { + return addPoint(new Point(x, y), label); + } + + /** + * Adds a point to the {@code Sam2Input}. + * + * @param point the point on image + * @param label the label of the point, 0 for background, 1 for foreground + * @return the builder + */ + public Builder addPoint(Point point, int label) { + points.add(point); + labels.add(label); + return this; + } + + /** + * Adds a box area to the {@code Sam2Input}. + * + * @param x the left coordinate + * @param y the top coordinate + * @param right the right coordinate + * @param bottom the bottom coordinate + * @return the builder + */ + public Builder addBox(int x, int y, int right, int bottom) { + addPoint(new Point(x, y), 2); + addPoint(new Point(right, bottom), 3); + return this; + } + + /** + * Builds the {@code Sam2Input}. + * + * @return the new {@code Sam2Input} + */ + public Sam2Input build() { + Point[] location = points.toArray(new Point[0]); + int[] array = labels.stream().mapToInt(Integer::intValue).toArray(); + return new Sam2Input(image, location, array); + } + } + + private static final class Location { + String type; + int[] data; + int label; + + public void setType(String type) { + this.type = type; + } + + public void setData(int[] data) { + this.data = data; + } + + public void setLabel(int label) { + this.label = label; + } + } + + private static final class Prompt { + String image; + Location[] prompt; + + public void setImage(String image) { + this.image = image; + } + + public void setPrompt(Location[] prompt) { + this.prompt = prompt; + } } } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java index 299b4b19b18d..e33a5617f116 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java @@ -13,6 +13,8 @@ package ai.djl.modality.cv.translator; import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; import ai.djl.translate.Translator; @@ -34,6 +36,7 @@ public class Sam2TranslatorFactory implements TranslatorFactory, Serializable { static { SUPPORTED_TYPES.add(new Pair<>(Sam2Input.class, DetectedObjects.class)); + SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); } /** {@inheritDoc} */ @@ -43,6 +46,9 @@ public Translator newInstance( Class input, Class output, Model model, Map arguments) { if (input == Sam2Input.class && output == DetectedObjects.class) { return (Translator) Sam2Translator.builder(arguments).build(); + } else if (input == Input.class && output == Output.class) { + Sam2Translator translator = Sam2Translator.builder(arguments).build(); + return (Translator) new Sam2ServingTranslator(translator); } throw new IllegalArgumentException("Unsupported input/output types."); } diff --git a/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java new file mode 100644 index 000000000000..9f9811bdaac8 --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/Sam2InputTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; + +public class Sam2InputTest { + + @Test + public void test() throws IOException { + Path file = Paths.get("../examples/src/test/resources/kitten.jpg"); + Image img = ImageFactory.getInstance().fromFile(file); + String json = + "{\"image\": \"" + + file.toUri().toURL() + + "\",\n" + + "\"prompt\": [\n" + + " {\"type\": \"point\", \"data\": [575, 750], \"label\": 0},\n" + + " {\"type\": \"rectangle\", \"data\": [425, 600, 700, 875]}\n" + + "]}"; + Sam2Input input = Sam2Input.fromJson(json); + Assert.assertEquals(input.getPoints().size(), 1); + Assert.assertEquals(input.getBoxes().size(), 1); + + input = Sam2Input.builder(img).addPoint(0, 1).addBox(0, 0, 1, 1).build(); + Assert.assertEquals(input.getPoints().size(), 1); + Assert.assertEquals(input.getBoxes().size(), 1); + } +} diff --git a/api/src/test/java/ai/djl/modality/cv/translator/Sam2TranslatorFactoryTest.java b/api/src/test/java/ai/djl/modality/cv/translator/Sam2TranslatorFactoryTest.java new file mode 100644 index 000000000000..40ead0c0bf90 --- /dev/null +++ b/api/src/test/java/ai/djl/modality/cv/translator/Sam2TranslatorFactoryTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.translate.Translator; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.HashMap; +import java.util.Map; + +public class Sam2TranslatorFactoryTest { + + @Test + public void testNewInstance() { + Sam2TranslatorFactory factory = new Sam2TranslatorFactory(); + Assert.assertEquals(factory.getSupportedTypes().size(), 2); + Map arguments = new HashMap<>(); + try (Model model = Model.newInstance("test")) { + Translator translator1 = + factory.newInstance(Sam2Input.class, DetectedObjects.class, model, arguments); + Assert.assertTrue(translator1 instanceof Sam2Translator); + + Translator translator5 = + factory.newInstance(Input.class, Output.class, model, arguments); + Assert.assertTrue(translator5 instanceof Sam2ServingTranslator); + + Assert.assertThrows( + IllegalArgumentException.class, + () -> factory.newInstance(Image.class, Output.class, model, arguments)); + } + } +} diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java index c3c2ea0eea0b..381e70027dd4 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngine.java @@ -77,7 +77,13 @@ OrtEnvironment getEnv() { @Override public Engine getAlternativeEngine() { if (!initialized && !Boolean.getBoolean("ai.djl.onnx.disable_alternative")) { - Engine engine = Engine.getInstance(); + Engine engine; + if (Engine.hasEngine("PyTorch")) { + // workaround MXNet engine issue on CI + engine = Engine.getEngine("PyTorch"); + } else { + engine = Engine.getInstance(); + } if (engine.getRank() < getRank()) { // alternativeEngine should not have the same rank as OnnxRuntime alternativeEngine = engine; diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java index 8d49d8931ee6..db5975f6d47e 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/zoo/OrtModelZoo.java @@ -31,6 +31,10 @@ public class OrtModelZoo extends ModelZoo { OrtModelZoo() { addModel(REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet", "0.0.1")); addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-base-plus", "0.0.1")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-small", "0.0.1")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolo5s", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); addModel(REPOSITORY.model(CV.POSE_ESTIMATION, GROUP_ID, "yolov8n-pose", "0.0.1")); diff --git a/examples/docs/segment_anything_2.md b/examples/docs/segment_anything_2.md index 4ffc11b9eab2..768a666b87f4 100644 --- a/examples/docs/segment_anything_2.md +++ b/examples/docs/segment_anything_2.md @@ -63,8 +63,9 @@ pip install sam2 transformers ### trace the model ```python +import os import sys -from typing import Tuple +from typing import Any import torch from sam2.modeling.sam2_base import SAM2Base @@ -72,77 +73,129 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor from torch import nn -class Sam2Wrapper(nn.Module): +class SAM2ImageEncoder(nn.Module): - def __init__( - self, - sam_model: SAM2Base, - ) -> None: + def __init__(self, sam_model: SAM2Base) -> None: super().__init__() self.model = sam_model + self.image_encoder = sam_model.image_encoder + self.no_mem_embed = sam_model.no_mem_embed - # Spatial dim for backbone feature maps - self._bb_feat_sizes = [ - (256, 256), - (128, 128), - (64, 64), - ] + def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: + backbone_out = self.image_encoder(x) + backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1]) - def extract_features( - self, - input_image: torch.Tensor, - ) -> (torch.Tensor, torch.Tensor, torch.Tensor): - backbone_out = self.model.forward_image(input_image) - _, vision_feats, _, _ = self.model._prepare_backbone_features( - backbone_out) - # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos - if self.model.directly_add_no_mem_embed: - vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feature_maps = backbone_out["backbone_fpn"][-self.model. + num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model. + num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_feats[-1] = vision_feats[-1] + self.no_mem_embed feats = [ - feat.permute(1, 2, - 0).view(1, -1, *feat_size) for feat, feat_size in zip( - vision_feats[::-1], self._bb_feat_sizes[::-1]) + feat.permute(1, 2, 0).reshape(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) ][::-1] - return feats[-1], feats[0], feats[1] + return feats[0], feats[1], feats[2] - def forward( - self, - input_image: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - image_embed, feature_1, feature_2 = self.extract_features(input_image) - return self.predict(point_coords, point_labels, image_embed, feature_1, - feature_2) - def predict( +class SAM2ImageDecoder(nn.Module): + + def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: + super().__init__() + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.model = sam_model + self.img_size = sam_model.image_size + self.multimask_output = multimask_output + self.sparse_embedding = None + + @torch.no_grad() + def forward( self, + image_embed: torch.Tensor, + high_res_feats_0: torch.Tensor, + high_res_feats_1: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, - image_embed: torch.Tensor, - feats_1: torch.Tensor, - feats_2: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - concat_points = (point_coords, point_labels) - - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=concat_points, - boxes=None, - masks=None, - ) - - low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=image_embed[0].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=True, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + self.sparse_embedding = sparse_embedding + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + high_res_feats = [high_res_feats_0, high_res_feats_1] + image_embed = image_embed + + masks, iou_predictions, _, _ = self.mask_decoder.predict_masks( + image_embeddings=image_embed, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, repeat_image=False, - high_res_features=[feats_1, feats_2], + high_res_features=high_res_feats, ) - return low_res_masks, iou_predictions + + if self.multimask_output: + masks = masks[:, 1:, :, :] + iou_predictions = iou_predictions[:, 1:] + else: + masks, iou_pred = ( + self.mask_decoder._dynamic_multimask_via_stability( + masks, iou_predictions)) + + masks = torch.clamp(masks, -32.0, 32.0) + + return masks, iou_predictions + + def _embed_points(self, point_coords: torch.Tensor, + point_labels: torch.Tensor) -> torch.Tensor: + + point_coords = point_coords + 0.5 + + padding_point = torch.zeros((point_coords.shape[0], 1, 2), + device=point_coords.device) + padding_label = -torch.ones( + (point_labels.shape[0], 1), device=point_labels.device) + point_coords = torch.cat([point_coords, padding_point], dim=1) + point_labels = torch.cat([point_labels, padding_label], dim=1) + + point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size + point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size + + point_embedding = self.prompt_encoder.pe_layer._pe_encoding( + point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = (point_embedding + + self.prompt_encoder.not_a_point_embed.weight * + (point_labels == -1)) + + for i in range(self.prompt_encoder.num_point_embeddings): + point_embedding = (point_embedding + + self.prompt_encoder.point_embeddings[i].weight * + (point_labels == i)) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, + has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling( + input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding def trace_model(model_id: str): @@ -151,19 +204,47 @@ def trace_model(model_id: str): else: device = torch.device("cpu") + model_name = f"{model_id[9:]}" + os.makedirs(model_name) + predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device) - model = Sam2Wrapper(predictor.model) + encoder = SAM2ImageEncoder(predictor.model) + decoder = SAM2ImageDecoder(predictor.model, True) input_image = torch.ones(1, 3, 1024, 1024).to(device) - input_point = torch.ones(1, 1, 2).to(device) - input_labels = torch.ones(1, 1, dtype=torch.int32, device=device) - - converted = torch.jit.trace_module( - model, { - "extract_features": input_image, - "forward": (input_image, input_point, input_labels) - }) - torch.jit.save(converted, f"{model_id[9:]}.pt") + high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image) + + converted = torch.jit.trace(encoder, input_image) + torch.jit.save(converted, f"model_name/encoder.pt") + + # trace decoder model + embed_size = ( + predictor.model.image_size // predictor.model.backbone_stride, + predictor.model.image_size // predictor.model.backbone_stride, + ) + mask_input_size = [4 * x for x in embed_size] + + point_coords = torch.randint(low=0, + high=1024, + size=(1, 5, 2), + dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float) + mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) + has_mask_input = torch.tensor([1], dtype=torch.float) + + converted = torch.jit.trace( + decoder, (image_embed, high_res_feats_0, high_res_feats_1, + point_coords, point_labels, mask_input, has_mask_input)) + torch.jit.save(converted, f"model_name/model_name.pt") + + # save serving.properties + serving_file = os.path.join(model_name, "serving.properties") + with open(serving_file, "w") as f: + f.write( + f"engine=PyTorch\n" + f"option.modelName={model_name}\n" + f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n" + f"encoder=encoder.pt") if __name__ == '__main__': diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java index 85afcb8789fd..ff7c0e5ff36e 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java @@ -12,11 +12,12 @@ */ package ai.djl.examples.inference.cv; -import ai.djl.Device; import ai.djl.ModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.output.Rectangle; import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; import ai.djl.modality.cv.translator.Sam2TranslatorFactory; import ai.djl.repository.zoo.Criteria; @@ -46,14 +47,15 @@ public static void main(String[] args) throws IOException, ModelException, Trans public static DetectedObjects predict() throws IOException, ModelException, TranslateException { String url = "https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg"; - Sam2Input input = Sam2Input.newInstance(url, 500, 375); + Image image = ImageFactory.getInstance().fromUrl(url); + Sam2Input input = + Sam2Input.builder(image).addPoint(575, 750).addBox(425, 600, 700, 875).build(); Criteria criteria = Criteria.builder() .setTypes(Sam2Input.class, DetectedObjects.class) - .optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny") - .optEngine("PyTorch") - .optDevice(Device.cpu()) // use sam2-hiera-tiny-gpu for GPU + .optModelUrls("djl://ai.djl.onnxruntime/sam2-hiera-tiny") + // .optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny") // for PyTorch .optTranslatorFactory(new Sam2TranslatorFactory()) .optProgress(new ProgressBar()) .build(); @@ -73,6 +75,9 @@ private static void showMask(Sam2Input input, DetectedObjects detection) throws Image img = input.getImage(); img.drawBoundingBoxes(detection, 0.8f); img.drawMarks(input.getPoints()); + for (Rectangle rect : input.getBoxes()) { + img.drawRectangle(rect, 0xff0000, 6); + } Path imagePath = outputDir.resolve("sam2.png"); img.save(Files.newOutputStream(imagePath), "png"); diff --git a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java index 6651a8981679..04362ce46971 100644 --- a/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java +++ b/extensions/opencv/src/main/java/ai/djl/opencv/OpenCVImage.java @@ -200,6 +200,22 @@ public void drawBoundingBoxes(DetectedObjects detections, float opacity) { } } + /** {@inheritDoc} */ + @Override + public void drawRectangle(Rectangle rectangle, int rgb, int stroke) { + Rect rect = + new Rect( + (int) rectangle.getX(), + (int) rectangle.getY(), + (int) rectangle.getWidth(), + (int) rectangle.getHeight()); + int r = (rgb & 0xff0000) >> 16; + int g = (rgb & 0x00ff00) >> 8; + int b = rgb & 0x0000ff; + Scalar color = new Scalar(b, g, r); + Imgproc.rectangle(image, rect.tl(), rect.br(), color, stroke); + } + /** {@inheritDoc} */ @Override public void drawMarks(List points, int radius) { diff --git a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java index a9f41d087d06..64be713891a3 100644 --- a/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java +++ b/extensions/opencv/src/test/java/ai/djl/opencv/OpenCVImageFactoryTest.java @@ -47,7 +47,6 @@ public class OpenCVImageFactoryTest { @BeforeClass public void setUp() { TestRequirements.notWindows(); // failed on Windows ServerCore container - TestRequirements.notArm(); } @Test