Skip to content

Commit

Permalink
Storage order fix in XFeat parser. (#153)
Browse files Browse the repository at this point in the history
* Storage order fix in XFeat parser.

* Docs update.

* Pre-commit fix.
  • Loading branch information
kkeroo authored Dec 17, 2024
1 parent 2797776 commit 97ef45a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ repos:
hooks:
- id: mdformat
additional_dependencies:
- mdformat-gfm
- mdformat-gfm==0.3.6
- mdformat-toc
30 changes: 15 additions & 15 deletions depthai_nodes/ml/parsers/xfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,29 +199,29 @@ def extractTensors(
self, output: dai.NNData
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extracts the tensors from the output. It returns the features, keypoints, and
heatmaps. It also handles the reshaping of the tensors.
heatmaps. It also handles the reshaping of the tensors by requesting the NCHW
storage order.
@param output: Output from the Neural Network node.
@type output: dai.NNData
@return: Tuple of features, keypoints, and heatmaps.
@rtype: Tuple[np.ndarray, np.ndarray, np.ndarray]
"""
feats = output.getTensor(self.output_layer_feats, dequantize=True).astype(
np.float32
)
feats = output.getTensor(
self.output_layer_feats,
dequantize=True,
storageOrder=dai.TensorInfo.StorageOrder.NCHW,
).astype(np.float32)
keypoints = output.getTensor(
self.output_layer_keypoints, dequantize=True
self.output_layer_keypoints,
dequantize=True,
storageOrder=dai.TensorInfo.StorageOrder.NCHW,
).astype(np.float32)
heatmaps = output.getTensor(
self.output_layer_heatmaps,
dequantize=True,
storageOrder=dai.TensorInfo.StorageOrder.NCHW,
).astype(np.float32)
heatmaps = output.getTensor(self.output_layer_heatmaps, dequantize=True).astype(
np.float32
)

if len(feats.shape) == 3:
feats = feats.reshape((1,) + feats.shape).transpose(0, 3, 1, 2)
if len(keypoints.shape) == 3:
keypoints = keypoints.reshape((1,) + keypoints.shape).transpose(0, 3, 1, 2)
if len(heatmaps.shape) == 3:
heatmaps = heatmaps.reshape((1,) + heatmaps.shape).transpose(0, 3, 1, 2)

return feats, keypoints, heatmaps

Expand Down

0 comments on commit 97ef45a

Please sign in to comment.