Skip to content

Commit

Permalink
Support Posetrack18 (#220)
Browse files Browse the repository at this point in the history
* add posetrack

* update posetrack18

* add config

* init posetrack dataset

* update posetrack18

* fix dependencies

* make pred dir

* fix dependencies

* fix docs

* fix build error

* rm dataclass

* rm unnecessary cpu/cuda mapping

* add @

* fix clone to copy

* fix return values

* add posetrack test

* use url model path

* add hrnet

* use gt bbox

* add pre-trained models

* fix build

* update modelzoo

* add tests

* add tests

* add tests

* compress the jpeg images
  • Loading branch information
jin-s13 authored Oct 30, 2020
1 parent 1436a45 commit 3ef8e6d
Show file tree
Hide file tree
Showing 25 changed files with 4,366 additions and 21 deletions.
18 changes: 18 additions & 0 deletions configs/top_down/hrnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,21 @@
| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AP (E) | AP (M) | AP (H) | ckpt | log |
| :----------------- | :-----------: | :------: | :------: | :------: | :------: | :------: |:------: |:------: | :------: |
| [pose_hrnet_w32](/configs/top_down/hrnet/crowdpose/hrnet_w32_crowdpose_256x192.py) | 256x192 | 0.675 | 0.825 | 0.729 | 0.768 | 0.687 | 0.554 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192-9b538d47_20201017.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_crowdpose_256x192_20201017.log.json) |


#### Results on PoseTrack2018 val with ground-truth bounding boxes.

| Arch | Input Size | Head | Shou | Elb | Wri | Hip | Knee | Ankl | Total | ckpt | log |
| :--- | :--------: | :------: |:------: |:------: |:------: |:------: |:------: | :------: | :------: |:------: |:------: |
| [pose_hrnet_w32](/configs/top_down/hrnet/posetrack18/hrnet_w32_posetrack18_256x192.py) | 256x192 | 87.4 | 88.6 | 84.3 | 78.5 | 79.7 | 81.8 | 78.8 | 83.0 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_posetrack18_256x192-1ee951c4_20201028.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_posetrack18_256x192_20201028.log.json) |

The models are first pre-trained on COCO dataset, and then fine-tuned on PoseTrack18.


#### Results on PoseTrack2018 val with [MMDetection](https://github.com/open-mmlab/mmdetection) pre-trained [Cascade R-CNN](http://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco_20200509_224357-051557b1.pth) (X-101-64x4d-FPN) human detector.

| Arch | Input Size | Head | Shou | Elb | Wri | Hip | Knee | Ankl | Total | ckpt | log |
| :--- | :--------: | :------: |:------: |:------: |:------: |:------: |:------: | :------: | :------: |:------: |:------: |
| [pose_hrnet_w32](/configs/top_down/hrnet/posetrack18/hrnet_w32_posetrack18_256x192.py) | 256x192 | 78.0 | 82.9 | 79.5 | 73.8 | 76.9 | 76.6 | 70.2 | 76.9 | [ckpt](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_posetrack18_256x192-1ee951c4_20201028.pth) | [log](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_posetrack18_256x192_20201028.log.json) |

The models are first pre-trained on COCO dataset, and then fine-tuned on PoseTrack18.
174 changes: 174 additions & 0 deletions configs/top_down/hrnet/posetrack18/hrnet_w32_posetrack18_256x192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
log_level = 'INFO'
load_from = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth' # noqa: E501
resume_from = None
dist_params = dict(backend='nccl')
workflow = [('train', 1)]
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, metric='mAP', key_indicator='Total AP')

optimizer = dict(
type='Adam',
lr=5e-4,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[10, 15])
total_epochs = 20
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))),
),
keypoint_head=dict(
type='TopDownSimpleHead',
in_channels=32,
out_channels=channel_cfg['num_output_channels'],
num_deconv_layers=0,
extra=dict(final_conv_kernel=1, ),
),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process=True,
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=dict(type='JointsMSELoss', use_target_weight=True))

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
bbox_thr=1.0,
use_gt_bbox=True,
image_thr=0.4,
bbox_file='data/posetrack18/annotations/'
'posetrack18_val_human_detections.json',
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(type='TopDownGenerateTarget', sigma=2),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=[
'img',
],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]

test_pipeline = val_pipeline

data_root = 'data/posetrack18'
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_train.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=train_pipeline),
val=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_val.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=val_pipeline),
test=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_val.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=val_pipeline),
)
18 changes: 17 additions & 1 deletion configs/top_down/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Following the common setting, the models are trained on COCO train dataset, and
| [pose_resnet_152](/configs/top_down/resnet/mpii_trb/res152_mpii_trb_256x256.py) | 256x256 | 0.897 | 0.868 | 0.879 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_mpii_trb_256x256-dd369ce6_20200812.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_mpii_trb_256x256_20200812.log.json) |


### Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector.
#### Results on CrowdPose test with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3) human detector.

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AP (E) | AP (M) | AP (H) | ckpt | log |
| :----------------- | :-----------: | :------: | :------: | :------: | :------: | :------: |:------: |:------: | :------: |
Expand All @@ -97,6 +97,22 @@ Following the common setting, the models are trained on COCO train dataset, and
| [pose_resnet_101](/configs/top_down/resnet/crowdpose/res101_crowdpose_320x256.py) | 320x256 | 0.659 | 0.819 | 0.713 | 0.757 | 0.669 | 0.536 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256-7723ede3_20201017.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res101_crowdpose_320x256_20201017.log.json) |
| [pose_resnet_152](/configs/top_down/resnet/crowdpose/res152_crowdpose_256x192.py) | 256x192 | 0.653 | 0.815 | 0.710 | 0.751 | 0.665 | 0.526 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192-a6507ad0_20201017.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res152_crowdpose_256x192_20201017.log.json) |

#### Results on PoseTrack2018 val with ground-truth bounding boxes.

| Arch | Input Size | Head | Shou | Elb | Wri | Hip | Knee | Ankl | Total | ckpt | log |
| :--- | :--------: | :------: |:------: |:------: |:------: |:------: |:------: | :------: | :------: |:------: |:------: |
| [pose_resnet_50](/configs/top_down/resnet/posetrack18/res50_posetrack18_256x192.py) | 256x192 | 86.5 | 87.5 | 82.3 | 75.6 | 79.9 | 78.6 | 74.0 | 81.0 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_posetrack18_256x192-a62807c7_20201028.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_posetrack18_256x192_20201028.log.json) |

The models are first pre-trained on COCO dataset, and then fine-tuned on PoseTrack18.

#### Results on PoseTrack2018 val with [MMDetection](https://github.com/open-mmlab/mmdetection) pre-trained [Cascade R-CNN](http://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco_20200509_224357-051557b1.pth) (X-101-64x4d-FPN) human detector.

| Arch | Input Size | Head | Shou | Elb | Wri | Hip | Knee | Ankl | Total | ckpt | log |
| :--- | :--------: | :------: |:------: |:------: |:------: |:------: |:------: | :------: | :------: |:------: |:------: |
| [pose_resnet_50](/configs/top_down/resnet/posetrack18/res50_posetrack18_256x192.py) | 256x192 | 78.9 | 81.9 | 77.8 | 70.8 | 75.3 | 73.2 | 66.4 | 75.2 | [ckpt](https://download.openmmlab.com/mmpose/top_down/resnet/res50_posetrack18_256x192-a62807c7_20201028.pth) | [log](https://download.openmmlab.com/mmpose/top_down/resnet/res50_posetrack18_256x192_20201028.log.json) |

The models are first pre-trained on COCO dataset, and then fine-tuned on PoseTrack18.


### 2d Hand Pose Estimation

Expand Down
144 changes: 144 additions & 0 deletions configs/top_down/resnet/posetrack18/res50_posetrack18_256x192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
log_level = 'INFO'
load_from = 'https://download.openmmlab.com/mmpose/top_down/resnet/res50_coco_256x192-ec54d7f3_20200709.pth' # noqa: E501
resume_from = None
dist_params = dict(backend='nccl')
workflow = [('train', 1)]
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, metric='mAP', key_indicator='Total AP')

optimizer = dict(
type='Adam',
lr=5e-4,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[10, 15])
total_epochs = 20
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

# model settings
model = dict(
type='TopDown',
pretrained=None,
backbone=dict(type='ResNet', depth=50),
keypoint_head=dict(
type='TopDownSimpleHead',
in_channels=2048,
out_channels=channel_cfg['num_output_channels'],
),
train_cfg=dict(),
test_cfg=dict(
flip_test=True,
post_process=True,
shift_heatmap=True,
unbiased_decoding=False,
modulate_kernel=11),
loss_pose=dict(type='JointsMSELoss', use_target_weight=True))

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
bbox_thr=1.0,
use_gt_bbox=True,
image_thr=0.4,
bbox_file='data/posetrack18/annotations/'
'posetrack18_val_human_detections.json',
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownRandomFlip', flip_prob=0.5),
dict(
type='TopDownHalfBodyTransform',
num_joints_half_body=8,
prob_half_body=0.3),
dict(
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(type='TopDownGenerateTarget', sigma=2),
dict(
type='Collect',
keys=['img', 'target', 'target_weight'],
meta_keys=[
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
'rotation', 'bbox_score', 'flip_pairs'
]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='TopDownAffine'),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='Collect',
keys=[
'img',
],
meta_keys=[
'image_file', 'center', 'scale', 'rotation', 'bbox_score',
'flip_pairs'
]),
]

test_pipeline = val_pipeline

data_root = 'data/posetrack18'
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_train.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=train_pipeline),
val=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_val.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=val_pipeline),
test=dict(
type='TopDownPoseTrack18Dataset',
ann_file=f'{data_root}/annotations/posetrack18_val.json',
img_prefix=f'{data_root}/',
data_cfg=data_cfg,
pipeline=val_pipeline),
)
Loading

0 comments on commit 3ef8e6d

Please sign in to comment.