Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to main #374

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ concurrency:

jobs:
build_cuda101:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.8]
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Install pytorch3d
run: |
conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y
conda install pytorch3d -c pytorch3d
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
- name: Install MMCV
run: |
pip install "mmcv-full>=1.3.17,<=1.5.3" -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ concurrency:

jobs:
lint:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
Expand Down
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe-
A suite of differentiale visualization tools for human parametric model rendering (including part segmentation, depth map and point clouds) and conventional 2D/3D keypoints are available.

## News
- 2023-04-05: MMHuman3D [v0.11.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.11.0) is released. Major updates include:
- Add [ExPose](configs/expose) inference
- Add [PyMAF-X](configs/pymafx) inference
- Support [CLIFF](configs/cliff)
- 2022-10-12: MMHuman3D [v0.10.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.10.0) is released. Major updates include:
- Add webcam demo and real-time renderer
- Update dataloader to speed up training
Expand All @@ -53,10 +57,6 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe-
- Support new body model [STAR](https://star.is.tue.mpg.de/)
- Release of [GTA-Human](https://caizhongang.github.io/projects/GTA-Human/) dataset with SPIN-FT (51.98 mm) and PARE-FT (46.84 mm) baselines! (Official)
- Refactor registration and improve performance of SPIN to 57.54 mm
- 2022-05-31: MMHuman3D [v0.8.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.8.0) is released. Major updates include:
- Support SmoothNet (added by paper authors)
- Fix circular import and up to 2.5x speed up in module initialization
- Add documentations in Chinese

## Benchmark and Model Zoo

Expand Down Expand Up @@ -92,6 +92,9 @@ Supported methods:
- [x] [ExPose](https://expose.is.tue.mpg.de) (ECCV'2020)
- [x] [BalancedMSE](https://sites.google.com/view/balanced-mse/home) (CVPR'2022)
- [x] [PyMAF-X](https://www.liuyebin.com/pymaf-x/) (arXiv'2022)
- [x] [ExPose](configs/expose) (ECCV'2020)
- [x] [PyMAF-X](configs/pymafx) (arXiv'2022)
- [x] [CLIFF](configs/cliff) (ECCV'2022)

</details>

Expand Down
12 changes: 8 additions & 4 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe-
一整套可微的可视化工具支持人体参数化模型的渲染(包括部分分割,深度图以及点云)和传统 2D/3D 关键点的可视化。

## 最新进展
- 2023-04-05: MMHuman3D [v0.11.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.11.0) 已经发布. 主要更新包括:
- 增加 [ExPose](configs/expose) 推理
- 增加 [PyMAF-X](configs/pymafx) 推理
- 支持 [CLIFF](configs/cliff)
- 2022-10-12: MMHuman3D [v0.10.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.10.0) 已经发布. 主要更新包括:
- 支持调用本地摄像头实时渲染
- 更新数据载入脚本,进而实现训练加速
Expand All @@ -53,10 +57,6 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe-
- 支持新的人体参数化模型 [STAR](https://star.is.tue.mpg.de/)
- 开源 [GTA-Human](https://caizhongang.github.io/projects/GTA-Human/) 数据集,以及 SPIN-FT (51.98 mm) 和 PARE-FT (46.84 mm) 基线! (官方开源)
- 重构配准管线并提升SPIN至 57.54 mm
- 2022-05-31: MMHuman3D [v0.8.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.8.0) 已经发布. 主要更新包括:
- 支持 SmoothNet(由论文作者添加)
- 修复循环引用问题,获得最多2.5倍速度提升
- 增加中文版文档

## 基准与模型库

Expand Down Expand Up @@ -91,6 +91,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe-
- [x] [SmoothNet](https://ailingzeng.site/smoothnet) (ECCV'2022)
- [x] [ExPose](https://expose.is.tue.mpg.de) (ECCV'2020)
- [x] [BalancedMSE](https://sites.google.com/view/balanced-mse/home) (CVPR'2022)
- [x] [ExPose](configs/expose) (ECCV'2020)
- [x] [PyMAF-X](configs/pymafx) (arXiv'2022)
- [x] [CLIFF](configs/cliff) (ECCV'2022)


</details>

Expand Down
81 changes: 81 additions & 0 deletions configs/cliff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# CLIFF

## Introduction

We provide the config files for CLIFF: [CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation](https://arxiv.org/pdf/2208.00571.pdf).

```BibTeX

@Inproceedings{li2022cliff,
author = {Li, Zhihao and
Liu, Jianzhuang and
Zhang, Zhensong and
Xu, Songcen and
Yan, Youliang},
title = {CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation},
booktitle = {ECCV},
year = {2022}
}

```

## Notes

- [SMPL](https://smpl.is.tue.mpg.de/) v1.0 is used in our experiments.
- [J_regressor_extra.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_extra.npy?versionId=CAEQHhiBgIDD6c3V6xciIGIwZDEzYWI5NTBlOTRkODU4OTE1M2Y4YTI0NTVlZGM1)
- [J_regressor_h36m.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_h36m.npy?versionId=CAEQHhiBgIDE6c3V6xciIDdjYzE3MzQ4MmU4MzQyNmRiZDA5YTg2YTI5YWFkNjRi)
- [pascal_occluders.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/pascal_occluders.npy?versionId=CAEQOhiBgMCH2fqigxgiIDY0YzRiNThkMjU1MzRjZTliMTBhZmFmYWY0MTViMTIx)
- [resnet50_a1h2_176-001a1197.pth](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth)
- [resnet50_a1h2_176-001a1197.pth(alternative download link)](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_a1h2_176-001a1197.pth)

Download the above resources and arrange them in the following file structure:

```text
mmhuman3d
├── mmhuman3d
├── docs
├── tests
├── tools
├── configs
└── data
├── checkpoints
│ ├── resnet50_a1h2_176-001a1197.pth
├── body_models
│ ├── J_regressor_extra.npy
│ ├── J_regressor_h36m.npy
│ ├── smpl_mean_params.npz
│ └── smpl
│ ├── SMPL_FEMALE.pkl
│ ├── SMPL_MALE.pkl
│ └── SMPL_NEUTRAL.pkl
├── preprocessed_datasets
│ ├── cliff_coco_train.npz
│ ├── cliff_mpii_train.npz
│ ├── h36m_mosh_train.npz
│ ├── muco3dhp_train.npz
│ ├── mpi_inf_3dhp_train.npz
│ └── pw3d_test.npz
├── occluders
│ ├── pascal_occluders.npy
└── datasets
├── coco
├── h36m
├── muco
├── mpi_inf_3dhp
├── mpii
└── pw3d
```

## Training
Stage 1: First use [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) to train.

Stage 2: After around 150 epoches, switch to [resume.py](resume.py) by using "--resume-from" optional argument.

## Results and Models

We evaluate HMR on 3DPW. Values are MPJPE/PA-MPJPE.

| Config | 3DPW | Download |
|:---------------------------------------------------------:|:-------------:|:------:|
| Stage 1: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 48.65 / 76.49 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff-8328e2e2_20230327.pth) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20220909_142945.log)
| Stage 2: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 47.38 / 75.08 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff_new-1e639f1d_20230327.pth) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20230222_092227.log)
189 changes: 189 additions & 0 deletions configs/cliff/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
_base_ = ['../_base_/default_runtime.py']
use_adversarial_train = True

# evaluate
evaluation = dict(metric=['pa-mpjpe', 'mpjpe'])
# optimizer
optimizer = dict(
backbone=dict(type='Adam', lr=1e-4),
head=dict(type='Adam', lr=1e-4),
# disc=dict(type='Adam', lr=1e-4)
)
optimizer_config = dict(grad_clip=2.0)
# learning policy
lr_config = dict(policy='Fixed', by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=800)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

img_resolution = (192, 256)

# model settings
model = dict(
type='CliffImageBodyModelEstimator',
backbone=dict(
type='ResNet',
depth=50,
out_indices=[3],
norm_eval=False,
norm_cfg=dict(type='SyncBN', requires_grad=True),
init_cfg=dict(
type='Pretrained',
checkpoint='data/checkpoints/resnet50_a1h2_176-001a1197.pth')),
head=dict(
type='CliffHead',
feat_dim=2048,
smpl_mean_params='data/body_models/smpl_mean_params.npz'),
body_model_train=dict(
type='SMPL',
keypoint_src='smpl_54',
keypoint_dst='smpl_54',
model_path='data/body_models/smpl',
keypoint_approximate=True,
extra_joints_regressor='data/body_models/J_regressor_extra.npy'),
body_model_test=dict(
type='SMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
convention='smpl_54',
loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100),
loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10),
loss_vertex=dict(type='L1Loss', loss_weight=2),
loss_smpl_pose=dict(type='MSELoss', loss_weight=3),
loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02),
loss_adv=dict(
type='GANLoss',
gan_type='lsgan',
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1),
# disc=dict(type='SMPLDiscriminator')
)
# dataset settings
dataset_type = 'HumanImageDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data_keys = [
'has_smpl',
'smpl_body_pose',
'smpl_global_orient',
'smpl_betas',
'smpl_transl',
'keypoints2d',
'keypoints3d',
'sample_idx',
'img_h', # extras for cliff
'img_w',
'focal_length',
'center',
'scale',
'bbox_info',
'crop_trans',
'inv_trans'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomChannelNoise', noise_factor=0.4),
dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'),
dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25),
dict(type='GetBboxInfo'),
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]
adv_data_keys = [
'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', 'smpl_transl'
]
train_adv_pipeline = [dict(type='Collect', keys=adv_data_keys, meta_keys=[])]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0),
dict(type='GetBboxInfo'),
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=[
'image_path', 'center', 'scale', 'rotation', 'img_h', 'img_w',
'bbox_info'
])
]

inference_pipeline = [
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img', 'sample_idx'],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

cache_files = {
'cliff_coco': 'data/cache/cliff_coco_train_smpl_54.npz',
}
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='AdversarialDataset',
train_dataset=dict(
type='MixedDataset',
configs=[
dict(
type=dataset_type,
dataset_name='coco',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_54',
cache_data_path=cache_files['cliff_coco'],
ann_file='cliff_coco_train.npz'),
],
partition=[1.0],
),
adv_dataset=dict(
type='MeshDataset',
dataset_name='cmu_mosh',
data_prefix='data',
pipeline=train_adv_pipeline,
ann_file='cmu_mosh.npz')),
val=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
test=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
)
Loading
Loading