Skip to content

Commit

Permalink
Add parsers for HRNet and AgeGender models. (#16)
Browse files Browse the repository at this point in the history
* feat: add support for age_gender model

* feat: add support for HRNet model

* fix: formatting and structure

* fix: AgeGenderParser formatting and convert age to years

* fix: HRNetParser formatting, remove comments, add normalization

* fix: add timestamps to outgoing messages

* Pre-commit fix.

* Add Classifications msg to AgeGender.

* Docstrings fix.

---------

Co-authored-by: kkeroo <[email protected]>
  • Loading branch information
jkbmrz and kkeroo authored Aug 28, 2024
1 parent d6dc43d commit 1043366
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 0 deletions.
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .img_detections import ImgDetectionsWithKeypoints, ImgDetectionWithKeypoints
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines
from .misc import AgeGender

__all__ = [
"ImgDetectionWithKeypoints",
Expand All @@ -11,4 +12,5 @@
"Line",
"Lines",
"Classifications",
"AgeGender",
]
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .detection import create_detection_message, create_line_detection_message
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .misc import create_age_gender_message
from .segmentation import create_segmentation_message
from .thermal import create_thermal_message
from .tracked_features import create_tracked_features_message
Expand All @@ -18,4 +19,5 @@
"create_keypoints_message",
"create_thermal_message",
"create_classification_message",
"create_age_gender_message",
]
39 changes: 39 additions & 0 deletions depthai_nodes/ml/messages/creators/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

from ...messages import AgeGender, Classifications


def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender:
"""Create a DepthAI message for the age and gender probability.
@param age: Detected person age (must be multiplied by 100 to get years).
@type age: float
@param gender_prob: Detected person gender probability [female, male].
@type gender_prob: List[float]
@return: AgeGender message containing the predicted person's age and Classifications
message containing the classes and probabilities of the predicted gender.
@rtype: AgeGender
@raise ValueError: If age is not a float.
@raise ValueError: If gender_prob is not a list.
@raise ValueError: If each item in gender_prob is not a float.
"""

if not isinstance(age, float):
raise ValueError(f"age should be float, got {type(age)}.")

if not isinstance(gender_prob, List):
raise ValueError(f"gender_prob should be list, got {type(gender_prob)}.")
for item in gender_prob:
if not isinstance(item, float):
raise ValueError(
f"gender_prob list values must be of type float, instead got {type(item)}."
)

age_gender_message = AgeGender()
age_gender_message.age = age
gender = Classifications()
gender.classes = ["female", "male"]
gender.scores = gender_prob
age_gender_message.gender = gender

return age_gender_message
34 changes: 34 additions & 0 deletions depthai_nodes/ml/messages/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import depthai as dai

from ..messages import Classifications


class AgeGender(dai.Buffer):
def __init__(self):
super().__init__()
self._age: float = None
self._gender = Classifications()

@property
def age(self) -> float:
return self._age

@age.setter
def age(self, value: float):
if not isinstance(value, float):
raise TypeError(
f"start_point must be of type float, instead got {type(value)}."
)
self._age = value

@property
def gender(self) -> Classifications:
return self._gender

@gender.setter
def gender(self, value: Classifications):
if not isinstance(value, Classifications):
raise TypeError(
f"gender must be of type Classifications, instead got {type(value)}."
)
self._gender = value
4 changes: 4 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .age_gender import AgeGenderParser
from .classification import ClassificationParser
from .hrnet import HRNetParser
from .image_output import ImageOutputParser
from .keypoints import KeypointParser
from .mediapipe_hand_landmarker import MPHandLandmarkParser
Expand Down Expand Up @@ -26,4 +28,6 @@
"XFeatParser",
"ThermalImageParser",
"ClassificationParser",
"AgeGenderParser",
"HRNetParser",
]
43 changes: 43 additions & 0 deletions depthai_nodes/ml/parsers/age_gender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import depthai as dai

from ..messages.creators import create_age_gender_message


class AgeGenderParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the Age-Gender regression model.
Attributes
----------
input : Node.Input
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node.
out : Node.Output
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved.
Output Message/s
----------------
**Type**: AgeGender
**Description**: Message containing the detected person age and Classfications object for storing information about the detected person's gender.
"""

def __init__(self):
"""Initializes the AgeGenderParser node."""
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

age = output.getTensor("age_conv3", dequantize=True).item()
age *= 100 # convert to years
prob = output.getTensor("prob", dequantize=True).flatten().tolist()

age_gender_message = create_age_gender_message(age=age, gender_prob=prob)
age_gender_message.setTimestamp(output.getTimestamp())

self.out.send(age_gender_message)
80 changes: 80 additions & 0 deletions depthai_nodes/ml/parsers/hrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import depthai as dai
import numpy as np

from ..messages.creators import create_keypoints_message


class HRNetParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation.
Attributes
----------
input : Node.Input
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node.
out : Node.Output
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved.
score_threshold : float
Confidence score threshold for detected keypoints.
Output Message/s
----------------
**Type**: Keypoints
**Description**: Keypoints message containing detected body keypoints.
"""

def __init__(self, score_threshold=0.5):
"""Initializes the HRNetParser node.
@param score_threshold: Confidence score threshold for detected keypoints.
@type score_threshold: float
"""
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
self.out = dai.Node.Output(self)

self.score_threshold = score_threshold

def setScoreThreshold(self, threshold):
"""Sets the confidence score threshold for the detected body keypoints.
@param threshold: Confidence score threshold for detected keypoints.
@type threshold: float
"""
self.score_threshold = threshold

def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

heatmaps = output.getTensor("heatmaps", dequantize=True)

if len(heatmaps.shape) == 4:
heatmaps = heatmaps[0]
if heatmaps.shape[2] == 16: # HW_ instead of _HW
heatmaps = heatmaps.transpose(2, 0, 1)
_, map_h, map_w = heatmaps.shape

scores = np.array([np.max(heatmap) for heatmap in heatmaps])
keypoints = np.array(
[
np.unravel_index(heatmap.argmax(), heatmap.shape)
for heatmap in heatmaps
]
)
keypoints = keypoints.astype(np.float32)
keypoints = keypoints[:, ::-1] / np.array(
[map_w, map_h]
) # normalize keypoints to [0, 1]

keypoints_message = create_keypoints_message(
keypoints=keypoints,
scores=scores,
confidence_threshold=self.score_threshold,
)
keypoints_message.setTimestamp(output.getTimestamp())

self.out.send(keypoints_message)

0 comments on commit 1043366

Please sign in to comment.