Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization methods to all message types #141

Merged
merged 23 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3a337e6
added visualizationMessage to some messages
aljazkonec1 Nov 18, 2024
a953c9c
added visualizationMessage to some messages
aljazkonec1 Nov 18, 2024
4d3a295
remove print
aljazkonec1 Nov 19, 2024
012e165
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 19, 2024
a185ab8
remove print
aljazkonec1 Nov 19, 2024
e5d057b
added annotations to all message types
aljazkonec1 Nov 26, 2024
00fa3c1
timestamps and removed duplicated code
aljazkonec1 Nov 27, 2024
e63f66d
classification updates
aljazkonec1 Nov 27, 2024
23e6c55
added dai.ImgTransformations attributes for messages.
aljazkonec1 Nov 28, 2024
905f522
rebase
aljazkonec1 Nov 18, 2024
850cedf
remove print
aljazkonec1 Nov 19, 2024
43be54c
rebase
aljazkonec1 Nov 18, 2024
73ebfcf
remove print
aljazkonec1 Nov 19, 2024
b4ec282
added annotations to all message types
aljazkonec1 Nov 28, 2024
f152ba0
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 28, 2024
4432afd
fixed shearing of frames.
aljazkonec1 Nov 29, 2024
581fbf8
Merge branch 'main' into feat/detection_visualization_node
aljazkonec1 Nov 29, 2024
9cf6790
added points back
aljazkonec1 Nov 29, 2024
13a9ee0
Merge branch 'feat/detection_visualization_node' of https://github.co…
aljazkonec1 Nov 29, 2024
605cea2
Updated annotations
aljazkonec1 Dec 9, 2024
73e0651
pre-commit error
aljazkonec1 Dec 9, 2024
d037f3c
Merge branch 'main' into feat/detection_visualization_node
aljazkonec1 Dec 9, 2024
3f4ff72
pre-commit
aljazkonec1 Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions depthai_nodes/ml/helpers/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import depthai as dai

OUTLINE_COLOR = dai.Color(1.0, 0.5, 0.5, 1.0)
TEXT_COLOR = dai.Color(0.5, 0.5, 1.0, 1.0)
BACKGROUND_COLOR = dai.Color(1.0, 1.0, 0.5, 1.0)
KEYPOINT_COLOR = dai.Color(1.0, 0.35, 0.367, 1.0)
30 changes: 30 additions & 0 deletions depthai_nodes/ml/messages/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Classifications(dai.Buffer):
"""Classification class for storing the classes and their respective scores.
Expand Down Expand Up @@ -91,3 +96,28 @@ def top_score(self) -> float:
@rtype: float
"""
return self._scores[0]

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for classification.

The message adds the top five classes and their scores to the right side of the
image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

soritng_indexes = np.argsort(self._scores)[::-1]
soritng_indexes = soritng_indexes[:5]
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved

for i in soritng_indexes:
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{self._classes[i]} {self._scores[i]:.2f}"
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
30 changes: 29 additions & 1 deletion depthai_nodes/ml/messages/clusters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import List

import cv2
import depthai as dai
import numpy as np


class Cluster(dai.Buffer):
Expand All @@ -18,7 +20,7 @@ def __init__(self):
"""Initializes the Cluster object."""
super().__init__()
self._label: int = None
self.points: List[dai.Point2f] = []
self._points: List[dai.Point2f] = []

@property
def label(self) -> int:
Expand Down Expand Up @@ -103,3 +105,29 @@ def clusters(self, value: List[Cluster]):
if not all(isinstance(cluster, Cluster) for cluster in value):
raise ValueError("Clusters must be a list of Cluster objects.")
self._clusters = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for clusters and colors each one
separately."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

num_clusters = len(self.clusters)
color_mask = np.array(range(0, 255, 255 // num_clusters), dtype=np.uint8)
color_mask = cv2.applyColorMap(color_mask, cv2.COLORMAP_RAINBOW)
color_mask = color_mask / 255
color_mask = color_mask.reshape(-1, 3)

for i, cluster in enumerate(self.clusters):
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_LOOP
pointsAnnotation.points = dai.VectorPoint2f(cluster.points)
r, g, b = color_mask[i]
color = dai.Color(r, g, b)
pointsAnnotation.outlineColor = color
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
11 changes: 11 additions & 0 deletions depthai_nodes/ml/messages/creators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,14 @@ def create_image_message(
imgFrame.setType(img_frame_type)

return imgFrame


def getVisualizationMessage(self) -> dai.ImgFrame:
img_frame = dai.ImgFrame()
mask = self._map.copy()
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
if np.any(mask < 1):
mask = mask * 255
mask = mask.astype(np.uint8)

colored_mask = cv2.applyColorMap(mask, cv2.COLORMAP_PLASMA)
return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
55 changes: 54 additions & 1 deletion depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import numpy as np
from numpy.typing import NDArray

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
KEYPOINT_COLOR,
OUTLINE_COLOR,
TEXT_COLOR,
)

from .keypoints import Keypoint
from .segmentation import SegmentationMask

Expand Down Expand Up @@ -196,7 +203,7 @@ def detections(self, value: List[ImgDetectionExtended]):
self._detections = value

@property
def masks(self) -> NDArray[np.int8]:
def masks(self) -> NDArray[np.int16]:
"""Returns the segmentation masks stored in a single numpy array.

@return: Segmentation masks.
Expand Down Expand Up @@ -226,3 +233,49 @@ def masks(self, value: NDArray[np.int16]):
masks_msg = SegmentationMask()
masks_msg.mask = value
self._masks = masks_msg

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for the detections.

The message adds the bounding boxes, labels and keypoints to the image
annotations.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for detection in self.detections:
detection: ImgDetectionExtended = detection
rotated_rect = detection.rotated_rect
points = rotated_rect.getPoints()
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_STRIP
pointsAnnotation.points = dai.VectorPoint2f(points)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

text = dai.TextAnnotation()
text.position = points[0]
text.text = f"{detection.label_name} {int(detection.confidence * 100)}%"
text.fontSize = 50.5
kkeroo marked this conversation as resolved.
Show resolved Hide resolved
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

if len(detection.keypoints) > 0:
keypoints = [
dai.Point2f(keypoint.x, keypoint.y)
for keypoint in detection.keypoints
]
keypointAnnotation = dai.PointsAnnotation()
keypointAnnotation.type = (
dai.PointsAnnotationType.LINE_STRIP
) # change to POINTS when fixed
keypointAnnotation.points = dai.VectorPoint2f(keypoints)
keypointAnnotation.outlineColor = KEYPOINT_COLOR
keypointAnnotation.thickness = 2.0
annotation.points.append(keypointAnnotation)

klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
23 changes: 23 additions & 0 deletions depthai_nodes/ml/messages/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import KEYPOINT_COLOR


class Keypoint(dai.Buffer):
"""Keypoint class for storing a keypoint.
Expand Down Expand Up @@ -157,3 +159,24 @@ def keypoints(self, value: List[Keypoint]):
if not all(isinstance(item, Keypoint) for item in value):
raise ValueError("keypoints must be a list of Keypoint objects.")
self._keypoints = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Creates a default visualization message for the keypoints."""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

keypoints = [dai.Point2f(keypoint.x, keypoint.y) for keypoint in self.keypoints]

pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = (
dai.PointsAnnotationType.LINE_LOOP
) # change when points get adjusted
pointsAnnotation.points = dai.VectorPoint2f(keypoints)
pointsAnnotation.outlineColor = KEYPOINT_COLOR
pointsAnnotation.fillColor = KEYPOINT_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
24 changes: 24 additions & 0 deletions depthai_nodes/ml/messages/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import OUTLINE_COLOR


class Line(dai.Buffer):
"""Line class for storing a line.
Expand Down Expand Up @@ -130,3 +132,25 @@ def lines(self, value: List[Line]):
if not all(isinstance(item, Line) for item in value):
raise ValueError("Lines must be a list of Line objects.")
self._lines = value

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns default visualization message for lines.

The message adds lines to the image.
"""
img_annotation = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for line in self.lines:
pointsAnnotation = dai.PointsAnnotation()
pointsAnnotation.type = dai.PointsAnnotationType.LINE_LOOP
pointsAnnotation.points = dai.VectorPoint2f(
[line.start_point, line.end_point]
)
pointsAnnotation.outlineColor = OUTLINE_COLOR
pointsAnnotation.thickness = 2.0
annotation.points.append(pointsAnnotation)

img_annotation.annotations.append(annotation)
img_annotation.setTimestamp(self.getTimestamp())
return img_annotation
13 changes: 13 additions & 0 deletions depthai_nodes/ml/messages/map.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -71,3 +72,15 @@ def height(self) -> int:
@rtype: int
"""
return self._height

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns default visualization message for 2D maps in the form of a
colormapped image."""
img_frame = dai.ImgFrame()
mask = self._map.copy()
if np.any(mask < 1):
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
mask = mask * 255
mask = mask.astype(np.uint8)

colored_mask = cv2.applyColorMap(mask, cv2.COLORMAP_PLASMA)
return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import depthai as dai

from depthai_nodes.ml.helpers.constants import (
BACKGROUND_COLOR,
TEXT_COLOR,
)


class Prediction(dai.Buffer):
"""Prediction class for storing a prediction.
Expand Down Expand Up @@ -89,3 +94,24 @@ def prediction(self) -> float:
@rtype: float
"""
return self._predictions[0].prediction

def getVisualizationMessage(self) -> dai.ImgAnnotations:
"""Returns the visualization message for the predictions.

The message adds text representing the predictions to the right of the image.
"""
img_annotations = dai.ImgAnnotations()
annotation = dai.ImgAnnotation()

for i, prediction in enumerate(self.predictions):
text = dai.TextAnnotation()
text.position = dai.Point2f(1.05, 0.1 + i * 0.1)
text.text = f"{prediction.prediction:.2f}"
text.fontSize = 15
text.textColor = TEXT_COLOR
text.backgroundColor = BACKGROUND_COLOR
annotation.texts.append(text)

img_annotations.annotations.append(annotation)
img_annotations.setTimestamp(self.getTimestamp())
return img_annotations
26 changes: 26 additions & 0 deletions depthai_nodes/ml/messages/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import depthai as dai
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -48,3 +49,28 @@ def mask(self, value: NDArray[np.int16]):
if np.any((value < -1)):
raise ValueError("Mask must be an array of integers larger or equal to -1.")
self._mask = value

def getVisualizationMessage(self) -> dai.ImgFrame:
"""Returns the default visualization message for segmentation masks."""
img_frame = dai.ImgFrame()
mask = self._mask.copy()

unique_values = np.unique(mask[mask >= 0])
scaled_mask = np.zeros_like(mask, dtype=np.uint8)

if unique_values.size == 0:
return img_frame.setCvFrame(scaled_mask, dai.ImgFrame.Type.BGR888i)

min_val, max_val = unique_values.min(), unique_values.max()

if min_val == max_val:
scaled_mask = np.ones_like(mask, dtype=np.uint8) * 255
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
else:
scaled_mask = ((mask - min_val) / (max_val - min_val) * 255).astype(
np.uint8
)
scaled_mask[mask == -1] = 0
jkbmrz marked this conversation as resolved.
Show resolved Hide resolved
colored_mask = cv2.applyColorMap(scaled_mask, cv2.COLORMAP_RAINBOW)
colored_mask[mask == -1] = [0, 0, 0]

return img_frame.setCvFrame(colored_mask, dai.ImgFrame.Type.BGR888i)
Loading