-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Add topdown pose estimator and keypoint codecs (#1493)
* add BasePoseEstimator class * add base topdown modules * add heatmap head * handle compatibility to old version state dict * add codec interface and registry * add docstrings * add codecs * codec use [N,K,C]/[K,C,H,W] as keypoints/heatmaps shape * refactor pipeline and heatmap head based on codec * raise error if gt_instances is missing * Rename classfication_loss.py to classification_loss.py fix a typo: classfication -> classification * organize tests/ * add unittest * remove deprecated unittest * add codec unittest * resolve comments * update docstrings * organize tests * add error test in ut * add test_legacy * reorganize tests/ * fix unittest Co-authored-by: Tau <[email protected]>
- Loading branch information
Showing
392 changed files
with
6,392 additions
and
3,974 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .heatmap import (generate_megvii_heatmap, generate_msra_heatmap, | ||
generate_udp_heatmap) | ||
from .codecs import MegviiHeatmap, MSRAHeatmap, UDPHeatmap | ||
from .transforms import flip_keypoints | ||
|
||
__all__ = [ | ||
'flip_keypoints', 'generate_megvii_heatmap', 'generate_msra_heatmap', | ||
'generate_udp_heatmap' | ||
] | ||
__all__ = ['flip_keypoints', 'MegviiHeatmap', 'MSRAHeatmap', 'UDPHeatmap'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .megvii_heatmap import MegviiHeatmap | ||
from .msra_heatmap import MSRAHeatmap | ||
from .udp_heatmap import UDPHeatmap | ||
|
||
__all__ = ['MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Any, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
class BaseKeypointCodec(metaclass=ABCMeta): | ||
"""The base class of the keypoint codec. | ||
A keypoint codec is a module to encode keypoint coordinates to specific | ||
representation (e.g. heatmap) and vice versa. A subclass should implement | ||
the methods :meth:`encode` and :meth:`decode`. | ||
""" | ||
|
||
@abstractmethod | ||
def encode(self, | ||
keypoints: np.ndarray, | ||
keypoints_visible: Optional[np.ndarray] = None) -> Any: | ||
"""Encode keypoints. | ||
Note: | ||
- instance number: N | ||
- keypoint number: K | ||
- keypoint dimension: C | ||
Args: | ||
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C) | ||
keypoints_visible (np.ndarray): Keypoint visibility in shape | ||
(N, K, C) | ||
""" | ||
|
||
@abstractmethod | ||
def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Decode keypoints. | ||
Args: | ||
encoded (any): Encoded keypoint representation using the codec | ||
Returns: | ||
tuple: | ||
- keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C) | ||
- keypoints_visible (np.ndarray): Keypoint visibility in shape | ||
(N, K, C) | ||
""" | ||
|
||
def keypoints_bbox2img(self, keypoints: np.ndarray, | ||
bbox_centers: np.ndarray, | ||
bbox_scales: np.ndarray) -> np.ndarray: | ||
"""Convert decoded keypoints from the bbox space to the image space. | ||
Topdown codecs should override this method. | ||
Args: | ||
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C). | ||
The coordinate is in the bbox space | ||
bbox_centers (np.ndarray): Bbox centers in shape (N, 2). | ||
See `pipelines.GetBboxCenterScale` for details | ||
bbox_scale (np.ndarray): Bbox scales in shape (N, 2). | ||
See `pipelines.GetBboxCenterScale` for details | ||
Returns: | ||
np.ndarray: The transformed keypoints in shape (N, K, C). | ||
The coordinate is in the image space. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from itertools import product | ||
from typing import Optional, Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from mmpose.registry import KEYPOINT_CODECS | ||
from .base import BaseKeypointCodec | ||
from .utils import gaussian_blur, get_heatmap_maximum | ||
|
||
|
||
@KEYPOINT_CODECS.register_module() | ||
class MegviiHeatmap(BaseKeypointCodec): | ||
"""Represent keypoints as heatmaps via "Megvii" approach. See `MSPN`_ | ||
(2019) and `CPN`_ (2018) for details. | ||
Note: | ||
- instance number: N | ||
- keypoint number: K | ||
- keypoint dimension: C | ||
- image size: [w, h] | ||
- heatmap size: [W, H] | ||
Args: | ||
input_size (tuple): Image size in [w, h] | ||
heatmap_size (tuple): Heatmap size in [W, H] | ||
kernel_size (tuple): The kernel size of the heatmap gaussian in | ||
[ks_x, ks_y] | ||
.. _`MSPN`: https://arxiv.org/abs/1901.00148 | ||
.. _`CPN`: https://arxiv.org/abs/1711.07319 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_size: Tuple[int, int], | ||
heatmap_size: Tuple[int, int], | ||
kernel_size: int, | ||
) -> None: | ||
|
||
super().__init__() | ||
self.input_size = input_size | ||
self.heatmap_size = heatmap_size | ||
self.kernel_size = kernel_size | ||
self.scale_factor = (np.array(input_size) / | ||
heatmap_size).astype(np.float32) | ||
|
||
def encode( | ||
self, | ||
keypoints: np.ndarray, | ||
keypoints_visible: Optional[np.ndarray] = None | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Encode keypoints into heatmaps. Note that the original keypoint | ||
coordinates should be in the input image space. | ||
Args: | ||
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, C) | ||
keypoints_visible (np.ndarray): Keypoint visibilities in shape | ||
(N, K) | ||
Returns: | ||
tuple: | ||
- heatmaps (np.ndarray): The generated heatmap in shape | ||
(K, H, W) where [W, H] is the `heatmap_size` | ||
- keypoint_weights (np.ndarray): The target weights in shape | ||
(N, K) | ||
""" | ||
|
||
N, K, _ = keypoints.shape | ||
W, H = self.heatmap_size | ||
|
||
assert N == 1, ( | ||
f'{self.__class__.__name__} only support single-instance ' | ||
'keypoint encoding') | ||
|
||
heatmaps = np.zeros((K, H, W), dtype=np.float32) | ||
keypoint_weights = keypoints_visible.copy() | ||
|
||
for n, k in product(range(N), range(K)): | ||
# skip unlabled keypoints | ||
if keypoints_visible[n, k] < 0.5: | ||
continue | ||
|
||
# get center coordinates | ||
kx, ky = (keypoints[n, k] / self.scale_factor).astype(np.int64) | ||
if kx < 0 or kx >= W or ky < 0 or ky >= H: | ||
keypoint_weights[n, k] = 0 | ||
continue | ||
|
||
heatmaps[k, ky, kx] = 1. | ||
kernel_size = (self.kernel_size, self.kernel_size) | ||
heatmaps[k] = cv2.GaussianBlur(heatmaps[k], kernel_size, 0) | ||
|
||
# normalize the heatmap | ||
heatmaps[k] = heatmaps[k] / heatmaps[k, ky, kx] * 255. | ||
|
||
return heatmaps, keypoint_weights | ||
|
||
def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Decode keypoint coordinates from heatmaps. The decoded keypoint | ||
coordinates are in the input image space. | ||
Args: | ||
encoded (np.ndarray): Heatmaps in shape (K, H, W) | ||
Returns: | ||
tuple: | ||
- keypoints (np.ndarray): Decoded keypoint coordinates in shape | ||
(K, C) | ||
- scores (np.ndarray): The keypoint scores in shape (K,). It | ||
usually represents the confidence of the keypoint prediction | ||
""" | ||
heatmaps = gaussian_blur(encoded.copy(), self.kernel_size) | ||
K, H, W = heatmaps.shape | ||
|
||
keypoints, scores = get_heatmap_maximum(heatmaps) | ||
|
||
for k in range(K): | ||
heatmap = heatmaps[k] | ||
px = int(keypoints[k, 0]) | ||
py = int(keypoints[k, 1]) | ||
if 1 < px < W - 1 and 1 < py < H - 1: | ||
diff = np.array([ | ||
heatmap[py][px + 1] - heatmap[py][px - 1], | ||
heatmap[py + 1][px] - heatmap[py - 1][px] | ||
]) | ||
keypoints[k] += (np.sign(diff) * 0.25 + 0.5) | ||
|
||
scores = scores / 255.0 + 0.5 | ||
|
||
# Unsqueeze the instance dimension for single-instance results | ||
# and restore the keypoint scales | ||
keypoints = keypoints[None] * self.scale_factor | ||
scores = scores[None] | ||
|
||
return keypoints, scores |
Oops, something went wrong.