-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add parsers for HRNet and AgeGender models. (#16)
* feat: add support for age_gender model * feat: add support for HRNet model * fix: formatting and structure * fix: AgeGenderParser formatting and convert age to years * fix: HRNetParser formatting, remove comments, add normalization * fix: add timestamps to outgoing messages * Pre-commit fix. * Add Classifications msg to AgeGender. * Docstrings fix. --------- Co-authored-by: kkeroo <[email protected]>
- Loading branch information
Showing
7 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import List | ||
|
||
from ...messages import AgeGender, Classifications | ||
|
||
|
||
def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender: | ||
"""Create a DepthAI message for the age and gender probability. | ||
@param age: Detected person age (must be multiplied by 100 to get years). | ||
@type age: float | ||
@param gender_prob: Detected person gender probability [female, male]. | ||
@type gender_prob: List[float] | ||
@return: AgeGender message containing the predicted person's age and Classifications | ||
message containing the classes and probabilities of the predicted gender. | ||
@rtype: AgeGender | ||
@raise ValueError: If age is not a float. | ||
@raise ValueError: If gender_prob is not a list. | ||
@raise ValueError: If each item in gender_prob is not a float. | ||
""" | ||
|
||
if not isinstance(age, float): | ||
raise ValueError(f"age should be float, got {type(age)}.") | ||
|
||
if not isinstance(gender_prob, List): | ||
raise ValueError(f"gender_prob should be list, got {type(gender_prob)}.") | ||
for item in gender_prob: | ||
if not isinstance(item, float): | ||
raise ValueError( | ||
f"gender_prob list values must be of type float, instead got {type(item)}." | ||
) | ||
|
||
age_gender_message = AgeGender() | ||
age_gender_message.age = age | ||
gender = Classifications() | ||
gender.classes = ["female", "male"] | ||
gender.scores = gender_prob | ||
age_gender_message.gender = gender | ||
|
||
return age_gender_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import depthai as dai | ||
|
||
from ..messages import Classifications | ||
|
||
|
||
class AgeGender(dai.Buffer): | ||
def __init__(self): | ||
super().__init__() | ||
self._age: float = None | ||
self._gender = Classifications() | ||
|
||
@property | ||
def age(self) -> float: | ||
return self._age | ||
|
||
@age.setter | ||
def age(self, value: float): | ||
if not isinstance(value, float): | ||
raise TypeError( | ||
f"start_point must be of type float, instead got {type(value)}." | ||
) | ||
self._age = value | ||
|
||
@property | ||
def gender(self) -> Classifications: | ||
return self._gender | ||
|
||
@gender.setter | ||
def gender(self, value: Classifications): | ||
if not isinstance(value, Classifications): | ||
raise TypeError( | ||
f"gender must be of type Classifications, instead got {type(value)}." | ||
) | ||
self._gender = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import depthai as dai | ||
|
||
from ..messages.creators import create_age_gender_message | ||
|
||
|
||
class AgeGenderParser(dai.node.ThreadedHostNode): | ||
"""Parser class for parsing the output of the Age-Gender regression model. | ||
Attributes | ||
---------- | ||
input : Node.Input | ||
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. | ||
out : Node.Output | ||
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. | ||
Output Message/s | ||
---------------- | ||
**Type**: AgeGender | ||
**Description**: Message containing the detected person age and Classfications object for storing information about the detected person's gender. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initializes the AgeGenderParser node.""" | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException: | ||
break # Pipeline was stopped | ||
|
||
age = output.getTensor("age_conv3", dequantize=True).item() | ||
age *= 100 # convert to years | ||
prob = output.getTensor("prob", dequantize=True).flatten().tolist() | ||
|
||
age_gender_message = create_age_gender_message(age=age, gender_prob=prob) | ||
age_gender_message.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(age_gender_message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import depthai as dai | ||
import numpy as np | ||
|
||
from ..messages.creators import create_keypoints_message | ||
|
||
|
||
class HRNetParser(dai.node.ThreadedHostNode): | ||
"""Parser class for parsing the output of the HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation. | ||
Attributes | ||
---------- | ||
input : Node.Input | ||
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. | ||
out : Node.Output | ||
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. | ||
score_threshold : float | ||
Confidence score threshold for detected keypoints. | ||
Output Message/s | ||
---------------- | ||
**Type**: Keypoints | ||
**Description**: Keypoints message containing detected body keypoints. | ||
""" | ||
|
||
def __init__(self, score_threshold=0.5): | ||
"""Initializes the HRNetParser node. | ||
@param score_threshold: Confidence score threshold for detected keypoints. | ||
@type score_threshold: float | ||
""" | ||
dai.node.ThreadedHostNode.__init__(self) | ||
self.input = dai.Node.Input(self) | ||
self.out = dai.Node.Output(self) | ||
|
||
self.score_threshold = score_threshold | ||
|
||
def setScoreThreshold(self, threshold): | ||
"""Sets the confidence score threshold for the detected body keypoints. | ||
@param threshold: Confidence score threshold for detected keypoints. | ||
@type threshold: float | ||
""" | ||
self.score_threshold = threshold | ||
|
||
def run(self): | ||
while self.isRunning(): | ||
try: | ||
output: dai.NNData = self.input.get() | ||
except dai.MessageQueue.QueueException: | ||
break # Pipeline was stopped | ||
|
||
heatmaps = output.getTensor("heatmaps", dequantize=True) | ||
|
||
if len(heatmaps.shape) == 4: | ||
heatmaps = heatmaps[0] | ||
if heatmaps.shape[2] == 16: # HW_ instead of _HW | ||
heatmaps = heatmaps.transpose(2, 0, 1) | ||
_, map_h, map_w = heatmaps.shape | ||
|
||
scores = np.array([np.max(heatmap) for heatmap in heatmaps]) | ||
keypoints = np.array( | ||
[ | ||
np.unravel_index(heatmap.argmax(), heatmap.shape) | ||
for heatmap in heatmaps | ||
] | ||
) | ||
keypoints = keypoints.astype(np.float32) | ||
keypoints = keypoints[:, ::-1] / np.array( | ||
[map_w, map_h] | ||
) # normalize keypoints to [0, 1] | ||
|
||
keypoints_message = create_keypoints_message( | ||
keypoints=keypoints, | ||
scores=scores, | ||
confidence_threshold=self.score_threshold, | ||
) | ||
keypoints_message.setTimestamp(output.getTimestamp()) | ||
|
||
self.out.send(keypoints_message) |