From 7996ee5b90735a6e620a741f3b2f8f2722bce7a6 Mon Sep 17 00:00:00 2001 From: oneScotch <71915686+oneScotch@users.noreply.github.com> Date: Wed, 5 Apr 2023 06:26:19 +0200 Subject: [PATCH 1/7] [Add] CLIFF (#302) * add cliff head * add function to convert from crop to full camera * add cliff annotation datasets converter * add tramsforms to get bbox information * store crop trans * cliff mesh estimator * modification to take in different resolutions * add configs * add missing comma * format correction * isort formating * correct error in cliff_head * revert unnecessary changes in cliff_head * add configs(single dataset) and small modification * configs format modification * add test for cliff head * format correction * update test file * format correction * update test file * format correction * update test file * format correction * docformatter correction * update test file * format * add README * add README * add test for cliff data converter * add test for cliff mesh estimator * update tests * merge cliff mesh estimator to mesh estimator * revert unnecessary tests * format * Revert to CliffMeshEstimator * Fix wrong class name in test * Fix linter * Fix bugs for test architecture * Fix test_data_converters.py * Update download links * Update pytorch3d install in workflow * Format * Add additional tests * Update to ubuntu-20.04 * Update to ubuntu-20.04 * Fix pickle * Fix setup.cfg * Fix setup.cfg * Change pickle5 to pickle * Fix pandas version --------- Co-authored-by: caizhongang Co-authored-by: caizhongang --- .github/workflows/build.yml | 4 +- .github/workflows/lint.yml | 2 +- configs/cliff/README.md | 81 ++ configs/cliff/coco.py | 189 ++++ configs/cliff/resnet50_pw3d_cache.py | 225 +++++ configs/cliff/resume.py | 228 +++++ mmhuman3d/data/data_converters/__init__.py | 3 +- mmhuman3d/data/data_converters/cliff.py | 121 +++ mmhuman3d/data/datasets/pipelines/__init__.py | 35 +- .../data/datasets/pipelines/transforms.py | 40 +- mmhuman3d/models/architectures/builder.py | 3 + .../architectures/cliff_mesh_estimator.py | 881 ++++++++++++++++++ mmhuman3d/models/heads/builder.py | 2 + mmhuman3d/models/heads/cliff_head.py | 98 ++ mmhuman3d/utils/geometry.py | 21 + requirements/runtime.txt | 2 +- setup.cfg | 2 +- tests/test_data_converters.py | 8 + tests/test_datasets/test_pipelines.py | 15 + .../test_cliff_mesh_estimator.py | 417 +++++++++ .../test_models/test_heads/test_cliff_head.py | 59 ++ tools/convert_datasets.py | 2 +- 22 files changed, 2402 insertions(+), 36 deletions(-) create mode 100644 configs/cliff/README.md create mode 100644 configs/cliff/coco.py create mode 100644 configs/cliff/resnet50_pw3d_cache.py create mode 100644 configs/cliff/resume.py create mode 100644 mmhuman3d/data/data_converters/cliff.py create mode 100644 mmhuman3d/models/architectures/cliff_mesh_estimator.py create mode 100644 mmhuman3d/models/heads/cliff_head.py create mode 100644 tests/test_models/test_architectures/test_cliff_mesh_estimator.py create mode 100644 tests/test_models/test_heads/test_cliff_head.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7fd2b25f..76a50c2c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,7 +26,7 @@ concurrency: jobs: build_cuda101: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: matrix: python-version: [3.8] @@ -69,7 +69,7 @@ jobs: - name: Install pytorch3d run: | conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y - conda install pytorch3d -c pytorch3d + pip install "git+https://github.com/facebookresearch/pytorch3d.git" - name: Install MMCV run: | pip install "mmcv-full>=1.3.17,<=1.5.3" -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 822bdcd9..2208335d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -8,7 +8,7 @@ concurrency: jobs: lint: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: Set up Python 3.8 diff --git a/configs/cliff/README.md b/configs/cliff/README.md new file mode 100644 index 00000000..8f536fa6 --- /dev/null +++ b/configs/cliff/README.md @@ -0,0 +1,81 @@ +# CLIFF + +## Introduction + +We provide the config files for CLIFF: [CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation](https://arxiv.org/pdf/2208.00571.pdf). + +```BibTeX + +@Inproceedings{li2022cliff, + author = {Li, Zhihao and + Liu, Jianzhuang and + Zhang, Zhensong and + Xu, Songcen and + Yan, Youliang}, + title = {CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation}, + booktitle = {ECCV}, + year = {2022} +} + +``` + +## Notes + +- [SMPL](https://smpl.is.tue.mpg.de/) v1.0 is used in our experiments. +- [J_regressor_extra.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_extra.npy?versionId=CAEQHhiBgIDD6c3V6xciIGIwZDEzYWI5NTBlOTRkODU4OTE1M2Y4YTI0NTVlZGM1) +- [J_regressor_h36m.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_h36m.npy?versionId=CAEQHhiBgIDE6c3V6xciIDdjYzE3MzQ4MmU4MzQyNmRiZDA5YTg2YTI5YWFkNjRi) +- [pascal_occluders.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/pascal_occluders.npy?versionId=CAEQOhiBgMCH2fqigxgiIDY0YzRiNThkMjU1MzRjZTliMTBhZmFmYWY0MTViMTIx) +- [resnet50_a1h2_176-001a1197.pth](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth) +- [resnet50_a1h2_176-001a1197.pth(alternative download link)](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_a1h2_176-001a1197.pth) + +Download the above resources and arrange them in the following file structure: + +```text +mmhuman3d +├── mmhuman3d +├── docs +├── tests +├── tools +├── configs +└── data + ├── checkpoints + │ ├── resnet50_a1h2_176-001a1197.pth + ├── body_models + │ ├── J_regressor_extra.npy + │ ├── J_regressor_h36m.npy + │ ├── smpl_mean_params.npz + │ └── smpl + │ ├── SMPL_FEMALE.pkl + │ ├── SMPL_MALE.pkl + │ └── SMPL_NEUTRAL.pkl + ├── preprocessed_datasets + │ ├── cliff_coco_train.npz + │ ├── cliff_mpii_train.npz + │ ├── h36m_mosh_train.npz + │ ├── muco3dhp_train.npz + │ ├── mpi_inf_3dhp_train.npz + │ └── pw3d_test.npz + ├── occluders + │ ├── pascal_occluders.npy + └── datasets + ├── coco + ├── h36m + ├── muco + ├── mpi_inf_3dhp + ├── mpii + └── pw3d +``` + +## Training +Stage 1: First use [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) to train. + +Stage 2: After around 150 epoches, switch to [resume.py](resume.py) by using "--resume-from" optional argument. + +## Results and Models + +We evaluate HMR on 3DPW. Values are MPJPE/PA-MPJPE. + +| Config | 3DPW | Download | +|:---------------------------------------------------------:|:-------------:|:------:| +| Stage 1: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 48.65 / 76.49 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff-8328e2e2_20230327.pth) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20220909_142945.log) +| Stage 2: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 47.38 / 75.08 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff_new-1e639f1d_20230327.pth) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20230222_092227.log) diff --git a/configs/cliff/coco.py b/configs/cliff/coco.py new file mode 100644 index 00000000..6651909e --- /dev/null +++ b/configs/cliff/coco.py @@ -0,0 +1,189 @@ +_base_ = ['../_base_/default_runtime.py'] +use_adversarial_train = True + +# evaluate +evaluation = dict(metric=['pa-mpjpe', 'mpjpe']) +# optimizer +optimizer = dict( + backbone=dict(type='Adam', lr=1e-4), + head=dict(type='Adam', lr=1e-4), + # disc=dict(type='Adam', lr=1e-4) +) +optimizer_config = dict(grad_clip=2.0) +# learning policy +lr_config = dict(policy='Fixed', by_epoch=False) +runner = dict(type='EpochBasedRunner', max_epochs=800) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +img_resolution = (192, 256) + +# model settings +model = dict( + type='CliffImageBodyModelEstimator', + backbone=dict( + type='ResNet', + depth=50, + out_indices=[3], + norm_eval=False, + norm_cfg=dict(type='SyncBN', requires_grad=True), + init_cfg=dict( + type='Pretrained', + checkpoint='data/checkpoints/resnet50_a1h2_176-001a1197.pth')), + head=dict( + type='CliffHead', + feat_dim=2048, + smpl_mean_params='data/body_models/smpl_mean_params.npz'), + body_model_train=dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_54', + model_path='data/body_models/smpl', + keypoint_approximate=True, + extra_joints_regressor='data/body_models/J_regressor_extra.npy'), + body_model_test=dict( + type='SMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + convention='smpl_54', + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100), + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10), + loss_vertex=dict(type='L1Loss', loss_weight=2), + loss_smpl_pose=dict(type='MSELoss', loss_weight=3), + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02), + loss_adv=dict( + type='GANLoss', + gan_type='lsgan', + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1), + # disc=dict(type='SMPLDiscriminator') +) +# dataset settings +dataset_type = 'HumanImageDataset' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data_keys = [ + 'has_smpl', + 'smpl_body_pose', + 'smpl_global_orient', + 'smpl_betas', + 'smpl_transl', + 'keypoints2d', + 'keypoints3d', + 'sample_idx', + 'img_h', # extras for cliff + 'img_w', + 'focal_length', + 'center', + 'scale', + 'bbox_info', + 'crop_trans', + 'inv_trans' +] +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomChannelNoise', noise_factor=0.4), + dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'), + dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] +adv_data_keys = [ + 'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', 'smpl_transl' +] +train_adv_pipeline = [dict(type='Collect', keys=adv_data_keys, meta_keys=[])] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=[ + 'image_path', 'center', 'scale', 'rotation', 'img_h', 'img_w', + 'bbox_info' + ]) +] + +inference_pipeline = [ + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict( + type='Collect', + keys=['img', 'sample_idx'], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] + +cache_files = { + 'cliff_coco': 'data/cache/cliff_coco_train_smpl_54.npz', +} +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='AdversarialDataset', + train_dataset=dict( + type='MixedDataset', + configs=[ + dict( + type=dataset_type, + dataset_name='coco', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['cliff_coco'], + ann_file='cliff_coco_train.npz'), + ], + partition=[1.0], + ), + adv_dataset=dict( + type='MeshDataset', + dataset_name='cmu_mosh', + data_prefix='data', + pipeline=train_adv_pipeline, + ann_file='cmu_mosh.npz')), + val=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), + test=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), +) diff --git a/configs/cliff/resnet50_pw3d_cache.py b/configs/cliff/resnet50_pw3d_cache.py new file mode 100644 index 00000000..d0d8becd --- /dev/null +++ b/configs/cliff/resnet50_pw3d_cache.py @@ -0,0 +1,225 @@ +_base_ = ['../_base_/default_runtime.py'] +use_adversarial_train = True + +# evaluate +evaluation = dict(metric=['pa-mpjpe', 'mpjpe']) +# optimizer +optimizer = dict( + backbone=dict(type='Adam', lr=3e-4), + head=dict(type='Adam', lr=3e-4), + # disc=dict(type='Adam', lr=1e-4) +) +optimizer_config = dict(grad_clip=2.0) +# learning policy +lr_config = dict(policy='step', gamma=0.1, step=[100]) +runner = dict(type='EpochBasedRunner', max_epochs=250) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +img_resolution = (192, 256) + +# model settings +model = dict( + type='CliffImageBodyModelEstimator', + backbone=dict( + type='ResNet', + depth=50, + out_indices=[3], + norm_eval=False, + norm_cfg=dict(type='SyncBN', requires_grad=True), + init_cfg=dict( + type='Pretrained', + checkpoint='data/checkpoints/resnet50_a1h2_176-001a1197.pth')), + head=dict( + type='CliffHead', + feat_dim=2048, + smpl_mean_params='data/body_models/smpl_mean_params.npz'), + body_model_train=dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_54', + model_path='data/body_models/smpl', + keypoint_approximate=True, + extra_joints_regressor='data/body_models/J_regressor_extra.npy'), + body_model_test=dict( + type='SMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + convention='smpl_54', + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100), + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10), + loss_vertex=dict(type='L1Loss', loss_weight=2), + loss_smpl_pose=dict(type='MSELoss', loss_weight=3), + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02), + loss_adv=dict( + type='GANLoss', + gan_type='lsgan', + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1), + # disc=dict(type='SMPLDiscriminator') +) +# dataset settings +dataset_type = 'HumanImageDataset' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data_keys = [ + 'has_smpl', + 'smpl_body_pose', + 'smpl_global_orient', + 'smpl_betas', + 'smpl_transl', + 'keypoints2d', + 'keypoints3d', + 'sample_idx', + 'img_h', # extras for cliff + 'img_w', + 'focal_length', + 'center', + 'scale', + 'bbox_info', + 'crop_trans', + 'inv_trans' +] +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomChannelNoise', noise_factor=0.4), + dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'), + dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] +adv_data_keys = [ + 'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', 'smpl_transl' +] +train_adv_pipeline = [dict(type='Collect', keys=adv_data_keys, meta_keys=[])] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=[ + 'image_path', 'center', 'scale', 'rotation', 'img_h', 'img_w', + 'bbox_info' + ]) +] + +inference_pipeline = [ + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict( + type='Collect', + keys=['img', 'sample_idx'], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] + +cache_files = { + 'h36m': 'data/cache/h36m_mosh_train_smpl_54.npz', + 'mpi_inf_3dhp': 'data/cache/mpi_inf_3dhp_train_smpl_54.npz', + 'cliff_coco': 'data/cache/cliff_coco_train_smpl_54.npz', + 'cliff_mpii': 'data/cache/cliff_mpii_train_smpl_54.npz', + 'pw3d': 'data/cache/pw3d_train_smpl_54.npz', +} +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='AdversarialDataset', + train_dataset=dict( + type='MixedDataset', + configs=[ + dict( + type=dataset_type, + dataset_name='h36m', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['h36m'], + ann_file='h36m_mosh_train.npz'), + dict( + type=dataset_type, + dataset_name='mpi_inf_3dhp', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['mpi_inf_3dhp'], + ann_file='mpi_inf_3dhp_train.npz'), + dict( + type=dataset_type, + dataset_name='mpii', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['cliff_mpii'], + ann_file='cliff_mpii_train.npz'), + dict( + type=dataset_type, + dataset_name='coco', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['cliff_coco'], + ann_file='cliff_coco_train.npz'), + dict( + type=dataset_type, + dataset_name='pw3d', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['pw3d'], + ann_file='pw3d_train.npz'), + ], + partition=[0.4, 0.1, 0.1, 0.2, 0.2], + ), + adv_dataset=dict( + type='MeshDataset', + dataset_name='cmu_mosh', + data_prefix='data', + pipeline=train_adv_pipeline, + ann_file='cmu_mosh.npz')), + val=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), + test=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), +) diff --git a/configs/cliff/resume.py b/configs/cliff/resume.py new file mode 100644 index 00000000..652de149 --- /dev/null +++ b/configs/cliff/resume.py @@ -0,0 +1,228 @@ +_base_ = ['../_base_/default_runtime.py'] +use_adversarial_train = True + +# evaluate +evaluation = dict(metric=['pa-mpjpe', 'mpjpe']) +# optimizer +optimizer = dict( + backbone=dict(type='Adam', lr=3e-4), + head=dict(type='Adam', lr=3e-4), + # disc=dict(type='Adam', lr=1e-4) +) +optimizer_config = dict(grad_clip=2.0) +# learning policy +lr_config = dict(policy='step', gamma=0.1, step=[100]) +runner = dict(type='EpochBasedRunner', max_epochs=160) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +img_resolution = (192, 256) + +# model settings +model = dict( + type='CliffImageBodyModelEstimator', + backbone=dict( + type='ResNet', + depth=50, + out_indices=[3], + norm_eval=False, + norm_cfg=dict(type='SyncBN', requires_grad=True), + init_cfg=dict( + type='Pretrained', + checkpoint='data/checkpoints/resnet50_a1h2_176-001a1197.pth')), + head=dict( + type='CliffHead', + feat_dim=2048, + smpl_mean_params='data/body_models/smpl_mean_params.npz'), + body_model_train=dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_54', + model_path='data/body_models/smpl', + keypoint_approximate=True, + extra_joints_regressor='data/body_models/J_regressor_extra.npy'), + body_model_test=dict( + type='SMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + convention='smpl_54', + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100), + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10), + loss_vertex=dict(type='L1Loss', loss_weight=2), + loss_smpl_pose=dict(type='MSELoss', loss_weight=3), + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02), + loss_adv=dict( + type='GANLoss', + gan_type='lsgan', + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1), + # disc=dict(type='SMPLDiscriminator') +) +# dataset settings +dataset_type = 'HumanImageDataset' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data_keys = [ + 'has_smpl', + 'smpl_body_pose', + 'smpl_global_orient', + 'smpl_betas', + 'smpl_transl', + 'keypoints2d', + 'keypoints3d', + 'sample_idx', + 'img_h', # extras for cliff + 'img_w', + 'focal_length', + 'center', + 'scale', + 'bbox_info', + 'crop_trans', + 'inv_trans' +] +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomChannelNoise', noise_factor=0.4), + dict( + type='SyntheticOcclusion', + occluders_file='data/occluders/pascal_occluders.npy'), + dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'), + dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] +adv_data_keys = [ + 'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', 'smpl_transl' +] +train_adv_pipeline = [dict(type='Collect', keys=adv_data_keys, meta_keys=[])] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0), + dict(type='GetBboxInfo'), + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=[ + 'image_path', 'center', 'scale', 'rotation', 'img_h', 'img_w', + 'bbox_info' + ]) +] + +inference_pipeline = [ + dict(type='MeshAffine', img_res=img_resolution), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict( + type='Collect', + keys=['img', 'sample_idx'], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] + +cache_files = { + 'h36m': 'data/cache/h36m_mosh_train_smpl_54.npz', + 'muco': 'data/cache/muco3dhp_train.npz', + 'cliff_coco': 'data/cache/cliff_coco_train_smpl_54.npz', + 'cliff_mpii': 'data/cache/cliff_mpii_train_smpl_54.npz', + 'pw3d': 'data/cache/pw3d_train_smpl_54.npz', +} +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='AdversarialDataset', + train_dataset=dict( + type='MixedDataset', + configs=[ + dict( + type=dataset_type, + dataset_name='h36m', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['h36m'], + ann_file='h36m_mosh_train.npz'), + dict( + type=dataset_type, + dataset_name='muco', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['muco'], + ann_file='muco3dhp_train.npz'), + dict( + type=dataset_type, + dataset_name='mpii', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['cliff_mpii'], + ann_file='cliff_mpii_train.npz'), + dict( + type=dataset_type, + dataset_name='coco', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['cliff_coco'], + ann_file='cliff_coco_train.npz'), + dict( + type=dataset_type, + dataset_name='pw3d', + data_prefix='data', + pipeline=train_pipeline, + convention='smpl_54', + cache_data_path=cache_files['pw3d'], + ann_file='pw3d_train.npz'), + ], + partition=[0.4, 0.1, 0.1, 0.2, 0.2], + ), + adv_dataset=dict( + type='MeshDataset', + dataset_name='cmu_mosh', + data_prefix='data', + pipeline=train_adv_pipeline, + ann_file='cmu_mosh.npz')), + val=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), + test=dict( + type=dataset_type, + body_model=dict( + type='GenderedSMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy'), + dataset_name='pw3d', + data_prefix='data', + pipeline=test_pipeline, + ann_file='pw3d_test.npz'), +) diff --git a/mmhuman3d/data/data_converters/__init__.py b/mmhuman3d/data/data_converters/__init__.py index c056f926..7471bfa4 100644 --- a/mmhuman3d/data/data_converters/__init__.py +++ b/mmhuman3d/data/data_converters/__init__.py @@ -1,6 +1,7 @@ from .agora import AgoraConverter from .amass import AmassConverter from .builder import build_data_converter +from .cliff import CliffConverter from .coco import CocoConverter from .coco_hybrik import CocoHybrIKConverter from .coco_wholebody import CocoWholebodyConverter @@ -43,5 +44,5 @@ 'SurrealConverter', 'InstaVibeConverter', 'SpinConverter', 'VibeConverter', 'HuMManConverter', 'FFHQFlameConverter', 'ExposeCuratedFitsConverter', 'ExposeSPINSMPLXConverter', 'FreihandConverter', 'StirlingConverter', - 'EHFConverter' + 'EHFConverter', 'CliffConverter' ] diff --git a/mmhuman3d/data/data_converters/cliff.py b/mmhuman3d/data/data_converters/cliff.py new file mode 100644 index 00000000..e34e897f --- /dev/null +++ b/mmhuman3d/data/data_converters/cliff.py @@ -0,0 +1,121 @@ +import os +from typing import List + +import numpy as np + +from mmhuman3d.core.conventions.keypoints_mapping import convert_kps +from mmhuman3d.data.data_structures.human_data import HumanData +from mmhuman3d.data.data_structures.multi_human_data import MultiHumanData +from .base_converter import BaseModeConverter +from .builder import DATA_CONVERTERS + + +@DATA_CONVERTERS.register_module() +class CliffConverter(BaseModeConverter): + """CLIFF datasets converter `Carrying Location Information in Full Frames + into Human Pose and Shape Estimation' More details can be found in the + `paper. + + `__. + Args: + modes (list): 'coco', 'mpii' + for accepted modes + """ + + ACCEPTED_MODES = ['coco', 'mpii'] + + def __init__(self, modes: List = []) -> None: + super(CliffConverter, self).__init__(modes) + + # def __init__(self) -> None: + self.mapping_dict = { + 'coco': 'coco2014part_cliffGT.npz', + 'mpii': 'mpii_cliffGT.npz', + } + + def convert_by_mode(self, + dataset_path: str, + out_path: str, + mode: str, + enable_multi_human_data: bool = False) -> dict: + """ + Args: + dataset_path (str): Path to directory where spin preprocessed + npz files are stored + out_path (str): Path to directory to save preprocessed npz file + mode (str): Mode in accepted modes + enable_multi_human_data (bool): + Whether to generate a multi-human data. If set to True, + stored in MultiHumanData() format. + Default: False, stored in HumanData() format. + + Returns: + dict: + A dict containing keys image_path, bbox_xywh, keypoints2d, + keypoints2d_mask,stored in HumanData() format. keypoints3d, + keypoints3d_mask, smpl are added if available. + + """ + if enable_multi_human_data: + # use MultiHumanData to store all data + human_data = MultiHumanData() + else: + # use HumanData to store all data + human_data = HumanData() + + image_path_, keypoints2d_, bbox_xywh_ = [], [], [] + + if mode in self.mapping_dict.keys(): + seq_file = self.mapping_dict[mode] + seq_path = os.path.join(dataset_path, seq_file) + + data = np.load(seq_path) + + keypoints2d_ = data['part'] + image_path_ = data['imgname'] + + # center scale to bbox + w = h = data['scale'] * 200 + x = data['center'][:, 0] - w / 2 + y = data['center'][:, 1] - h / 2 + + bbox_xywh_ = np.column_stack((x, y, w, h)) + + # convert keypoints + bbox_xywh_ = np.array(bbox_xywh_).reshape((-1, 4)) + bbox_xywh_ = np.hstack([bbox_xywh_, np.ones([bbox_xywh_.shape[0], 1])]) + keypoints2d_ = np.array(keypoints2d_).reshape((-1, 24, 3)) + keypoints2d_, keypoints2d_mask = convert_kps(keypoints2d_, 'smpl_24', + 'human_data') + + if 'S' in data: + keypoints3d_ = data['S'] + keypoints3d_ = np.array(keypoints3d_).reshape((-1, 24, 4)) + keypoints3d_, keypoints3d_mask = convert_kps( + keypoints3d_, 'smpl_24', 'human_data') + human_data['keypoints3d_mask'] = keypoints3d_mask + human_data['keypoints3d'] = keypoints3d_ + + if 'has_smpl' in data: + has_smpl = data['has_smpl'] + smpl = {} + smpl['body_pose'] = np.array(data['pose'][:, 3:]).reshape( + (-1, 23, 3)) + smpl['global_orient'] = np.array(data['pose'][:, :3]).reshape( + (-1, 3)) + smpl['betas'] = np.array(data['shape']).reshape((-1, 10)) + human_data['smpl'] = smpl + human_data['has_smpl'] = has_smpl + + human_data['image_path'] = image_path_.tolist() + human_data['bbox_xywh'] = bbox_xywh_ + human_data['keypoints2d_mask'] = keypoints2d_mask + human_data['keypoints2d'] = keypoints2d_ + human_data['config'] = mode + human_data.compress_keypoints_by_mask() + + # store the data struct + if not os.path.isdir(out_path): + os.makedirs(out_path) + out_file = os.path.join(out_path, f'cliff_{mode}_train.npz') + human_data.dump(out_file) diff --git a/mmhuman3d/data/datasets/pipelines/__init__.py b/mmhuman3d/data/datasets/pipelines/__init__.py index ae6d9dbe..6551ed63 100644 --- a/mmhuman3d/data/datasets/pipelines/__init__.py +++ b/mmhuman3d/data/datasets/pipelines/__init__.py @@ -22,6 +22,7 @@ BBoxCenterJitter, CenterCrop, ColorJitter, + GetBboxInfo, GetRandomScaleRotation, Lighting, MeshAffine, @@ -33,31 +34,11 @@ ) __all__ = [ - 'Compose', - 'to_tensor', - 'ToTensor', - 'ImageToTensor', - 'ToPIL', - 'ToNumpy', - 'Transpose', - 'Collect', - 'LoadImageFromFile', - 'CenterCrop', - 'RandomHorizontalFlip', - 'ColorJitter', - 'Lighting', - 'RandomChannelNoise', - 'GetRandomScaleRotation', - 'MeshAffine', - 'HybrIKRandomFlip', - 'HybrIKAffine', - 'GenerateHybrIKTarget', - 'RandomDPG', - 'RandomOcclusion', - 'Rotation', - 'NewKeypointsSelection', - 'Normalize', - 'SyntheticOcclusion', - 'BBoxCenterJitter', - 'SimulateLowRes', + 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', + 'Transpose', 'Collect', 'LoadImageFromFile', 'CenterCrop', + 'RandomHorizontalFlip', 'ColorJitter', 'Lighting', 'RandomChannelNoise', + 'GetRandomScaleRotation', 'MeshAffine', 'HybrIKRandomFlip', 'HybrIKAffine', + 'GenerateHybrIKTarget', 'RandomDPG', 'RandomOcclusion', 'Rotation', + 'NewKeypointsSelection', 'Normalize', 'SyntheticOcclusion', + 'BBoxCenterJitter', 'SimulateLowRes', 'GetBboxInfo' ] diff --git a/mmhuman3d/data/datasets/pipelines/transforms.py b/mmhuman3d/data/datasets/pipelines/transforms.py index 078bcd38..da2f42f9 100644 --- a/mmhuman3d/data/datasets/pipelines/transforms.py +++ b/mmhuman3d/data/datasets/pipelines/transforms.py @@ -746,14 +746,18 @@ class MeshAffine: """ def __init__(self, img_res): - self.img_res = img_res - self.image_size = np.array([img_res, img_res]) + if isinstance(img_res, tuple): + self.image_size = img_res + else: + self.image_size = np.array([img_res, img_res]) def __call__(self, results): c = results['center'] s = results['scale'] r = results['rotation'] trans = get_affine_transform(c, s, r, self.image_size) + inv_trans = get_affine_transform(c, s, 0., self.image_size, inv=True) + crop_trans = get_affine_transform(c, s, 0., self.image_size) if 'img' in results: img = results['img'] @@ -797,6 +801,8 @@ def __call__(self, results): global_orient = _rotate_smpl_pose(global_orient, r) results['smplx_global_orient'] = global_orient + results['crop_trans'] = crop_trans + results['inv_trans'] = inv_trans return results @@ -951,3 +957,33 @@ def __call__(self, results): results['img'] = img return results + + +@PIPELINES.register_module() +class GetBboxInfo: + """Get bbox for cliff.""" + + def estimate_focal_length(self, img_h, img_w): + return (img_w * img_w + img_h * img_h)**0.5 # fov: 55 degree + + def __call__(self, results): + """(1) Get focal length from original image (2) get bbox_info from c + and s.""" + img = results['img'] + img_h, img_w = img.shape[:2] + focal_length = self.estimate_focal_length(img_h, img_w) + + results['img_h'] = img_h + results['img_w'] = img_w + results['focal_length'] = focal_length + cx, cy = results['center'] + s = results['scale'][0] + + bbox_info = np.stack([cx - img_w / 2., cy - img_h / 2., s]) + bbox_info[:2] = bbox_info[:2] / focal_length * 2.8 # [-1, 1] + bbox_info[2] = (bbox_info[2] - 0.24 * focal_length) / ( + 0.06 * focal_length) # [-1, 1] + + results['bbox_info'] = np.float32(bbox_info) + + return results diff --git a/mmhuman3d/models/architectures/builder.py b/mmhuman3d/models/architectures/builder.py index 4e504d82..15ffbc01 100644 --- a/mmhuman3d/models/architectures/builder.py +++ b/mmhuman3d/models/architectures/builder.py @@ -3,6 +3,7 @@ from mmcv.cnn import MODELS as MMCV_MODELS from mmcv.utils import Registry +from .cliff_mesh_estimator import CliffImageBodyModelEstimator from .expressive_mesh_estimator import SMPLXImageBodyModelEstimator from .hybrik import HybrIK_trainer from .mesh_estimator import ImageBodyModelEstimator, VideoBodyModelEstimator @@ -25,6 +26,8 @@ def build_from_cfg(cfg, registry, default_args=None): name='VideoBodyModelEstimator', module=VideoBodyModelEstimator) ARCHITECTURES.register_module( name='SMPLXImageBodyModelEstimator', module=SMPLXImageBodyModelEstimator) +ARCHITECTURES.register_module( + name='CliffImageBodyModelEstimator', module=CliffImageBodyModelEstimator) ARCHITECTURES.register_module(name='PyMAFX', module=PyMAFX) diff --git a/mmhuman3d/models/architectures/cliff_mesh_estimator.py b/mmhuman3d/models/architectures/cliff_mesh_estimator.py new file mode 100644 index 00000000..cea36317 --- /dev/null +++ b/mmhuman3d/models/architectures/cliff_mesh_estimator.py @@ -0,0 +1,881 @@ +from abc import ABCMeta, abstractmethod +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +import mmhuman3d.core.visualization.visualize_smpl as visualize_smpl +from mmhuman3d.core.conventions.keypoints_mapping import get_keypoint_idx +from mmhuman3d.models.utils import FitsDict +from mmhuman3d.utils.geometry import ( + batch_rodrigues, + cam_crop2full, + estimate_translation, + perspective_projection, + project_points, + rotation_matrix_to_angle_axis, +) +from ..backbones.builder import build_backbone +from ..body_models.builder import build_body_model +from ..discriminators.builder import build_discriminator +from ..heads.builder import build_head +from ..losses.builder import build_loss +from ..necks.builder import build_neck +from ..registrants.builder import build_registrant +from .base_architecture import BaseArchitecture + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad for all the networks. + + Args: + nets (nn.Module | list[nn.Module]): A list of networks or a single + network. + requires_grad (bool): Whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +class BodyModelEstimator(BaseArchitecture, metaclass=ABCMeta): + """BodyModelEstimator Architecture. + + Args: + backbone (dict | None, optional): Backbone config dict. Default: None. + neck (dict | None, optional): Neck config dict. Default: None + head (dict | None, optional): Regressor config dict. Default: None. + disc (dict | None, optional): Discriminator config dict. + Default: None. + registration (dict | None, optional): Registration config dict. + Default: None. + body_model_train (dict | None, optional): SMPL config dict during + training. Default: None. + body_model_test (dict | None, optional): SMPL config dict during + test. Default: None. + convention (str, optional): Keypoints convention. Default: "human_data" + loss_keypoints2d (dict | None, optional): Losses config dict for + 2D keypoints. Default: None. + loss_keypoints3d (dict | None, optional): Losses config dict for + 3D keypoints. Default: None. + loss_vertex (dict | None, optional): Losses config dict for mesh + vertices. Default: None + loss_smpl_pose (dict | None, optional): Losses config dict for smpl + pose. Default: None + loss_smpl_betas (dict | None, optional): Losses config dict for smpl + betas. Default: None + loss_camera (dict | None, optional): Losses config dict for predicted + camera parameters. Default: None + loss_adv (dict | None, optional): Losses config for adversial + training. Default: None. + loss_segm_mask (dict | None, optional): Losses config for predicted + part segmentation. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + backbone: Optional[Union[dict, None]] = None, + neck: Optional[Union[dict, None]] = None, + head: Optional[Union[dict, None]] = None, + disc: Optional[Union[dict, None]] = None, + registration: Optional[Union[dict, None]] = None, + body_model_train: Optional[Union[dict, None]] = None, + body_model_test: Optional[Union[dict, None]] = None, + convention: Optional[str] = 'human_data', + loss_keypoints2d: Optional[Union[dict, None]] = None, + loss_keypoints3d: Optional[Union[dict, None]] = None, + loss_vertex: Optional[Union[dict, None]] = None, + loss_smpl_pose: Optional[Union[dict, None]] = None, + loss_smpl_betas: Optional[Union[dict, None]] = None, + loss_camera: Optional[Union[dict, None]] = None, + loss_adv: Optional[Union[dict, None]] = None, + loss_segm_mask: Optional[Union[dict, None]] = None, + init_cfg: Optional[Union[list, dict, None]] = None): + super(BodyModelEstimator, self).__init__(init_cfg) + self.backbone = build_backbone(backbone) + self.neck = build_neck(neck) + self.head = build_head(head) + self.disc = build_discriminator(disc) + + self.body_model_train = build_body_model(body_model_train) + self.body_model_test = build_body_model(body_model_test) + self.convention = convention + + # TODO: support HMR+ + + self.registration = registration + if registration is not None: + self.fits_dict = FitsDict(fits='static') + self.registration_mode = self.registration['mode'] + self.registrant = build_registrant(registration['registrant']) + else: + self.registrant = None + + self.loss_keypoints2d = build_loss(loss_keypoints2d) + self.loss_keypoints3d = build_loss(loss_keypoints3d) + + self.loss_vertex = build_loss(loss_vertex) + self.loss_smpl_pose = build_loss(loss_smpl_pose) + self.loss_smpl_betas = build_loss(loss_smpl_betas) + self.loss_adv = build_loss(loss_adv) + self.loss_camera = build_loss(loss_camera) + self.loss_segm_mask = build_loss(loss_segm_mask) + set_requires_grad(self.body_model_train, False) + set_requires_grad(self.body_model_test, False) + + def train_step(self, data_batch, optimizer, **kwargs): + """Train step function. + + In this function, the detector will finish the train step following + the pipeline: + 1. get fake and real SMPL parameters + 2. optimize discriminator (if have) + 3. optimize generator + If `self.train_cfg.disc_step > 1`, the train step will contain multiple + iterations for optimizing discriminator with different input data and + only one iteration for optimizing generator after `disc_step` + iterations for discriminator. + Args: + data_batch (torch.Tensor): Batch of data as input. + optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for + generator and discriminator (if have). + Returns: + outputs (dict): Dict with loss, information for logger, + the number of samples. + """ + if self.backbone is not None: + img = data_batch['img'] + features = self.backbone(img) + else: + features = data_batch['features'] + + if self.neck is not None: + features = self.neck(features) + + # NOTE: features and bbox_info taken as input for Cliff + bbox_info = data_batch['bbox_info'] + predictions = self.head(features, bbox_info) + targets = self.prepare_targets(data_batch) + + # optimize discriminator (if have) + if self.disc is not None: + self.optimize_discrinimator(predictions, data_batch, optimizer) + + if self.registration is not None: + targets = self.run_registration(predictions, targets) + + losses = self.compute_losses(predictions, targets) + # optimizer generator part + if self.disc is not None: + adv_loss = self.optimize_generator(predictions) + losses.update(adv_loss) + + loss, log_vars = self._parse_losses(losses) + for key in optimizer.keys(): + optimizer[key].zero_grad() + loss.backward() + for key in optimizer.keys(): + optimizer[key].step() + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(next(iter(data_batch.values())))) + return outputs + + def run_registration( + self, + predictions: dict, + targets: dict, + threshold: Optional[float] = 10.0, + focal_length: Optional[float] = 5000.0, + img_res: Optional[Union[Tuple[int], int]] = 224) -> dict: + """Run registration on 2D keypoinst in predictions to obtain SMPL + parameters as pseudo ground truth. + + Args: + predictions (dict): predicted SMPL parameters are used for + initialization. + targets (dict): existing ground truths with 2D keypoints + threshold (float, optional): the threshold to update fits + dictionary. Default: 10.0. + focal_length (tuple(int) | int, optional): camera focal_length + img_res (int, optional): image resolution + + Returns: + targets: contains additional SMPL parameters + """ + + img_metas = targets['img_metas'] + dataset_name = [meta['dataset_name'] for meta in img_metas + ] # name of the dataset the image comes from + + indices = targets['sample_idx'].squeeze() + is_flipped = targets['is_flipped'].squeeze().bool( + ) # flag that indicates whether image was flipped + # during data augmentation + rot_angle = targets['rotation'].squeeze( + ) # rotation angle used for data augmentation Q + gt_betas = targets['smpl_betas'].float() + gt_global_orient = targets['smpl_global_orient'].float() + gt_pose = targets['smpl_body_pose'].float().view(-1, 69) + + pred_rotmat = predictions['pred_pose'].detach().clone() + pred_betas = predictions['pred_shape'].detach().clone() + pred_cam = predictions['pred_cam'].detach().clone() + pred_cam_t = torch.stack([ + pred_cam[:, 1], pred_cam[:, 2], 2 * focal_length / + (img_res * pred_cam[:, 0] + 1e-9) + ], + dim=-1) + + gt_keypoints_2d = targets['keypoints2d'].float() + num_keypoints = gt_keypoints_2d.shape[1] + + has_smpl = targets['has_smpl'].view( + -1).bool() # flag that indicates whether SMPL parameters are valid + batch_size = has_smpl.shape[0] + device = has_smpl.device + + # Get GT vertices and model joints + # Note that gt_model_joints is different from gt_joints as + # it comes from SMPL + gt_out = self.body_model_train( + betas=gt_betas, body_pose=gt_pose, global_orient=gt_global_orient) + # TODO: support more convention + assert num_keypoints == 49 + gt_model_joints = gt_out['joints'] + gt_vertices = gt_out['vertices'] + + # Get current best fits from the dictionary + opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(), + rot_angle.cpu(), + is_flipped.cpu())] + + opt_pose = opt_pose.to(device) + opt_betas = opt_betas.to(device) + opt_output = self.body_model_train( + betas=opt_betas, + body_pose=opt_pose[:, 3:], + global_orient=opt_pose[:, :3]) + opt_joints = opt_output['joints'] + opt_vertices = opt_output['vertices'] + + gt_keypoints_2d_orig = gt_keypoints_2d.clone() + # Estimate camera translation given the model joints and 2D keypoints + # by minimizing a weighted least squares loss + gt_cam_t = estimate_translation( + gt_model_joints, + gt_keypoints_2d_orig, + focal_length=focal_length, + img_size=img_res) + + opt_cam_t = estimate_translation( + opt_joints, + gt_keypoints_2d_orig, + focal_length=focal_length, + img_size=img_res) + + with torch.no_grad(): + loss_dict = self.registrant.evaluate( + global_orient=opt_pose[:, :3], + body_pose=opt_pose[:, 3:], + betas=opt_betas, + transl=opt_cam_t, + keypoints2d=gt_keypoints_2d_orig[:, :, :2], + keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2], + reduction_override='none') + opt_joint_loss = loss_dict['keypoint2d_loss'].sum(dim=-1).sum(dim=-1) + + if self.registration_mode == 'in_the_loop': + # Convert predicted rotation matrices to axis-angle + pred_rotmat_hom = torch.cat([ + pred_rotmat.detach().view(-1, 3, 3).detach(), + torch.tensor([0, 0, 1], dtype=torch.float32, + device=device).view(1, 3, 1).expand( + batch_size * 24, -1, -1) + ], + dim=-1) + pred_pose = rotation_matrix_to_angle_axis( + pred_rotmat_hom).contiguous().view(batch_size, -1) + # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation, + # so manually hack it + pred_pose[torch.isnan(pred_pose)] = 0.0 + + registrant_output = self.registrant( + keypoints2d=gt_keypoints_2d_orig[:, :, :2], + keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2], + init_global_orient=pred_pose[:, :3], + init_transl=pred_cam_t, + init_body_pose=pred_pose[:, 3:], + init_betas=pred_betas, + return_joints=True, + return_verts=True, + return_losses=True) + new_opt_vertices = registrant_output[ + 'vertices'] - pred_cam_t.unsqueeze(1) + new_opt_joints = registrant_output[ + 'joints'] - pred_cam_t.unsqueeze(1) + + new_opt_global_orient = registrant_output['global_orient'] + new_opt_body_pose = registrant_output['body_pose'] + new_opt_pose = torch.cat( + [new_opt_global_orient, new_opt_body_pose], dim=1) + + new_opt_betas = registrant_output['betas'] + new_opt_cam_t = registrant_output['transl'] + new_opt_joint_loss = registrant_output['keypoint2d_loss'].sum( + dim=-1).sum(dim=-1) + + # Will update the dictionary for the examples where the new loss + # is less than the current one + update = (new_opt_joint_loss < opt_joint_loss) + + opt_joint_loss[update] = new_opt_joint_loss[update] + opt_vertices[update, :] = new_opt_vertices[update, :] + opt_joints[update, :] = new_opt_joints[update, :] + opt_pose[update, :] = new_opt_pose[update, :] + opt_betas[update, :] = new_opt_betas[update, :] + opt_cam_t[update, :] = new_opt_cam_t[update, :] + + self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(), + is_flipped.cpu(), + update.cpu())] = (opt_pose.cpu(), opt_betas.cpu()) + + # Replace extreme betas with zero betas + opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0. + + # Replace the optimized parameters with the ground truth parameters, + # if available + opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :] + opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :] + opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :] + opt_pose[has_smpl, 3:] = gt_pose[has_smpl, :] + opt_pose[has_smpl, :3] = gt_global_orient[has_smpl, :] + opt_betas[has_smpl, :] = gt_betas[has_smpl, :] + + # Assert whether a fit is valid by comparing the joint loss with + # the threshold + valid_fit = (opt_joint_loss < threshold).to(device) + valid_fit = valid_fit | has_smpl + targets['valid_fit'] = valid_fit + + targets['opt_vertices'] = opt_vertices + targets['opt_cam_t'] = opt_cam_t + targets['opt_joints'] = opt_joints + targets['opt_pose'] = opt_pose + targets['opt_betas'] = opt_betas + + return targets + + def optimize_discrinimator(self, predictions: dict, data_batch: dict, + optimizer: dict): + """Optimize discrinimator during adversarial training.""" + set_requires_grad(self.disc, True) + fake_data = self.make_fake_data(predictions, requires_grad=False) + real_data = self.make_real_data(data_batch) + fake_score = self.disc(fake_data) + real_score = self.disc(real_data) + + disc_losses = {} + disc_losses['real_loss'] = self.loss_adv( + real_score, target_is_real=True, is_disc=True) + disc_losses['fake_loss'] = self.loss_adv( + fake_score, target_is_real=False, is_disc=True) + loss_disc, log_vars_d = self._parse_losses(disc_losses) + + optimizer['disc'].zero_grad() + loss_disc.backward() + optimizer['disc'].step() + + def optimize_generator(self, predictions: dict): + """Optimize generator during adversarial training.""" + set_requires_grad(self.disc, False) + fake_data = self.make_fake_data(predictions, requires_grad=True) + pred_score = self.disc(fake_data) + loss_adv = self.loss_adv( + pred_score, target_is_real=True, is_disc=False) + loss = dict(adv_loss=loss_adv) + return loss + + def compute_keypoints3d_loss( + self, + pred_keypoints3d: torch.Tensor, + gt_keypoints3d: torch.Tensor, + has_keypoints3d: Optional[torch.Tensor] = None): + """Compute loss for 3d keypoints.""" + keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1) + keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3) + pred_keypoints3d = pred_keypoints3d.float() + gt_keypoints3d = gt_keypoints3d[:, :, :3].float() + + # currently, only mpi_inf_3dhp and h36m have 3d keypoints + # both datasets have right_hip_extra and left_hip_extra + right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention) + left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention) + gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] + + gt_keypoints3d[:, left_hip_idx, :]) / 2 + pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] + + pred_keypoints3d[:, left_hip_idx, :]) / 2 + + gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :] + pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :] + loss = self.loss_keypoints3d( + pred_keypoints3d, gt_keypoints3d, reduction_override='none') + + # If has_keypoints3d is not None, then computes the losses on the + # instances that have ground-truth keypoints3d. + # But the zero confidence keypoints will be included in mean. + # Otherwise, only compute the keypoints3d + # which have positive confidence. + + # has_keypoints3d is None when the key has_keypoints3d + # is not in the datasets + if has_keypoints3d is None: + + valid_pos = keypoints3d_conf > 0 + if keypoints3d_conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_keypoints3d) + loss = torch.sum(loss * keypoints3d_conf) + loss /= keypoints3d_conf[valid_pos].numel() + else: + + keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1] + if keypoints3d_conf.shape[0] == 0: + return torch.Tensor([0]).type_as(gt_keypoints3d) + loss = loss[has_keypoints3d == 1] + loss = (loss * keypoints3d_conf).mean() + return loss + + def compute_keypoints2d_loss( + self, + pred_keypoints3d: torch.Tensor, + pred_cam: torch.Tensor, + gt_keypoints2d: torch.Tensor, + img_res: Optional[int] = 224, + focal_length: Optional[int] = 5000, + has_keypoints2d: Optional[torch.Tensor] = None): + """Compute loss for 2d keypoints.""" + keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) + keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) + gt_keypoints2d = gt_keypoints2d[:, :, :2].float() + pred_keypoints2d = project_points( + pred_keypoints3d, + pred_cam, + focal_length=focal_length, + img_res=img_res) + # Normalize keypoints to [-1,1] + # The coordinate origin of pred_keypoints_2d is + # the center of the input image. + pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1) + # The coordinate origin of gt_keypoints_2d is + # the top left corner of the input image. + gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1 + loss = self.loss_keypoints2d( + pred_keypoints2d, gt_keypoints2d, reduction_override='none') + + # If has_keypoints2d is not None, then computes the losses on the + # instances that have ground-truth keypoints2d. + # But the zero confidence keypoints will be included in mean. + # Otherwise, only compute the keypoints2d + # which have positive confidence. + # has_keypoints2d is None when the key has_keypoints2d + # is not in the datasets + + if has_keypoints2d is None: + valid_pos = keypoints2d_conf > 0 + if keypoints2d_conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_keypoints2d) + loss = torch.sum(loss * keypoints2d_conf) + loss /= keypoints2d_conf[valid_pos].numel() + else: + keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1] + if keypoints2d_conf.shape[0] == 0: + return torch.Tensor([0]).type_as(gt_keypoints2d) + loss = loss[has_keypoints2d == 1] + loss = (loss * keypoints2d_conf).mean() + + return loss + + def compute_keypoints2d_loss_cliff( + self, + pred_keypoints3d: torch.Tensor, + pred_cam: torch.Tensor, + gt_keypoints2d: torch.Tensor, + camera_center: torch.Tensor, + focal_length: torch.Tensor, + trans: torch.Tensor, + img_res: Optional[int] = 224, + has_keypoints2d: Optional[torch.Tensor] = None): + """Compute loss for 2d keypoints.""" + keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1) + keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2) + gt_keypoints2d = gt_keypoints2d[:, :, :2].float() + + device = gt_keypoints2d.device + batch_size, num_keypoints = pred_keypoints3d.shape[0:2] + + pred_keypoints2d = perspective_projection( + pred_keypoints3d, + rotation=torch.eye(3, device=device).unsqueeze(0).expand( + batch_size, -1, -1), + translation=pred_cam, + focal_length=focal_length, + camera_center=camera_center) + + pred_keypoints2d = torch.cat( + (pred_keypoints2d, torch.ones(batch_size, num_keypoints, + 1).to(device)), + dim=2) + # trans @ pred_keypoints2d2 + pred_keypoints2d = torch.einsum('bij,bkj->bki', trans, + pred_keypoints2d) + + # The coordinate origin of pred_keypoints_2d and gt_keypoints_2d is + # the top left corner of the input image. + pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1) - 1 + gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1 + loss = self.loss_keypoints2d( + pred_keypoints2d, gt_keypoints2d, reduction_override='none') + + # If has_keypoints2d is not None, then computes the losses on the + # instances that have ground-truth keypoints2d. + # But the zero confidence keypoints will be included in mean. + # Otherwise, only compute the keypoints2d + # which have positive confidence. + # has_keypoints2d is None when the key has_keypoints2d + # is not in the datasets + + if has_keypoints2d is None: + valid_pos = keypoints2d_conf > 0 + if keypoints2d_conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_keypoints2d) + loss = torch.sum(loss * keypoints2d_conf) + loss /= keypoints2d_conf[valid_pos].numel() + else: + keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1] + if keypoints2d_conf.shape[0] == 0: + return torch.Tensor([0]).type_as(gt_keypoints2d) + loss = loss[has_keypoints2d == 1] + loss = (loss * keypoints2d_conf).mean() + + return loss + + def compute_vertex_loss(self, pred_vertices: torch.Tensor, + gt_vertices: torch.Tensor, has_smpl: torch.Tensor): + """Compute loss for vertices.""" + gt_vertices = gt_vertices.float() + conf = has_smpl.float().view(-1, 1, 1) + conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2]) + loss = self.loss_vertex( + pred_vertices, gt_vertices, reduction_override='none') + valid_pos = conf > 0 + if conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_vertices) + loss = torch.sum(loss * conf) / conf[valid_pos].numel() + return loss + + def compute_smpl_pose_loss(self, pred_rotmat: torch.Tensor, + gt_pose: torch.Tensor, has_smpl: torch.Tensor): + """Compute loss for smpl pose.""" + conf = has_smpl.float().view(-1) + valid_pos = conf > 0 + if conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_pose) + pred_rotmat = pred_rotmat[valid_pos] + gt_pose = gt_pose[valid_pos] + conf = conf[valid_pos] + gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) + loss = self.loss_smpl_pose( + pred_rotmat, gt_rotmat, reduction_override='none') + loss = loss.view(loss.shape[0], -1).mean(-1) + loss = torch.mean(loss * conf) + return loss + + def compute_smpl_betas_loss(self, pred_betas: torch.Tensor, + gt_betas: torch.Tensor, + has_smpl: torch.Tensor): + """Compute loss for smpl betas.""" + conf = has_smpl.float().view(-1) + valid_pos = conf > 0 + if conf[valid_pos].numel() == 0: + return torch.Tensor([0]).type_as(gt_betas) + pred_betas = pred_betas[valid_pos] + gt_betas = gt_betas[valid_pos] + conf = conf[valid_pos] + loss = self.loss_smpl_betas( + pred_betas, gt_betas, reduction_override='none') + loss = loss.view(loss.shape[0], -1).mean(-1) + loss = torch.mean(loss * conf) + return loss + + def compute_camera_loss(self, cameras: torch.Tensor): + """Compute loss for predicted camera parameters.""" + loss = self.loss_camera(cameras) + return loss + + def compute_part_segmentation_loss(self, + pred_heatmap: torch.Tensor, + gt_vertices: torch.Tensor, + gt_keypoints2d: torch.Tensor, + gt_model_joints: torch.Tensor, + has_smpl: torch.Tensor, + img_res: Optional[int] = 224, + focal_length: Optional[int] = 500): + """Compute loss for part segmentations.""" + device = gt_keypoints2d.device + gt_keypoints2d_valid = gt_keypoints2d[has_smpl == 1] + batch_size = gt_keypoints2d_valid.shape[0] + + gt_vertices_valid = gt_vertices[has_smpl == 1] + gt_model_joints_valid = gt_model_joints[has_smpl == 1] + + if batch_size == 0: + return torch.Tensor([0]).type_as(gt_keypoints2d) + gt_cam_t = estimate_translation( + gt_model_joints_valid, + gt_keypoints2d_valid, + focal_length=focal_length, + img_size=img_res, + ) + + K = torch.eye(3) + K[0, 0] = focal_length + K[1, 1] = focal_length + K[2, 2] = 1 + K[0, 2] = img_res / 2. + K[1, 2] = img_res / 2. + K = K[None, :, :] + + R = torch.eye(3)[None, :, :] + device = gt_keypoints2d.device + gt_sem_mask = visualize_smpl.render_smpl( + verts=gt_vertices_valid, + R=R, + K=K, + T=gt_cam_t, + render_choice='part_silhouette', + resolution=img_res, + return_tensor=True, + body_model=self.body_model_train, + device=device, + in_ndc=False, + convention='pytorch3d', + projection='perspective', + no_grad=True, + batch_size=batch_size, + verbose=False, + ) + gt_sem_mask = torch.flip(gt_sem_mask, [1, 2]).squeeze(-1).detach() + pred_heatmap_valid = pred_heatmap[has_smpl == 1] + ph, pw = pred_heatmap_valid.size(2), pred_heatmap_valid.size(3) + h, w = gt_sem_mask.size(1), gt_sem_mask.size(2) + if ph != h or pw != w: + pred_heatmap_valid = F.interpolate( + input=pred_heatmap_valid, size=(h, w), mode='bilinear') + + loss = self.loss_segm_mask(pred_heatmap_valid, gt_sem_mask) + return loss + + def compute_losses(self, predictions: dict, targets: dict): + """Compute losses.""" + pred_betas = predictions['pred_shape'].view(-1, 10) + pred_pose = predictions['pred_pose'].view(-1, 24, 3, 3) + pred_cam_crop = predictions['pred_cam'].view(-1, 3) + + # NOTE: convert cam parameters from the crop to the full camera + img_h, img_w = targets['img_h'], targets['img_w'] + center, scale, focal_length = targets['center'], targets[ + 'scale'][:, 0], targets['focal_length'].squeeze(dim=1) + full_img_shape = torch.hstack((img_h, img_w)) + pred_cam = cam_crop2full(pred_cam_crop, center, scale, full_img_shape, + focal_length).to(torch.float32) + + gt_keypoints3d = targets['keypoints3d'] + # this should be in full frame + gt_keypoints2d = targets['keypoints2d'] + # pred_pose N, 24, 3, 3 + if self.body_model_train is not None: + pred_output = self.body_model_train( + betas=pred_betas, + body_pose=pred_pose[:, 1:], + global_orient=pred_pose[:, 0].unsqueeze(1), + pose2rot=False, + num_joints=gt_keypoints2d.shape[1]) + pred_keypoints3d = pred_output['joints'] + pred_vertices = pred_output['vertices'] + + # NOTE: use crop_trans to contain full -> crop so that pred keypoints + # are normalized to bbox + camera_center = torch.hstack((img_w, img_h)) / 2 + trans = targets['crop_trans'].float() + + # TODO: temp solution + if 'valid_fit' in targets: + has_smpl = targets['valid_fit'].view(-1) + # global_orient = targets['opt_pose'][:, :3].view(-1, 1, 3) + gt_pose = targets['opt_pose'] + gt_betas = targets['opt_betas'] + gt_vertices = targets['opt_vertices'] + else: + has_smpl = targets['has_smpl'].view(-1) + gt_pose = targets['smpl_body_pose'] + global_orient = targets['smpl_global_orient'].view(-1, 1, 3) + gt_pose = torch.cat((global_orient, gt_pose), dim=1).float() + gt_betas = targets['smpl_betas'].float() + + # gt_pose N, 72 + if self.body_model_train is not None: + gt_output = self.body_model_train( + betas=gt_betas, + body_pose=gt_pose[:, 3:], + global_orient=gt_pose[:, :3], + num_joints=gt_keypoints2d.shape[1]) + gt_vertices = gt_output['vertices'] + gt_model_joints = gt_output['joints'] + if 'has_keypoints3d' in targets: + has_keypoints3d = targets['has_keypoints3d'].squeeze(-1) + else: + has_keypoints3d = None + if 'has_keypoints2d' in targets: + has_keypoints2d = targets['has_keypoints2d'].squeeze(-1) + else: + has_keypoints2d = None + if 'pred_segm_mask' in predictions: + pred_segm_mask = predictions['pred_segm_mask'] + losses = {} + if self.loss_keypoints3d is not None: + losses['keypoints3d_loss'] = self.compute_keypoints3d_loss( + pred_keypoints3d, + gt_keypoints3d, + has_keypoints3d=has_keypoints3d) + if self.loss_keypoints2d is not None: + losses['keypoints2d_loss'] = self.compute_keypoints2d_loss_cliff( + pred_keypoints3d, + pred_cam, + gt_keypoints2d, + camera_center, + focal_length, + trans, + has_keypoints2d=has_keypoints2d) + if self.loss_vertex is not None: + losses['vertex_loss'] = self.compute_vertex_loss( + pred_vertices, gt_vertices, has_smpl) + if self.loss_smpl_pose is not None: + losses['smpl_pose_loss'] = self.compute_smpl_pose_loss( + pred_pose, gt_pose, has_smpl) + if self.loss_smpl_betas is not None: + losses['smpl_betas_loss'] = self.compute_smpl_betas_loss( + pred_betas, gt_betas, has_smpl) + if self.loss_camera is not None: + losses['camera_loss'] = self.compute_camera_loss(pred_cam) + if self.loss_segm_mask is not None: + losses['loss_segm_mask'] = self.compute_part_segmentation_loss( + pred_segm_mask, gt_vertices, gt_keypoints2d, gt_model_joints, + has_smpl) + + return losses + + @abstractmethod + def make_fake_data(self, predictions, requires_grad): + pass + + @abstractmethod + def make_real_data(self, data_batch): + pass + + @abstractmethod + def prepare_targets(self, data_batch): + pass + + def forward_train(self, **kwargs): + """Forward function for general training. + + For mesh estimation, we do not use this interface. + """ + raise NotImplementedError('This interface should not be used in ' + 'current training schedule. Please use ' + '`train_step` for training.') + + @abstractmethod + def forward_test(self, img, img_metas, **kwargs): + """Defines the computation performed at every call when testing.""" + pass + + +class CliffImageBodyModelEstimator(BodyModelEstimator): + + def make_fake_data(self, predictions: dict, requires_grad: bool): + pred_cam = predictions['pred_cam'] + pred_pose = predictions['pred_pose'] + pred_betas = predictions['pred_shape'] + if requires_grad: + fake_data = (pred_cam, pred_pose, pred_betas) + else: + fake_data = (pred_cam.detach(), pred_pose.detach(), + pred_betas.detach()) + return fake_data + + def make_real_data(self, data_batch: dict): + transl = data_batch['adv_smpl_transl'].float() + global_orient = data_batch['adv_smpl_global_orient'] + body_pose = data_batch['adv_smpl_body_pose'] + betas = data_batch['adv_smpl_betas'].float() + pose = torch.cat((global_orient, body_pose), dim=-1).float() + real_data = (transl, pose, betas) + return real_data + + def prepare_targets(self, data_batch: dict): + # Image Mesh Estimator does not need extra process for ground truth + return data_batch + + def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs): + """Defines the computation performed at every call when testing.""" + if self.backbone is not None: + features = self.backbone(img) + else: + features = kwargs['features'] + + if self.neck is not None: + features = self.neck(features) + + # NOTE: extras for Cliff inference + bbox_info = kwargs['bbox_info'] + predictions = self.head(features, bbox_info) + pred_pose = predictions['pred_pose'] + pred_betas = predictions['pred_shape'] + pred_cam_crop = predictions['pred_cam'].view(-1, 3) + + # convert the camera parameters from the crop camera to the full camera + img_h, img_w = kwargs['img_h'], kwargs['img_w'] + center, scale, focal_length = kwargs['center'], kwargs[ + 'scale'][:, 0], kwargs['focal_length'].squeeze(dim=1) + full_img_shape = torch.hstack((img_h, img_w)) + + pred_cam = cam_crop2full(pred_cam_crop, center, scale, full_img_shape, + focal_length).to(torch.float32) + + pred_output = self.body_model_test( + betas=pred_betas, + body_pose=pred_pose[:, 1:], + global_orient=pred_pose[:, 0].unsqueeze(1), + pose2rot=False) + + pred_vertices = pred_output['vertices'] + pred_keypoints_3d = pred_output['joints'] + all_preds = {} + all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy() + all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy() + all_preds['smpl_beta'] = pred_betas.detach().cpu().numpy() + all_preds['camera'] = pred_cam.detach().cpu().numpy() + all_preds['vertices'] = pred_vertices.detach().cpu().numpy() + image_path = [] + for img_meta in img_metas: + image_path.append(img_meta['image_path']) + all_preds['image_path'] = image_path + all_preds['image_idx'] = kwargs['sample_idx'] + return all_preds diff --git a/mmhuman3d/models/heads/builder.py b/mmhuman3d/models/heads/builder.py index 5e15c8ef..bdaa7e12 100644 --- a/mmhuman3d/models/heads/builder.py +++ b/mmhuman3d/models/heads/builder.py @@ -2,6 +2,7 @@ from mmcv.utils import Registry +from .cliff_head import CliffHead from .expose_head import ExPoseBodyHead, ExPoseFaceHead, ExPoseHandHead from .hmr_head import HMRHead from .hybrik_head import HybrIKHead @@ -16,6 +17,7 @@ HEADS.register_module(name='ExPoseBodyHead', module=ExPoseBodyHead) HEADS.register_module(name='ExPoseHandHead', module=ExPoseHandHead) HEADS.register_module(name='ExPoseFaceHead', module=ExPoseFaceHead) +HEADS.register_module(name='CliffHead', module=CliffHead) HEADS.register_module(name='PyMAFXHead', module=PyMAFXHead) HEADS.register_module(name='Regressor', module=Regressor) diff --git a/mmhuman3d/models/heads/cliff_head.py b/mmhuman3d/models/heads/cliff_head.py new file mode 100644 index 00000000..037e37d1 --- /dev/null +++ b/mmhuman3d/models/heads/cliff_head.py @@ -0,0 +1,98 @@ +import numpy as np +import torch +import torch.nn as nn +from mmcv.runner.base_module import BaseModule + +from mmhuman3d.utils.geometry import rot6d_to_rotmat + + +class CliffHead(BaseModule): + + def __init__(self, + feat_dim, + smpl_mean_params=None, + npose=144, + nbeta=10, + ncam=3, + nbbox=3, + hdim=1024, + init_cfg=None): + super(CliffHead, self).__init__(init_cfg=init_cfg) + self.fc1 = nn.Linear(feat_dim + nbbox + npose + nbeta + ncam, hdim) + self.drop1 = nn.Dropout() + self.fc2 = nn.Linear(hdim, hdim) + self.drop2 = nn.Dropout() + self.decpose = nn.Linear(hdim, npose) + self.decshape = nn.Linear(hdim, nbeta) + self.deccam = nn.Linear(hdim, ncam) + + nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) + nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) + nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) + + if smpl_mean_params is None: + init_pose = torch.zeros([1, npose]) + init_shape = torch.zeros([1, nbeta]) + init_cam = torch.FloatTensor([[1, 0, 0]]) + else: + mean_params = np.load(smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy( + mean_params['shape'][:].astype('float32')).unsqueeze(0) + init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.register_buffer('init_cam', init_cam) + + def forward(self, + x, + bbox_info, + init_pose=None, + init_shape=None, + init_cam=None, + n_iter=3): + + # inherited from hmr head, only support one layer feature + if isinstance(x, list) or isinstance(x, tuple): + x = x[-1] + + output_seq = False + if len(x.shape) == 4: + # use feature from the last layer of the backbone + # apply global average pooling on the feature map + x = x.mean(dim=-1).mean(dim=-1) + elif len(x.shape) == 3: + # temporal feature + raise NotImplementedError + + batch_size = x.shape[0] + if init_pose is None: + init_pose = self.init_pose.expand(batch_size, -1) + if init_shape is None: + init_shape = self.init_shape.expand(batch_size, -1) + if init_cam is None: + init_cam = self.init_cam.expand(batch_size, -1) + + pred_pose = init_pose + pred_shape = init_shape + pred_cam = init_cam + for i in range(n_iter): + xc = torch.cat([x, bbox_info, pred_pose, pred_shape, pred_cam], 1) + xc = self.fc1(xc) + xc = self.drop1(xc) + xc = self.fc2(xc) + xc = self.drop2(xc) + pred_pose = self.decpose(xc) + pred_pose + pred_shape = self.decshape(xc) + pred_shape + pred_cam = self.deccam(xc) + pred_cam + + pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) + + if output_seq: + raise NotImplementedError + output = { + 'pred_pose': pred_rotmat, + 'pred_shape': pred_shape, + 'pred_cam': pred_cam + } + return output diff --git a/mmhuman3d/utils/geometry.py b/mmhuman3d/utils/geometry.py index 88dcdccd..09a8cef4 100644 --- a/mmhuman3d/utils/geometry.py +++ b/mmhuman3d/utils/geometry.py @@ -417,6 +417,27 @@ def weak_perspective_projection(points, scale, translation): return projected_points +def cam_crop2full(crop_cam, center, scale, full_img_shape, focal_length): + """convert the camera parameters from the crop camera to the full camera. + + :param crop_cam: shape=(N, 3) weak perspective camera in cropped + img coordinates (s, tx, ty) + :param center: shape=(N, 2) bbox coordinates (c_x, c_y) + :param scale: shape=(N, 1) square bbox resolution (b / 200) + :param full_img_shape: shape=(N, 2) original image height and width + :param focal_length: shape=(N,) + :return: + """ + img_h, img_w = full_img_shape[:, 0], full_img_shape[:, 1] + cx, cy, b = center[:, 0], center[:, 1], scale + bs = b * crop_cam[:, 0] + 1e-9 + tz = 2 * focal_length / bs + tx = (2 * (cx - img_w / 2.) / bs) + crop_cam[:, 1] + ty = (2 * (cy - img_h / 2.) / bs) + crop_cam[:, 2] + full_cam = torch.stack([tx, ty, tz], dim=-1) + return full_cam + + def projection(pred_joints, pred_camera, iwp_mode=True): """Project 3D points on the image plane based on the given camera info, Identity rotation and Weak Perspective (IWP) camera is used when diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 40562562..cb4edc25 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -8,7 +8,7 @@ h5py matplotlib numpy opencv-python -pandas +pandas<2.0.0 pickle5 plyfile rtree diff --git a/setup.cfg b/setup.cfg index 8bc3c131..a5899aa2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,6 @@ multi_line_output = 3 include_trailing_comma = true known_standard_library = pkg_resources,setuptools known_first_party = mmhuman3d -known_third_party =PIL,cdflib,colormap,cv2,einops,h5py,matplotlib,mmcv,mpl_toolkits,numpy,openpifpaf,pickle5,plyfile,pytest,pytorch3d,pytorch_sphinx_theme,scipy,skimage,smplx,surrogate,torch,tqdm,trimesh,vedo +known_third_party = PIL,cdflib,colormap,cv2,einops,h5py,matplotlib,mmcv,mpl_toolkits,numpy,openpifpaf,pickle5,plyfile,pytest,pytorch3d,pytorch_sphinx_theme,scipy,skimage,smplx,surrogate,torch,tqdm,trimesh,vedo no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_data_converters.py b/tests/test_data_converters.py index 4210e2fb..b9e77f80 100644 --- a/tests/test_data_converters.py +++ b/tests/test_data_converters.py @@ -323,6 +323,14 @@ def test_multi_human_data_preprocess(): assert os.path.exists('/tmp/preprocessed_npzs/' + 'crowdpose_test.npz') assert os.path.exists('/tmp/preprocessed_npzs/' + 'crowdpose_trainval.npz') + CLIFF_ROOT = os.path.join(root_path, 'eft') + cfg = dict(type='CliffConverter', modes=['coco', 'mpii']) + data_converter = build_data_converter(cfg) + data_converter.convert( + CLIFF_ROOT, output_path, enable_multi_human_data=True) + assert os.path.exists('/tmp/preprocessed_npzs/' + 'cliff_coco_train.npz') + assert os.path.exists('/tmp/preprocessed_npzs/' + 'cliff_mpii_train.npz') + def test_preprocessed_npz(): npz_folder = '/tmp/preprocessed_npzs' diff --git a/tests/test_datasets/test_pipelines.py b/tests/test_datasets/test_pipelines.py index 4e0daf61..f14a003a 100644 --- a/tests/test_datasets/test_pipelines.py +++ b/tests/test_datasets/test_pipelines.py @@ -2,6 +2,7 @@ import pytest from mmhuman3d.data.datasets.pipelines import ( + GetBboxInfo, LoadImageFromFile, SyntheticOcclusion, ) @@ -57,3 +58,17 @@ def test_synthetic_occlusion(): results = pipeline(results) assert results['img'].shape == (224, 224, 3) + + +def test_get_bbox_inf(): + pipeline = GetBboxInfo() + results = { + 'img': np.ones((224, 224, 3)), + 'center': np.array([100, 100]), + 'scale': np.array([10, 10]) + } + pipeline(results=results) + assert 'img_h' in results + assert 'img_w' in results + assert 'focal_length' in results + assert 'bbox_info' in results diff --git a/tests/test_models/test_architectures/test_cliff_mesh_estimator.py b/tests/test_models/test_architectures/test_cliff_mesh_estimator.py new file mode 100644 index 00000000..2896d787 --- /dev/null +++ b/tests/test_models/test_architectures/test_cliff_mesh_estimator.py @@ -0,0 +1,417 @@ +import torch + +from mmhuman3d.core.cameras import build_cameras +from mmhuman3d.models.architectures.cliff_mesh_estimator import \ + CliffImageBodyModelEstimator # noqa: E501 +from mmhuman3d.models.body_models.builder import build_body_model +from mmhuman3d.utils.geometry import project_points + + +def test_cliff_image_body_mesh_estimator(): + backbone = dict( + type='ResNet', + depth=50, + out_indices=[3], + norm_eval=False, + norm_cfg=dict(type='SyncBN', requires_grad=True), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')) + head = dict( + type='CliffHead', + feat_dim=2048, + smpl_mean_params='data/body_models/smpl_mean_params.npz') + body_model_train = dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_54', + model_path='data/body_models/smpl', + keypoint_approximate=True, + extra_joints_regressor='data/body_models/J_regressor_extra.npy') + body_model_test = dict( + type='SMPL', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/smpl', + joints_regressor='data/body_models/J_regressor_h36m.npy') + convention = 'smpl_54' + loss_keypoints3d = dict(type='SmoothL1Loss', loss_weight=100) + loss_keypoints2d = dict(type='SmoothL1Loss', loss_weight=10) + loss_vertex = dict(type='L1Loss', loss_weight=2) + loss_smpl_pose = dict(type='MSELoss', loss_weight=3) + loss_smpl_betas = dict(type='MSELoss', loss_weight=0.02) + loss_adv = dict( + type='GANLoss', + gan_type='lsgan', + real_label_val=1.0, + fake_label_val=0.0, + loss_weight=1) + model = CliffImageBodyModelEstimator( + backbone=backbone, + head=head, + body_model_train=body_model_train, + body_model_test=body_model_test, + convention=convention, + loss_keypoints3d=loss_keypoints3d, + loss_keypoints2d=loss_keypoints2d, + loss_vertex=loss_vertex, + loss_smpl_pose=loss_smpl_pose, + loss_smpl_betas=loss_smpl_betas, + loss_adv=loss_adv) + assert model.backbone is not None + assert model.head is not None + assert model.body_model_train is not None + assert model.body_model_test is not None + assert model.convention == 'smpl_54' + assert model.loss_keypoints3d is not None + assert model.loss_keypoints2d is not None + assert model.loss_vertex is not None + assert model.loss_smpl_pose is not None + assert model.loss_smpl_betas is not None + assert model.loss_adv is not None + + +def test_compute_keypoints3d_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100)) + + pred_keypoints3d = torch.zeros((32, 54, 3)) + gt_keypoints3d = torch.zeros((32, 54, 4)) + loss_empty = model.compute_keypoints3d_loss(pred_keypoints3d, + gt_keypoints3d) + assert loss_empty == 0 + + pred_keypoints3d = torch.randn((32, 54, 3)) + gt_keypoints3d = torch.randn((32, 54, 4)) + gt_keypoints3d[:, :, 3] = torch.sigmoid(gt_keypoints3d[:, :, 3]) + loss = model.compute_keypoints3d_loss(pred_keypoints3d, gt_keypoints3d) + assert loss > 0 + + has_keypoints3d = torch.ones(32) + loss = model.compute_keypoints3d_loss( + pred_keypoints3d, gt_keypoints3d, has_keypoints3d=has_keypoints3d) + assert loss > 0 + has_keypoints3d = torch.zeros(32) + loss = model.compute_keypoints3d_loss( + pred_keypoints3d, gt_keypoints3d, has_keypoints3d=has_keypoints3d) + assert loss == 0 + + +def test_compute_keypoints2d_loss_cliff(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10)) + + pred_keypoints3d = torch.zeros((32, 54, 3)) + gt_keypoints2d = torch.zeros((32, 54, 3)) + pred_cam = torch.randn((32, 3)) + camera_center = torch.randn((32, 2)) + trans = torch.randn((32, 2, 3)) + focal_length = 5000 + loss_empty = model.compute_keypoints2d_loss_cliff(pred_keypoints3d, + pred_cam, gt_keypoints2d, + camera_center, + focal_length, trans) + assert loss_empty == 0 + + pred_keypoints3d = torch.randn((32, 54, 3)) + gt_keypoints2d = torch.randn((32, 54, 3)) + gt_keypoints2d[:, :, 2] = torch.sigmoid(gt_keypoints2d[:, :, 2]) + pred_cam = torch.randn((32, 3)) + loss = model.compute_keypoints2d_loss_cliff(pred_keypoints3d, pred_cam, + gt_keypoints2d, camera_center, + focal_length, trans) + assert loss > 0 + + has_keypoints2d = torch.ones((32)) + loss = model.compute_keypoints2d_loss_cliff( + pred_keypoints3d, + pred_cam, + gt_keypoints2d, + camera_center, + focal_length, + trans, + has_keypoints2d=has_keypoints2d) + assert loss > 0 + + has_keypoints2d = torch.zeros((32)) + loss = model.compute_keypoints2d_loss_cliff( + pred_keypoints3d, + pred_cam, + gt_keypoints2d, + camera_center, + focal_length, + trans, + has_keypoints2d=has_keypoints2d) + assert loss == 0 + + +def test_compute_keypoints2d_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10)) + + pred_keypoints3d = torch.zeros((32, 54, 3)) + gt_keypoints2d = torch.zeros((32, 54, 3)) + pred_cam = torch.randn((32, 3)) + loss_empty = model.compute_keypoints2d_loss(pred_keypoints3d, pred_cam, + gt_keypoints2d) + assert loss_empty == 0 + + pred_keypoints3d = torch.randn((32, 54, 3)) + gt_keypoints2d = torch.randn((32, 54, 3)) + gt_keypoints2d[:, :, 2] = torch.sigmoid(gt_keypoints2d[:, :, 2]) + pred_cam = torch.randn((32, 3)) + loss = model.compute_keypoints2d_loss(pred_keypoints3d, pred_cam, + gt_keypoints2d) + assert loss > 0 + + has_keypoints2d = torch.ones((32)) + loss = model.compute_keypoints2d_loss( + pred_keypoints3d, + pred_cam, + gt_keypoints2d, + has_keypoints2d=has_keypoints2d) + assert loss > 0 + + has_keypoints2d = torch.zeros((32)) + loss = model.compute_keypoints2d_loss( + pred_keypoints3d, + pred_cam, + gt_keypoints2d, + has_keypoints2d=has_keypoints2d) + assert loss == 0 + + +def test_compute_vertex_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', loss_vertex=dict(type='L1Loss', loss_weight=2)) + + pred_vertices = torch.randn((32, 4096, 3)) + gt_vertices = torch.randn((32, 4096, 3)) + has_smpl = torch.zeros((32)) + loss_empty = model.compute_vertex_loss(pred_vertices, gt_vertices, + has_smpl) + assert loss_empty == 0 + + pred_vertices = torch.randn((32, 4096, 3)) + gt_vertices = torch.randn((32, 4096, 3)) + has_smpl = torch.ones((32)) + loss = model.compute_vertex_loss(pred_vertices, gt_vertices, has_smpl) + assert loss > 0 + + +def test_compute_smpl_pose_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_smpl_pose=dict(type='MSELoss', loss_weight=3)) + + pred_rotmat = torch.randn((32, 24, 3, 3)) + gt_pose = torch.randn((32, 24, 3)) + has_smpl = torch.zeros((32)) + loss_empty = model.compute_smpl_pose_loss(pred_rotmat, gt_pose, has_smpl) + assert loss_empty == 0 + + pred_rotmat = torch.randn((32, 24, 3, 3)) + gt_pose = torch.randn((32, 24, 3)) + has_smpl = torch.ones((32)) + loss = model.compute_smpl_pose_loss(pred_rotmat, gt_pose, has_smpl) + assert loss > 0 + + +def test_compute_part_segm_loss(): + N = 1 + random_body_pose = torch.rand((N, 69)) + body_model_train = dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_49', + model_path='data/body_models/smpl', + extra_joints_regressor='data/body_models/J_regressor_extra.npy') + body_model = build_body_model(body_model_train) + + body_model_output = body_model(body_pose=random_body_pose, ) + gt_model_joins = body_model_output['joints'].detach() + cam = torch.ones(N, 3) + gt_keypoints2d = project_points( + gt_model_joins, cam, focal_length=5000, img_res=224) + loss_segm_mask = dict(type='CrossEntropyLoss', loss_weight=60) + + gt_keypoints2d = torch.cat([gt_keypoints2d, torch.ones(N, 49, 1)], dim=-1) + model = CliffImageBodyModelEstimator( + body_model_train=body_model_train, + loss_segm_mask=loss_segm_mask, + ) + gt_vertices = torch.randn(N, 6890, 3) + pred_heatmap = torch.zeros(N, 25, 224, 224) + pred_heatmap[:, 0, :, :] = 1 + has_smpl = torch.ones((N)) + + loss = model.compute_part_segmentation_loss( + pred_heatmap, + gt_vertices, + has_smpl=has_smpl, + gt_keypoints2d=gt_keypoints2d, + gt_model_joints=gt_model_joins) + assert loss > 0 + + +def test_compute_smpl_betas_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02)) + + pred_betas = torch.randn((32, 10)) + gt_betas = torch.randn((32, 10)) + has_smpl = torch.zeros((32)) + loss_empty = model.compute_smpl_betas_loss(pred_betas, gt_betas, has_smpl) + assert loss_empty == 0 + + pred_betas = torch.randn((32, 10)) + gt_betas = torch.randn((32, 10)) + has_smpl = torch.ones((32)) + loss = model.compute_smpl_betas_loss(pred_betas, gt_betas, has_smpl) + assert loss > 0 + + +def test_compute_camera_loss(): + model = CliffImageBodyModelEstimator( + convention='smpl_54', + loss_camera=dict(type='CameraPriorLoss', loss_weight=60), + ) + + pred_cameras = torch.randn((32, 3)) + loss = model.compute_camera_loss(pred_cameras) + assert loss > 0 + + +def test_compute_losses(): + N = 32 + predictions = {} + predictions['pred_shape'] = torch.randn(N, 10) + predictions['pred_pose'] = torch.randn(N, 24, 3, 3) + predictions['pred_cam'] = torch.randn(N, 3) + + targets = {} + targets['keypoints3d'] = torch.randn(N, 45, 4) + targets['keypoints2d'] = torch.randn(N, 45, 3) + targets['has_smpl'] = torch.ones(N) + targets['smpl_body_pose'] = torch.randn(N, 23, 3) + targets['smpl_global_orient'] = torch.randn(N, 3) + targets['smpl_betas'] = torch.randn(N, 10) + targets['img_h'] = torch.ones(N, 1) * 256 + targets['img_w'] = torch.ones(N, 1) * 192 + targets['center'] = torch.randn(N, 2) + targets['scale'] = torch.randn(N, 1) + targets['focal_length'] = torch.randn(N, 1) + targets['crop_trans'] = torch.randn(N, 2, 3) + + model = CliffImageBodyModelEstimator(convention='smpl_54') + loss = model.compute_losses(predictions, targets) + assert loss == {} + + model = CliffImageBodyModelEstimator( + convention='smpl_45', + body_model_train=dict( + type='SMPL', + keypoint_src='smpl_45', + keypoint_dst='smpl_45', + model_path='data/body_models/smpl'), + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100), + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10), + loss_vertex=dict(type='L1Loss', loss_weight=2), + loss_smpl_pose=dict(type='MSELoss', loss_weight=3), + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02), + loss_camera=dict(type='CameraPriorLoss', loss_weight=60)) + + loss = model.compute_losses(predictions, targets) + assert 'keypoints3d_loss' in loss + assert 'keypoints2d_loss' in loss + assert 'vertex_loss' in loss + assert 'smpl_pose_loss' in loss + assert 'smpl_betas_loss' in loss + assert 'camera_loss' in loss + + +def test_run_registration(): + batch_size = 2 + body_model = dict( + type='SMPL', + keypoint_src='smpl_54', + keypoint_dst='smpl_49', + keypoint_approximate=True, + model_path='data/body_models/smpl', + extra_joints_regressor='data/body_models/J_regressor_extra.npy', + batch_size=batch_size) + + camera = build_cameras( + dict( + type='PerspectiveCameras', + convention='opencv', + in_ndc=False, + focal_length=5000, + image_size=(224, 224), + principal_point=(112, 112))) + + registrant = dict( + type='SMPLify', + body_model=body_model, + num_epochs=1, + stages=[ + dict( + num_iter=1, + fit_global_orient=True, + fit_transl=True, + fit_body_pose=False, + fit_betas=False) + ], + optimizer=dict(type='Adam', lr=1e-2, betas=(0.9, 0.999)), + keypoints2d_loss=dict( + type='KeypointMSELoss', + loss_weight=1.0, + reduction='sum', + sigma=100), + device=torch.device('cpu'), + camera=camera) + + registration = dict(mode='in_the_loop', registrant=registrant) + + model = CliffImageBodyModelEstimator( + body_model_train=body_model, registration=registration) + assert model.registrant is not None + assert model.fits_dict is not None + + transl = torch.Tensor([0, 0, 1]).view(1, 3).expand(batch_size, -1) + + predictions = dict( + pred_pose=torch.zeros((batch_size, 24, 3, 3)), + pred_shape=torch.zeros((batch_size, 10)), + pred_cam=transl, + ) + + # generate 2D keypoints + smpl = build_body_model(body_model) + keypoints3d = smpl(transl=transl)['joints'].detach() + keypoints2d_xyd = camera.transform_points_screen(keypoints3d) + keypoints2d = keypoints2d_xyd[..., :2] + keypoints2d_conf = torch.ones(*keypoints2d.shape[:2], 1) + keypoints2d = torch.cat([keypoints2d, keypoints2d_conf], dim=-1) + + targets = dict( + img_metas=[dict(dataset_name='coco'), + dict(dataset_name='h36m')], + sample_idx=torch.zeros((batch_size, 1), dtype=torch.int), + is_flipped=torch.tensor([0, 1], dtype=torch.int), + rotation=torch.tensor([0.0, 0.1]), + smpl_betas=torch.zeros((batch_size, 10)), + smpl_global_orient=torch.zeros((batch_size, 3)), + smpl_body_pose=torch.zeros((batch_size, 69)), + keypoints2d=keypoints2d, + has_smpl=torch.tensor([0, 1], dtype=torch.int)) + + model.run_registration(predictions=predictions, targets=targets) + assert 'valid_fit' in targets + assert 'opt_vertices' in targets + assert 'opt_cam_t' in targets + assert 'opt_joints' in targets + assert 'opt_pose' in targets + assert 'opt_betas' in targets diff --git a/tests/test_models/test_heads/test_cliff_head.py b/tests/test_models/test_heads/test_cliff_head.py new file mode 100644 index 00000000..691add98 --- /dev/null +++ b/tests/test_models/test_heads/test_cliff_head.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest +import torch + +from mmhuman3d.models.heads.builder import CliffHead + + +def test_cliff_head(): + # initialize models + model = CliffHead( + feat_dim=2048, + smpl_mean_params='data/body_models/smpl_mean_params.npz') + + # image feature from backbone + batch_size = 32 + bbox_info = [-0.5, 0.2, 1.5] + bbox_info = torch.FloatTensor([bbox_info] * batch_size) + x0_shape = (batch_size, 2048, 7, 7) + x0 = _demo_head_inputs(x0_shape) + x0 = torch.tensor(x0).float() + y0 = model(x0, bbox_info) + assert y0['pred_pose'].shape == (batch_size, 24, 3, 3) + assert y0['pred_shape'].shape == (batch_size, 10) + assert y0['pred_cam'].shape == (batch_size, 3) + + # image feature from multi-layer backbone + x1_1_shape = (batch_size, 1024, 14, 14) + x1_2_shape = (batch_size, 2048, 7, 7) + x1 = [_demo_head_inputs(x1_1_shape), _demo_head_inputs(x1_2_shape)] + y1 = model(x1, bbox_info) + assert y1['pred_pose'].shape == (batch_size, 24, 3, 3) + assert y1['pred_shape'].shape == (batch_size, 10) + assert y1['pred_cam'].shape == (batch_size, 3) + + # test temporal feature + T = 16 + x_temp_shape = (batch_size, T, 1024) + x_temp = _demo_head_inputs(x_temp_shape) + with pytest.raises(NotImplementedError): + model(x_temp, bbox_info) + + # test other cases + model_wo_smpl_mean_params = CliffHead(feat_dim=2048) + assert model_wo_smpl_mean_params.init_pose.shape == (1, 144) + assert model_wo_smpl_mean_params.init_shape.shape == (1, 10) + assert model_wo_smpl_mean_params.init_cam.shape == (1, 3) + + +def _demo_head_inputs(input_shape=(1, 3, 64, 64)): + """Create a superset of inputs needed to run models. + + Args: + input_shape (tuple): input batch dimensions. + Default: (1, 3, 64, 64). + """ + features = np.random.random(input_shape) + features = torch.FloatTensor(features) + + return features diff --git a/tools/convert_datasets.py b/tools/convert_datasets.py index 83d50ffd..bd06fbcd 100644 --- a/tools/convert_datasets.py +++ b/tools/convert_datasets.py @@ -57,7 +57,7 @@ gta_human=dict(type='GTAHumanConverter', prefix='gta_human'), humman=dict( type='HuMManConverter', modes=['train', 'test'], prefix='humman'), -) + cliff=dict(type='CliffConverter', modes=['coco', 'mpii'])) def parse_args(): From 03b4666ad7071cd6d1d49f652faffba1ef679be5 Mon Sep 17 00:00:00 2001 From: Zhongang Cai <62529255+caizhongang@users.noreply.github.com> Date: Wed, 5 Apr 2023 15:12:13 +0800 Subject: [PATCH 2/7] Bump version to v0.11.0 (#333) * Bump version to v0.11.0 * Update version to v0.11.0 --- README.md | 11 +++++++---- README_CN.md | 12 ++++++++---- mmhuman3d/version.py | 3 +-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index d7a0596c..23eb0aea 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe- A suite of differentiale visualization tools for human parametric model rendering (including part segmentation, depth map and point clouds) and conventional 2D/3D keypoints are available. ## News +- 2023-04-05: MMHuman3D [v0.11.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.11.0) is released. Major updates include: + - Add [ExPose](configs/expose) inference + - Add [PyMAF-X](configs/pymafx) inference + - Support [CLIFF](configs/cliff) - 2022-10-12: MMHuman3D [v0.10.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.10.0) is released. Major updates include: - Add webcam demo and real-time renderer - Update dataloader to speed up training @@ -53,10 +57,6 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe- - Support new body model [STAR](https://star.is.tue.mpg.de/) - Release of [GTA-Human](https://caizhongang.github.io/projects/GTA-Human/) dataset with SPIN-FT (51.98 mm) and PARE-FT (46.84 mm) baselines! (Official) - Refactor registration and improve performance of SPIN to 57.54 mm -- 2022-05-31: MMHuman3D [v0.8.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.8.0) is released. Major updates include: - - Support SmoothNet (added by paper authors) - - Fix circular import and up to 2.5x speed up in module initialization - - Add documentations in Chinese ## Benchmark and Model Zoo @@ -92,6 +92,9 @@ Supported methods: - [x] [ExPose](https://expose.is.tue.mpg.de) (ECCV'2020) - [x] [BalancedMSE](https://sites.google.com/view/balanced-mse/home) (CVPR'2022) - [x] [PyMAF-X](https://www.liuyebin.com/pymaf-x/) (arXiv'2022) +- [x] [ExPose](configs/expose) (ECCV'2020) +- [x] [PyMAF-X](configs/pymafx) (arXiv'2022) +- [x] [CLIFF](configs/cliff) (ECCV'2022) diff --git a/README_CN.md b/README_CN.md index a04dfc7f..678afce8 100644 --- a/README_CN.md +++ b/README_CN.md @@ -44,6 +44,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe- 一整套可微的可视化工具支持人体参数化模型的渲染(包括部分分割,深度图以及点云)和传统 2D/3D 关键点的可视化。 ## 最新进展 +- 2023-04-05: MMHuman3D [v0.11.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.11.0) 已经发布. 主要更新包括: + - 增加 [ExPose](configs/expose) 推理 + - 增加 [PyMAF-X](configs/pymafx) 推理 + - 支持 [CLIFF](configs/cliff) - 2022-10-12: MMHuman3D [v0.10.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.10.0) 已经发布. 主要更新包括: - 支持调用本地摄像头实时渲染 - 更新数据载入脚本,进而实现训练加速 @@ -53,10 +57,6 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe- - 支持新的人体参数化模型 [STAR](https://star.is.tue.mpg.de/) - 开源 [GTA-Human](https://caizhongang.github.io/projects/GTA-Human/) 数据集,以及 SPIN-FT (51.98 mm) 和 PARE-FT (46.84 mm) 基线! (官方开源) - 重构配准管线并提升SPIN至 57.54 mm -- 2022-05-31: MMHuman3D [v0.8.0](https://github.com/open-mmlab/mmhuman3d/releases/tag/v0.8.0) 已经发布. 主要更新包括: - - 支持 SmoothNet(由论文作者添加) - - 修复循环引用问题,获得最多2.5倍速度提升 - - 增加中文版文档 ## 基准与模型库 @@ -91,6 +91,10 @@ https://user-images.githubusercontent.com/62529255/144362861-e794b404-c48f-4ebe- - [x] [SmoothNet](https://ailingzeng.site/smoothnet) (ECCV'2022) - [x] [ExPose](https://expose.is.tue.mpg.de) (ECCV'2020) - [x] [BalancedMSE](https://sites.google.com/view/balanced-mse/home) (CVPR'2022) +- [x] [ExPose](configs/expose) (ECCV'2020) +- [x] [PyMAF-X](configs/pymafx) (arXiv'2022) +- [x] [CLIFF](configs/cliff) (ECCV'2022) + diff --git a/mmhuman3d/version.py b/mmhuman3d/version.py index c64e135f..a62d3c7b 100644 --- a/mmhuman3d/version.py +++ b/mmhuman3d/version.py @@ -1,7 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. -# __version__ = '0.9.0' -__version__ = '0.10.0' +__version__ = '0.11.0' def parse_version_info(version_str): From 9edb2113299377e0b7086f00c434a69f1f4d05fb Mon Sep 17 00:00:00 2001 From: Yining Li Date: Wed, 26 Apr 2023 00:12:23 +0800 Subject: [PATCH 3/7] [Fix] update mmcv maximum version to 1.8.0 (#323) * update mmcv maximum version to 0.8.0 * change to mmcv 1.6.0 --------- Co-authored-by: ttxskk --- mmhuman3d/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmhuman3d/__init__.py b/mmhuman3d/__init__.py index af1b11e3..ddc69792 100644 --- a/mmhuman3d/__init__.py +++ b/mmhuman3d/__init__.py @@ -16,7 +16,7 @@ def digit_version(version_str): mmcv_minimum_version = '1.3.17' -mmcv_maximum_version = '1.6.1' +mmcv_maximum_version = '1.8.0' mmcv_version = digit_version(mmcv.__version__) From 2774577a436ee4377a97fcf69529195a0ce9bd27 Mon Sep 17 00:00:00 2001 From: Zhongang Cai <62529255+caizhongang@users.noreply.github.com> Date: Tue, 20 Jun 2023 16:03:38 +0800 Subject: [PATCH 4/7] Add helper script for GTA-Human visualization (#357) --- configs/gta_human/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/gta_human/README.md b/configs/gta_human/README.md index 24166909..0a04d172 100644 --- a/configs/gta_human/README.md +++ b/configs/gta_human/README.md @@ -64,6 +64,11 @@ Hence, `gta_human_4x.npz` is used as the training, it may be obtained in two way Please refer to [getting_started.md](../../docs/getting_started.md) for training and evaluation on MMHuman3D. +## Visualization + +We prepared a script with a sample sequence that should help you with the SMPL overlay. +You may download it from [here](https://drive.google.com/file/d/11osPM67KiQN6NbdJo3plcgNoPNxfJ_j7/view?usp=share_link). + ## Notes For different base models, you can find detailed data preparation steps in each subfolder. From 46dc586303daa1476f5982d20979339fd65740e8 Mon Sep 17 00:00:00 2001 From: ttxskk Date: Mon, 3 Jul 2023 15:15:58 +0800 Subject: [PATCH 5/7] [Upgrade] Specify the numpy version less than 1.24. (#361) * downgrade numpy version to 1.23.1 * specify numpy version to <=1.23.1 * specify numpy version as <1.24 * fix lint --- configs/gta_human/README.md | 2 +- requirements/runtime.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/gta_human/README.md b/configs/gta_human/README.md index 0a04d172..f2b5e733 100644 --- a/configs/gta_human/README.md +++ b/configs/gta_human/README.md @@ -66,7 +66,7 @@ Please refer to [getting_started.md](../../docs/getting_started.md) for training ## Visualization -We prepared a script with a sample sequence that should help you with the SMPL overlay. +We prepared a script with a sample sequence that should help you with the SMPL overlay. You may download it from [here](https://drive.google.com/file/d/11osPM67KiQN6NbdJo3plcgNoPNxfJ_j7/view?usp=share_link). ## Notes diff --git a/requirements/runtime.txt b/requirements/runtime.txt index cb4edc25..813f02a9 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -6,7 +6,7 @@ easydev einops h5py matplotlib -numpy +numpy<1.24 opencv-python pandas<2.0.0 pickle5 From 0e1f101b87e8a793b3656bca34f4e079a7e6f6cd Mon Sep 17 00:00:00 2001 From: WEI CHEN <77597327+Wei-Chen-hub@users.noreply.github.com> Date: Mon, 3 Jul 2023 19:48:50 +0800 Subject: [PATCH 6/7] Update description of HumanData (English and Chinese) (#356) * Update human_data.md * Update human_data.md * Update human_data.md * fix * minor docs update English & Chinese * minor docs update English & Chinese --------- Co-authored-by: wei-chen-hub --- docs/human_data.md | 66 +++++++++++++++++++++++++++++++++++----- docs_zh-CN/human_data.md | 63 ++++++++++++++++++++++++++++++++------ 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/docs/human_data.md b/docs/human_data.md index 7b11804f..4956e93a 100644 --- a/docs/human_data.md +++ b/docs/human_data.md @@ -6,19 +6,69 @@ HumanData is a subclass of python built-in class dict, containing single-view, i ### Key/Value definition -#### The keys and values supported by HumanData are described as below. +#### Paths: +Image path is included, and optionally path of segmentation map and depth image can be included if provided by dataset. - image_path: (N, ), list of str, each element is a relative path from the root folder (exclusive) to the image. +- segmentation (optional): (N, ), list of str, each element is a relative path from the root folder (exclusive) to the segmentation map. +- depth_path (optional): (N, ), list of str, each element is a relative path from the root folder (exclusive) to the depth image. + +#### Keypoints: + +Following keys should be included in `HumanData` if applicable. For each dictionary key of keypoints,a corresponding dictionart key of mask should be included,stating which keypoint is valid. For example `keypoints3d_original` should correspond to `keypoints3d_original_mask`. + +In `HumanData`, keypoints are stored as `HUMAN_DATA` format, which includes 190 joints. We provide keypoints format (for both 2d and 3d keypoints) convention for many datasets, please see [keypoints_convention](../docs/keypoints_convention.md). + +- keypoints3d_smpl / keypoints3d_smplx: (N, 190, 4), numpy array, `smplx / smplx` 3d joints with confidence, joints from each datasets are mapped to HUMAN_DATA joints. +- keypoints3d_original: (N, 190, 4), numpy array, 3d joints with confidence which provided by the dataset originally, joints from each datasets are mapped to HUMAN_DATA joints. +- keypoints2d_smpl / keypoints2d_smplx: (N, 190, 3), numpy array, `smpl / smplx` 2d joints with confidence, joints from each datasets are mapped to HUMAN_DATA joints. +- keypoints2d_original: (N, 190, 3), numpy array, 2d joints with confidence which provided by the dataset originally, joints from each datasets are mapped to HUMAN_DATA joints. +- (mask sample) keypoints2d_smpl_mask: (190, ), numpy array, mask for which keypoint is valid in `keypoints2d_smpl`. 0 means that the joint in this position cannot be found in original dataset. + +#### Bounding Box: + +Bounding box of body (smpl), face and hand (smplx), which data type is `[x_min, y_min, width, height, confidence]`,and should not exceed the image boundary. - bbox_xywh: (N, 5), numpy array, bounding box with confidence, coordinates of bottom-left point x, y, width w and height h of bbox, score at last. +- face_bbox_xywh, lhand_bbox_xywh, rhand_bbox_xywh (optional): (N, 5), numpy array, should be included if `smplx` data is provided, and is derived from smplx2d keypoints. Have the same srtucture as above. + +#### Human Pose and Shape Parameters: + +Normally saved as smpl/smplx. +- smpl: (1, ), dict, keys are `['body_pose': numpy array, (N, 23, 3), 'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3)]`. +- smplx: (1, ), dict, keys are `['body_pose': numpy array, (N, 21, 3),'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3), 'left_hand_pose': numpy array, (N, 15, 3), 'right_hand_pose': numpy array, (N, 15, 3), 'expression': numpy array (N, 10), 'leye_pose': numpy array (N, 3), 'reye_pose': (N, 3), 'jaw_pose': numpy array (N, 3)]`. + + +#### Other keys + - config: (), str, the flag name of config for individual dataset. -- keypoints2d: (N, 190, 3), numpy array, 2d joints of smplx model with confidence, joints from each datasets are mapped to HUMAN_DATA joints. -- keypoints3d: (N, 190, 4), numpy array, 3d joints of smplx model with confidence. Same as above. -- smpl: (1, ), dict, keys are ['body_pose': numpy array, (N, 23, 3), 'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3)]. -- smplx: (1, ), dict, keys are ['body_pose': numpy array, (N, 21, 3),'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3), 'left_hand_pose': numpy array, (N, 15, 3), 'right_hand_pose': numpy array, (N, 15, 3), 'expression': numpy array (N, 10), 'leye_pose': numpy array (N, 3), 'reye_pose': (N, 3), 'jaw_pose': numpy array (N, 3)]. - meta: (1, ), dict, its keys are meta data from dataset like 'gender'. -- keypoints2d_mask: (190, ), numpy array, mask for which keypoint is valid in keypoints2d. 0 means that the joint in this position cannot be found in original dataset. -- keypoints3d_mask: (190, ), numpy array, mask for which keypoint is valid in keypoints3d. 0 means that the joint in this position cannot be found in original dataset. -- misc: (1, ), dict, keys and values are defined by user. The space misc takes(sys.getsizeof(misc)) shall be no more than 6MB. +- misc: (1, ), dict, keys and values are designed to describe the different settings for each dataset. Can also be defined by user. The space misc takes (sys.getsizeof(misc)) shall be no more than 6MB. + +#### Suggestion for WHAT to include in `HumanData['misc']`: + +Miscellaneous contains the info of different settings for each dataset, including camaera type, source of keypoints annotation, bounding box etc. Aims to faclitate different usage of data. +`HumanData['misc']` is a dictionary and its keys are described as following: +- kps3d_root_aligned: Bool, stating that if keypoints3d is root-aligned,root_alignment is not preferred for HumanData. If this key does not exist, root_aligenment is by default to be `False`. +- flat_hand_mean:Bool, applicable for smplx data,for most datasets `flat_hand_mean=False`. +- bbox_source:source of bounding box,`bbox_soruce='keypoints2d_smpl' or 'keypoints2d_smplx' or 'keypoints2d_original'`,describing which type of keypoints are used to derive the bounding box,OR `bbox_source='provide_by_dataset'` shows that bounding box if provided by dataset. (For example, from some detection module rather than keypoints) +- bbox_body_scale: applicable if bounding box is derived by keypoints,stating the zoom-in scale of bounding scale from smpl/smplx/2d_gt keypoints,we suggest `bbox_body_scale=1.2`. +- bbox_hand_scale, bbox_face_scale: applicable if bounding box is derived by smplx keypoints,stating the zoom-in scale of bounding scale from smplx/2d_gt keypoints,we suggest `bbox_hand_scale=1.0, bbox_face_scale=1.0` +- smpl_source / smplx_source: describing the source of smpl/smplx annotations,`'original', 'nerual_annot', 'eft', 'osx_annot', 'cliff_annot'`. +- cam_param_type: describing the type of camera parameters,`cam_param_type='prespective' or 'predicted_camera' or 'eft_camera'` +- principal_point, focal_length: (1, 2), numpy array,applicable if camera parameters are same across the whole dataset, which is the case for some synthetic datasets. +- image_shape: (1, 2), numpy array,applicable if image shape are same across the whole dataset. + +#### Suggestion for WHAT to include in `HumanData['meta']`: + +- gender: (N, ), list of str, each element represents the gender for an smpl/smplx instance. (key not required if dataset use gender-neutral model) +- height (width):(N, ), list of str, each element represents the height (width) of an image, `image_shape=(width, height): (N, 2)` is not suggested as width and height might need to be referenced in different orders. (keys should be in `HumanData['misc']` if image shape are same across the dataset) +- other keys,applicable if the key value is different,and have some impact on keypoints or smpl/smplx (2d and 3d),For example, `focal_length` and `principal_point`, focal_length = (N, 2), principal_point = (N, 2) + +#### Some other info of HumanData + +- All annotations are transformed from world space to opencv camera space, for space transformation we use: + + ```from mmhuman3d.models.body_models.utils import transform_to_camera_frame, batch_transform_to_camera_frame``` #### Key check in HumanData. diff --git a/docs_zh-CN/human_data.md b/docs_zh-CN/human_data.md index f8855669..a866f453 100644 --- a/docs_zh-CN/human_data.md +++ b/docs_zh-CN/human_data.md @@ -5,21 +5,66 @@ `HumanData`是Python内置字典的子类,主要用于存放包含人体的单视角图像的信息。它具有通用的基础结构,也兼容具有新特性的客制化数据。 原生的`HumanData`包含`numpy.ndarray`或其他的Python内置的数据结构,但不包含`torch.Tensor`的数据。可以使用`human_data.to()`将其转换为`torch.Tensor`(支持CPU和GPU)。 -### `Key/Value`的定义 +### `Key/Value`的定义:如下是对`HumanData`支持的`Key`和`Value`的描述. -#### 如下是对`HumanData`支持的`Key`和`Value`的描述. +#### 路径: +通常包含图片路径,如果数据集有提供额外的深度或者分割图,也可以记录下来。 - image_path: (N, ), 字符串组成的列表, 每一个元素是图像相对于根目录的路径。 +- segmantation_path (可选): (N, ), 字符串组成的列表, 每一个元素是图像分割图相对于根目录的路径。 +- depth_path (可选): (N, ), 字符串组成的列表, 每一个元素是图像深度图相对于根目录的路径。 + +#### 关键点: + +以下关键点keys如果适用,则应包含在HumanData中。任何一个关键点的key,应存在一个mask,表示其中哪些关键点有效。如`keypoints3d_original`应对应`keypoints3d_original_mask`。 +`HumanData` 中的关键点存储格式为`HUMAN_DATA`, 包含190个关键点。MMHuman3d中提供了很多常用关键点格式的转换(2d及3d均支持), 详见 [keypoints_convention](../docs_zh-CN/keypoints_convention.md). +- keypoints3d_smpl / keypoints3d_smplx: (N, 190, 4), numpy array, `smplx / smplx`模型的3d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 +- keypoints3d_original: (N, 190, 4), numpy array, 由数据集本身提供的3d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 +- keypoints2d_smpl / keypoints2d_smplx: (N, 190, 3), numpy array, `smpl / smplx`模型的2d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 +- keypoints2d_original: (N, 190, 3), numpy array, 由数据集本身提供的2d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 +- (mask示例) keypoints2d_smpl_mask: (190, ), numpy array, 表示`keypoints2d_smpl`中关键点是否有效的掩膜。 0表示该位置的关键点在原始数据集中无法找到。 + +#### 检测框: + +身体(smpl),手脸(smplx)的检测框,标注为`[x_min, y_min, width, height, confidence]`,且不应超出图片。 - bbox_xywh: (N, 5), numpy array, 边界框的置信度, 边界框左下角点的坐标x和y, 边界框的宽w和高h, 置信度得分放置在最后。 -- config: (), 字符串, 单个数据集的配置的标志。 -- keypoints2d: (N, 190, 3), numpy array, `smplx`模型的2d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 -- keypoints3d: (N, 190, 4), numpy array, `smplx`模型的3d关节点与置信度, 每一个数据集的关节点映射到了`HUMAN_DATA`的关节点。 +- face_bbox_xywh, lhand_bbox_xywh, rhand_bbox_xywh(可选): (N, 5), numpy array, 如果数据标注中含有`smplx`, 则应包括这三个key,由smplx2d关键点得出,格式同上。 + +#### 人体模型参数: + +通常以smpl/smplx格式存储。 - smpl: (1, ), 字典, `keys` 分别为 ['body_pose': numpy array, (N, 23, 3), 'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3)]. - smplx: (1, ), 字典, `keys` 分别为 ['body_pose': numpy array, (N, 21, 3),'global_orient': numpy array, (N, 3), 'betas': numpy array, (N, 10), 'transl': numpy array, (N, 3), 'left_hand_pose': numpy array, (N, 15, 3), 'right_hand_pose': numpy array, (N, 15, 3), 'expression': numpy array (N, 10), 'leye_pose': numpy array (N, 3), 'reye_pose': (N, 3), 'jaw_pose': numpy array (N, 3)]. -- meta: (1, ), 字典, `keys` 为数据集中类似性别的元数据。 -- keypoints2d_mask: (190, ), numpy array, 表示`keypoints2d`中关键点是否有效的掩膜。 0表示该位置的关键点在原始数据集中无法找到。 -- keypoints3d_mask: (190, ), numpy array, 表示`keypoints3d`中关键点是否有效的掩膜。 0表示该位置的关键点在原始数据集中无法找到。 -- misc: (1, ), 字典, `keys`和`values`由用户定义。`misc`占用的空间(可以通过`sys.getsizeof(misc)`获取)不能超过6MB。 + +#### 其它keys + +- config: (), 字符串, 单个数据集的配置的标志。 +- meta: (1, ), 字典, `keys`为数据集中的各种元数据。 +- misc: (1, ), 字典, `keys`为数据集中各种独特设定,也可以由用户自定义。`misc`占用的空间(可以通过`sys.getsizeof(misc)`获取)不能超过6MB。 + +#### `HumanData['misc']`中建议(可能)包含的内容: +Miscellaneous部分中包含了每个数据集的独特设定,包括相机种类,关键点标注来源,检测框来源,是否包含smpl/smplx标注等等,用于便利数据读取。 +`HumanData['misc']`中包含一个dictionary,建议包含的key如下所示: +- kps3d_root_aligned: Bool 描述keypoints3d是否经过root align,建议不进行root_alignment,如果不包含这个key,则默认没有进行过root_aligenment +- flat_hand_mean:Bool 对于smplx标注的数据,应该存在此项,大多数数据集中`flat_hand_mean=False` +- bbox_source:描述检测框的来源,`bbox_soruce='keypoints2d_smpl' or 'keypoints2d_smplx' or 'keypoints2d_original'`,描述检测框是由哪种关键点得出的,或者`bbox_source='provide_by_dataset'`表示检测框由数据集直接给出(比如用其自带检测器生成而不是由关键点推导得出) +- bbox_body_scale: 如果检测框由关键点推导得出,则应包含此项,描述由smpl/smplx/2d_gt关键点推导出的身体检测框的放大比例,建议`bbox_body_scale=1.2` +- bbox_hand_scale, bbox_face_scale: 如果检测框由关键点推导得出,则应包含这两项,描述由smpl/smplx/2d_gt关键点推导出的身体检测框的放大比例,建议`bbox_hand_scale=1.0, bbox_face_scale=1.0` +- smpl_source / smplx_source: 描述smpl/smplx的来源,`'original', 'nerual_annot', 'eft', 'osx_annot', 'cliff_annot'`, 来描述smpl/smnplx是来源于数据集提供,或者其它标注来源 +- cam_param_type: 描述相机参数的种类,`cam_param_type='prespective' or 'predicted_camera' or 'eft_camera'` +- principal_point, focal_length: (1, 2), numpy array,如果数据集中相机参数恒定,则应包含这两项,通常适用于生成数据集。 +- image_shape: (1, 2), numpy array,如果数据集中图片大小恒定,则应包含此项。 + +#### `HumanData['meta']`中建议(可能)包含的内容: +- gender: (N, ), 字符串组成的列表, 每一个元素是smplx模型的性别(中性则不必标注) +- height(width):(N, ), 字符串组成的列表, 每一个元素是图片的高(或宽),这里不推荐使用`image_shape=(width, height): (N, 2)`,因为有时需要按反顺序读取图片格式。(数据集图片分辨率一致则应标注在`HumanData['misc']`中) +- 其它有标识性的key,若数据集中该key不一致,且会影响keypoints or smpl/smplx,则建议标注,如focal_length与principal_point, focal_length = (N, 2), principal_point = (N, 2) + +#### 关于HumanData的一些说明 + +- 所有数据标注均已从世界坐标转移到opencv相机空间,进行smpl/smplx的相机空间转换可以用 + +```from mmhuman3d.models.body_models.utils import transform_to_camera_frame, batch_transform_to_camera_frame``` #### 检查`HumanData`中的`key`. From 9431addec32f7fbeffa1786927a854c0ab79d9ea Mon Sep 17 00:00:00 2001 From: wendaizhou <57831849+wendaizhou@users.noreply.github.com> Date: Mon, 10 Jul 2023 10:32:20 +0800 Subject: [PATCH 7/7] Update STAR lbs & render (#355) * Update STAR lbs & render * Update star forward * Add helper script for GTA-Human visualization (#357) * [Upgrade] Specify the numpy version less than 1.24. (#361) * downgrade numpy version to 1.23.1 * specify numpy version to <=1.23.1 * specify numpy version as <1.24 * fix lint * Update description of HumanData (English and Chinese) (#356) * Update human_data.md * Update human_data.md * Update human_data.md * fix * minor docs update English & Chinese * minor docs update English & Chinese --------- Co-authored-by: wei-chen-hub * fix linting problem * Update STAR model --------- Co-authored-by: wendaizhou Co-authored-by: Zhongang Cai <62529255+caizhongang@users.noreply.github.com> Co-authored-by: ttxskk Co-authored-by: WEI CHEN <77597327+Wei-Chen-hub@users.noreply.github.com> Co-authored-by: wei-chen-hub --- configs/hmr/star_pw3d.py | 152 +++++++ demo/estimate_star.py | 421 ++++++++++++++++++ .../core/conventions/segmentation/__init__.py | 6 +- .../core/renderer/torch3d_renderer/meshes.py | 4 +- .../core/visualization/visualize_smpl.py | 20 +- mmhuman3d/models/body_models/star.py | 128 ++++-- 6 files changed, 681 insertions(+), 50 deletions(-) create mode 100755 configs/hmr/star_pw3d.py create mode 100755 demo/estimate_star.py mode change 100644 => 100755 mmhuman3d/core/conventions/segmentation/__init__.py mode change 100644 => 100755 mmhuman3d/core/renderer/torch3d_renderer/meshes.py mode change 100644 => 100755 mmhuman3d/core/visualization/visualize_smpl.py mode change 100644 => 100755 mmhuman3d/models/body_models/star.py diff --git a/configs/hmr/star_pw3d.py b/configs/hmr/star_pw3d.py new file mode 100755 index 00000000..575b8a57 --- /dev/null +++ b/configs/hmr/star_pw3d.py @@ -0,0 +1,152 @@ +_base_ = ['../_base_/default_runtime.py'] +use_adversarial_train = True + +# evaluate +evaluation = dict(metric=['pa-mpjpe', 'mpjpe']) +# optimizer +optimizer = dict( + backbone=dict(type='Adam', lr=2.5e-4), head=dict(type='Adam', lr=2.5e-4)) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='Fixed', by_epoch=False) +runner = dict(type='EpochBasedRunner', max_epochs=100) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +img_res = 224 + +# model settings +model = dict( + type='ImageBodyModelEstimator', + backbone=dict( + type='ResNet', + depth=50, + out_indices=[3], + norm_eval=False, + norm_cfg=dict(type='BN', requires_grad=True), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + head=dict(type='HMRHead', feat_dim=2048), + body_model_train=dict( + type='STAR', + keypoint_src='smpl', + keypoint_dst='star', + model_path='data/body_models/star', + keypoint_approximate=True), + body_model_test=dict( + type='STAR', + keypoint_src='smpl', + keypoint_dst='star', + model_path='data/body_models/star'), + convention='star', + loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100), + loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10), + loss_vertex=dict(type='L1Loss', loss_weight=2), + loss_smpl_pose=dict(type='MSELoss', loss_weight=3), + loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02)) +# dataset settings +dataset_type = 'HumanImageDataset' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +data_keys = [ + 'has_smpl', 'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', + 'smpl_transl', 'keypoints2d', 'keypoints3d', 'sample_idx', 'has_smpl', + 'has_keypoints2d', 'has_keypoints3d' +] +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='RandomChannelNoise', noise_factor=0.4), + dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='star'), + dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25), + dict(type='MeshAffine', img_res=img_res), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='RandomErasing'), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0), + dict(type='MeshAffine', img_res=img_res), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=data_keys), + dict( + type='Collect', + keys=['img', *data_keys], + meta_keys=['image_path', 'center', 'scale', 'rotation']) +] + +inference_pipeline = [ + dict(type='MeshAffine', img_res=img_res), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict( + type='Collect', + keys=['img', 'sample_idx'], + meta_keys=['image_path', 'center', 'scale', 'rotation', 'origin_img']) +] + +data = dict( + samples_per_gpu=32, + workers_per_gpu=1, + train=dict( + type='MixedDataset', + configs=[ + dict( + type=dataset_type, + dataset_name='pw3d', + data_prefix='data', + pipeline=train_pipeline, + convention='star', + ann_file='star.npz'), + dict( + type=dataset_type, + dataset_name='mpi_inf_3dhp', + data_prefix='data', + pipeline=train_pipeline, + convention='star', + ann_file='mpi_inf_3dhp_1_4.npz'), + dict( + type=dataset_type, + dataset_name='h36m', + data_prefix='data', + pipeline=train_pipeline, + convention='star', + ann_file='h36m_train_new.npz'), + ], + partition=[0.4, 0.3, 0.3], + ), + val=dict( + type=dataset_type, + dataset_name='h36m', + body_model=dict( + type='STAR', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/star'), + data_prefix='data', + pipeline=test_pipeline, + convention='star', + ann_file='h36m_test.npz'), + test=dict( + type=dataset_type, + dataset_name='h36m', + body_model=dict( + type='STAR', + keypoint_src='h36m', + keypoint_dst='h36m', + model_path='data/body_models/star'), + data_prefix='data', + pipeline=test_pipeline, + convention='star', + ann_file='h36m_test.npz'), +) diff --git a/demo/estimate_star.py b/demo/estimate_star.py new file mode 100755 index 00000000..fbe05182 --- /dev/null +++ b/demo/estimate_star.py @@ -0,0 +1,421 @@ +import os +import os.path as osp +import shutil +import warnings +from argparse import ArgumentParser +from pathlib import Path + +import mmcv +import numpy as np +import torch + +from mmhuman3d.apis import ( + feature_extract, + inference_image_based_model, + inference_video_based_model, + init_model, +) +from mmhuman3d.core.visualization.visualize_smpl import visualize_smpl_hmr +from mmhuman3d.data.data_structures.human_data import HumanData +from mmhuman3d.utils.demo_utils import ( + extract_feature_sequence, + get_speed_up_interval, + prepare_frames, + process_mmdet_results, + process_mmtracking_results, + smooth_process, + speed_up_interpolate, + speed_up_process, +) +from mmhuman3d.utils.ffmpeg_utils import array_to_images + +try: + from mmdet.apis import inference_detector, init_detector + + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + +try: + from mmtrack.apis import inference_mot + from mmtrack.apis import init_model as init_tracking_model + + has_mmtrack = True +except (ImportError, ModuleNotFoundError): + has_mmtrack = False + + +def get_tracking_result(args, frames_iter, mesh_model, extractor): + tracking_model = init_tracking_model( + args.tracking_config, None, device=args.device.lower()) + + max_track_id = 0 + max_instance = 0 + result_list = [] + frame_id_list = [] + + for i, frame in enumerate(mmcv.track_iter_progress(frames_iter)): + mmtracking_results = inference_mot(tracking_model, frame, frame_id=i) + + # keep the person class bounding boxes. + result, max_track_id, instance_num = \ + process_mmtracking_results( + mmtracking_results, + max_track_id=max_track_id, + bbox_thr=args.bbox_thr) + + # extract features from the input video or image sequences + if mesh_model.cfg.model.type == 'VideoBodyModelEstimator' \ + and extractor is not None: + result = feature_extract( + extractor, frame, result, args.bbox_thr, format='xyxy') + + # drop the frame with no detected results + if result == []: + continue + + # update max_instance + if instance_num > max_instance: + max_instance = instance_num + + # vis bboxes + if args.draw_bbox: + bboxes = [res['bbox'] for res in result] + bboxes = np.vstack(bboxes) + mmcv.imshow_bboxes( + frame, bboxes, top_k=-1, thickness=2, show=False) + + result_list.append(result) + frame_id_list.append(i) + + return max_track_id, max_instance, frame_id_list, result_list + + +def nonlinearWeight(x, thre_low, a1, a2, a3, a4): + ratio = 0 + if abs(x).mean() >= thre_low: + ratio = (a1 / a3 + a2 * np.power(x, 2)) / np.sqrt(a4 / (a3 * a3) + + np.power(x, 4)) + # if ratio > 1: + # ratio = 1 + return ratio + + +def get_detection_result(args, frames_iter, mesh_model, extractor): + person_det_model = init_detector( + args.det_config, args.det_checkpoint, device=args.device.lower()) + frame_id_list = [] + result_list = [] + pre_bbox = None + for i, frame in enumerate(mmcv.track_iter_progress(frames_iter)): + mmdet_results = inference_detector(person_det_model, frame) + # keep the person class bounding boxes. + results = process_mmdet_results( + mmdet_results, cat_id=args.det_cat_id, bbox_thr=args.bbox_thr) + + # smooth + if pre_bbox is not None: + cur_bbox = results[0]['bbox'] + dist_tl = np.array([(cur_bbox[0] - pre_bbox[0])**2, + (cur_bbox[1] - pre_bbox[1])**2]) + delta_tl = np.array(dist_tl / + (np.array(pre_bbox[:2]) + 1e-7)).sum() + ratio_tl = nonlinearWeight(delta_tl, 0, 0.2, 0.8, 120, 1) + dist_br = np.array([(cur_bbox[2] - pre_bbox[2])**2, + (cur_bbox[3] - pre_bbox[3])**2]) + delta_br = np.array(dist_br / + (np.array(pre_bbox[2:4]) + 1e-7)).sum() + ratio_br = nonlinearWeight(delta_br, 0, 0.2, 0.8, 120, 1) + results[0]['bbox'] = np.array([ + ratio_tl * cur_bbox[0] + (1 - ratio_tl) * pre_bbox[0], + ratio_tl * cur_bbox[1] + (1 - ratio_tl) * pre_bbox[1], + ratio_br * cur_bbox[2] + (1 - ratio_br) * pre_bbox[2], + ratio_br * cur_bbox[3] + (1 - ratio_br) * pre_bbox[3], + cur_bbox[4] + ]) + pre_bbox = results[0]['bbox'] + + # extract features from the input video or image sequences + if mesh_model.cfg.model.type == 'VideoBodyModelEstimator' \ + and extractor is not None: + results = feature_extract( + extractor, frame, results, args.bbox_thr, format='xyxy') + # drop the frame with no detected results + if results == []: + continue + # vis bboxes + if args.draw_bbox: + bboxes = [res['bbox'] for res in results] + bboxes = np.vstack(bboxes) + mmcv.imshow_bboxes( + frame, bboxes, top_k=-1, thickness=2, show=False) + frame_id_list.append(i) + result_list.append(results) + + frame_num = len(result_list) + x = np.array([i[0]['bbox'] for i in result_list]) + x = smooth_process( + x[:, np.newaxis], smooth_type='savgol').reshape(frame_num, 5) + for idx, result in enumerate(result_list): + result[0]['bbox'] = x[idx, :] + + return frame_id_list, result_list + + +def single_person_with_mmdet(args, frames_iter): + """Estimate smpl parameters from single-person + images with mmdetection + Args: + args (object): object of argparse.Namespace. + frames_iter (np.ndarray,): prepared frames + + """ + + mesh_model, extractor = init_model( + args.mesh_reg_config, + args.mesh_reg_checkpoint, + device=args.device.lower()) + + pred_cams, verts, smpl_poses, smpl_betas, bboxes_xyxy = \ + [], [], [], [], [] + + frame_id_list, result_list = \ + get_detection_result(args, frames_iter, mesh_model, extractor) + + frame_num = len(frame_id_list) + # speed up + if args.speed_up_type: + speed_up_interval = get_speed_up_interval(args.speed_up_type) + speed_up_frames = (frame_num - + 1) // speed_up_interval * speed_up_interval + + for i, result in enumerate(mmcv.track_iter_progress(result_list)): + frame_id = frame_id_list[i] + if mesh_model.cfg.model.type == 'VideoBodyModelEstimator': + if args.speed_up_type: + warnings.warn( + 'Video based models do not support speed up. ' + 'By default we will inference with original speed.', + UserWarning) + feature_results_seq = extract_feature_sequence( + result_list, frame_idx=i, causal=True, seq_len=16, step=1) + mesh_results = inference_video_based_model( + mesh_model, + extracted_results=feature_results_seq, + with_track_id=False) + elif mesh_model.cfg.model.type == 'ImageBodyModelEstimator': + if args.speed_up_type and i % speed_up_interval != 0 \ + and i <= speed_up_frames: + mesh_results = [{ + 'bbox': np.zeros((5)), + 'camera': np.zeros((3)), + 'smpl_pose': np.zeros((24, 3, 3)), + 'smpl_beta': np.zeros((10)), + 'vertices': np.zeros((6890, 3)), + 'keypoints_3d': np.zeros((17, 3)), + }] + else: + mesh_results = inference_image_based_model( + mesh_model, + frames_iter[frame_id], + result, + bbox_thr=args.bbox_thr, + format='xyxy') + else: + raise Exception( + f'{mesh_model.cfg.model.type} is not supported yet') + + smpl_betas.append(mesh_results[0]['smpl_beta']) + smpl_pose = mesh_results[0]['smpl_pose'] + smpl_poses.append(smpl_pose) + pred_cams.append(mesh_results[0]['camera']) + verts.append(mesh_results[0]['vertices']) + bboxes_xyxy.append(mesh_results[0]['bbox']) + + smpl_poses = np.array(smpl_poses) + smpl_betas = np.array(smpl_betas) + pred_cams = np.array(pred_cams) + verts = np.array(verts) + bboxes_xyxy = np.array(bboxes_xyxy) + + # release GPU memory + del mesh_model + del extractor + torch.cuda.empty_cache() + + # speed up + if args.speed_up_type: + smpl_poses = speed_up_process( + torch.tensor(smpl_poses).to(args.device.lower()), + args.speed_up_type) + + selected_frames = np.arange(0, len(frames_iter), speed_up_interval) + smpl_poses, smpl_betas, pred_cams, bboxes_xyxy = speed_up_interpolate( + selected_frames, speed_up_frames, smpl_poses, smpl_betas, + pred_cams, bboxes_xyxy) + + # smooth + if args.smooth_type is not None: + smpl_poses = smooth_process( + smpl_poses.reshape(frame_num, 24, 9), + smooth_type=args.smooth_type).reshape(frame_num, 24, 3, 3) + verts = smooth_process(verts, smooth_type=args.smooth_type) + pred_cams = smooth_process( + pred_cams[:, np.newaxis], + smooth_type=args.smooth_type).reshape(frame_num, 3) + + if args.output is not None: + body_pose_, global_orient_, smpl_betas_, verts_, pred_cams_, \ + bboxes_xyxy_, image_path_, person_id_, frame_id_ = \ + [], [], [], [], [], [], [], [], [] + human_data = HumanData() + frames_folder = osp.join(args.output, 'images') + os.makedirs(frames_folder, exist_ok=True) + array_to_images( + np.array(frames_iter)[frame_id_list], output_folder=frames_folder) + + for i, img_i in enumerate(sorted(os.listdir(frames_folder))): + body_pose_.append(smpl_poses[i][1:]) + global_orient_.append(smpl_poses[i][:1]) + smpl_betas_.append(smpl_betas[i]) + verts_.append(verts[i]) + pred_cams_.append(pred_cams[i]) + bboxes_xyxy_.append(bboxes_xyxy[i]) + image_path_.append(os.path.join('images', img_i)) + person_id_.append(0) + frame_id_.append(frame_id_list[i]) + + smpl = {} + smpl['body_pose'] = np.array(body_pose_).reshape((-1, 23, 3)) + smpl['global_orient'] = np.array(global_orient_).reshape((-1, 3)) + smpl['betas'] = np.array(smpl_betas_).reshape((-1, 10)) + human_data['smpl'] = smpl + human_data['verts'] = verts_ + human_data['pred_cams'] = pred_cams_ + human_data['bboxes_xyxy'] = bboxes_xyxy_ + human_data['image_path'] = image_path_ + human_data['person_id'] = person_id_ + human_data['frame_id'] = frame_id_ + human_data.dump(osp.join(args.output, 'inference_result.npz')) + + if args.show_path is not None: + if args.output is not None: + frames_folder = os.path.join(args.output, 'images') + else: + frames_folder = osp.join(Path(args.show_path).parent, 'images') + os.makedirs(frames_folder, exist_ok=True) + array_to_images( + np.array(frames_iter)[frame_id_list], + output_folder=frames_folder) + + body_model_config = dict(model_path='data/body_models', type='star') + visualize_smpl_hmr( + poses=smpl_poses, + betas=smpl_betas, + cam_transl=pred_cams, + bbox=bboxes_xyxy, + output_path=args.show_path, + render_choice=args.render_choice, + resolution=frames_iter[0].shape[:2], + origin_frames=frames_folder, + body_model_config=body_model_config, + overwrite=True, + palette=args.palette, + read_frames_batch=True) + if args.output is None: + shutil.rmtree(frames_folder) + + +def main(args): + # prepare input + frames_iter = prepare_frames(args.input_path) + single_person_with_mmdet(args, frames_iter) + + +if __name__ == '__main__': + + parser = ArgumentParser() + parser.add_argument( + 'mesh_reg_config', + type=str, + default=None, + help='Config file for mesh regression') + parser.add_argument( + 'mesh_reg_checkpoint', + type=str, + default=None, + help='Checkpoint file for mesh regression') + parser.add_argument( + '--single_person_demo', + action='store_true', + help='Single person demo with MMDetection') + parser.add_argument('--det_config', help='Config file for detection') + parser.add_argument( + '--det_checkpoint', help='Checkpoint file for detection') + parser.add_argument( + '--det_cat_id', + type=int, + default=1, + help='Category id for bounding box detection model') + parser.add_argument('--tracking_config', help='Config file for tracking') + + parser.add_argument( + '--body_model_dir', + type=str, + default='data/body_models/', + help='Body models file path') + parser.add_argument( + '--input_path', type=str, default=None, help='Input path') + parser.add_argument( + '--output', + type=str, + default=None, + help='directory to save output result file') + parser.add_argument( + '--show_path', + type=str, + default=None, + help='directory to save rendered images or video') + parser.add_argument( + '--render_choice', + type=str, + default='hq', + help='Render choice parameters') + parser.add_argument( + '--palette', type=str, default='segmentation', help='Color theme') + parser.add_argument( + '--bbox_thr', + type=float, + default=0.97, + help='Bounding box score threshold') + parser.add_argument( + '--draw_bbox', + action='store_true', + help='Draw a bbox for each detected instance') + parser.add_argument( + '--smooth_type', + type=str, + default=None, + help='Smooth the data through the specified type.' + 'Select in [oneeuro,gaus1d,savgol].') + parser.add_argument( + '--speed_up_type', + type=str, + default=None, + help='Speed up data processing through the specified type.' + 'Select in [deciwatch].') + parser.add_argument( + '--focal_length', type=float, default=5000., help='Focal lenght') + parser.add_argument( + '--device', + choices=['cpu', 'cuda'], + default='cuda', + help='device used for testing') + args = parser.parse_args() + + if args.single_person_demo: + assert has_mmdet, 'Please install mmdet to run the demo.' + assert args.det_config is not None + assert args.det_checkpoint is not None + + main(args) diff --git a/mmhuman3d/core/conventions/segmentation/__init__.py b/mmhuman3d/core/conventions/segmentation/__init__.py old mode 100644 new mode 100755 index 7de351c1..31949615 --- a/mmhuman3d/core/conventions/segmentation/__init__.py +++ b/mmhuman3d/core/conventions/segmentation/__init__.py @@ -10,13 +10,17 @@ def __init__(self, model_type='smpl') -> None: self.DICT = SMPL_SEGMENTATION_DICT self.super_set = SMPL_SUPER_SET self.NUM_VERTS = 6890 + elif model_type == 'star': + self.DICT = SMPL_SEGMENTATION_DICT + self.super_set = SMPL_SUPER_SET + self.NUM_VERTS = 6890 elif model_type == 'smplx': self.DICT = SMPLX_SEGMENTATION_DICT self.super_set = SMPLX_SUPER_SET self.NUM_VERTS = 10475 else: raise ValueError(f'Wrong model_type: {model_type}.' - f' Should be in {["smpl", "smplx"]}') + f' Should be in {["smpl", "smplx", "star"]}') self.model_type = model_type self.len = len(list(self.DICT)) diff --git a/mmhuman3d/core/renderer/torch3d_renderer/meshes.py b/mmhuman3d/core/renderer/torch3d_renderer/meshes.py old mode 100644 new mode 100755 index 12f5b44c..ca73546d --- a/mmhuman3d/core/renderer/torch3d_renderer/meshes.py +++ b/mmhuman3d/core/renderer/torch3d_renderer/meshes.py @@ -7,7 +7,7 @@ from pytorch3d.renderer.mesh.textures import TexturesBase from pytorch3d.structures import Meshes, list_to_padded, padded_to_list -from mmhuman3d.models.body_models.builder import SMPL, SMPLX +from mmhuman3d.models.body_models.builder import SMPL, SMPLX, STAR from mmhuman3d.utils.mesh_utils import \ join_meshes_as_batch as _join_meshes_as_batch from .builder import build_renderer @@ -41,7 +41,7 @@ class ParametricMeshes(Meshes): Will use the textures directly from the meshes. """ # TODO: More model class to be added (FLAME, MANO) - MODEL_CLASSES = {'smpl': SMPL, 'smplx': SMPLX} + MODEL_CLASSES = {'smpl': SMPL, 'smplx': SMPLX, 'star': STAR} def __init__(self, verts: Union[List[torch.Tensor], torch.Tensor] = None, diff --git a/mmhuman3d/core/visualization/visualize_smpl.py b/mmhuman3d/core/visualization/visualize_smpl.py old mode 100644 new mode 100755 index 09c0bfd8..e38ffcea --- a/mmhuman3d/core/visualization/visualize_smpl.py +++ b/mmhuman3d/core/visualization/visualize_smpl.py @@ -184,9 +184,9 @@ def _prepare_body_model(body_model, body_model_config): model_path = body_model_config.get('model_path', None) model_type = body_model_config.get('type').lower() - if model_type not in ['smpl', 'smplx']: + if model_type not in ['smpl', 'smplx', 'star']: raise ValueError(f'Do not support {model_type}, please choose' - f' in `smpl` or `smplx.') + f' in `smpl`, `smplx` or `star`.') if model_path and osp.isdir(model_path): model_path = osp.join(model_path, model_type) @@ -481,8 +481,8 @@ def render_smpl( mask: Optional[Union[np.ndarray, List[int]]] = None, vis_kp_index: bool = False, verbose: bool = False) -> Union[None, torch.Tensor]: - """Render SMPL or SMPL-X mesh or silhouette into differentiable tensors, - and export video or images. + """Render SMPL, SMPL-X or STAR mesh or silhouette into differentiable + tensors, and export video or images. Args: # smpl parameters: @@ -748,10 +748,16 @@ def render_smpl( body_model = _prepare_body_model(body_model, body_model_config) model_type = body_model.name().replace('-', '').lower() - assert model_type in ['smpl', 'smplx'] + assert model_type in ['smpl', 'smplx', 'star'] - vertices, joints, num_frames, num_person = _prepare_mesh( - poses, betas, transl, verts, start, end, body_model) + if model_type in ['smpl', 'smplx']: + vertices, joints, num_frames, num_person = _prepare_mesh( + poses, betas, transl, verts, start, end, body_model) + elif model_type == 'star': + model_output = body_model(body_pose=poses, betas=betas, transl=transl) + vertices = model_output['vertices'] + num_frames = poses.shape[0] + num_person = 1 # star temporarily only support single person end = num_frames if end is None else end vertices = vertices.view(num_frames, num_person, -1, 3) num_verts = vertices.shape[-2] diff --git a/mmhuman3d/models/body_models/star.py b/mmhuman3d/models/body_models/star.py old mode 100644 new mode 100755 index 403f5c5b..f4176f5b --- a/mmhuman3d/models/body_models/star.py +++ b/mmhuman3d/models/body_models/star.py @@ -8,15 +8,14 @@ import torch.nn as nn from mmhuman3d.core.conventions.keypoints_mapping import convert_kps -from mmhuman3d.utils.transforms import ( - aa_to_rotmat, - make_homegeneous_rotmat_batch, -) +from mmhuman3d.utils.geometry import rotation_matrix_to_angle_axis +from mmhuman3d.utils.transforms import aa_to_rotmat class STAR(nn.Module): - - NUM_BODY_JOINTS = 24 + NUM_JOINTS = 24 + NUM_VERTS = 6890 + NUM_FACES = 13776 def __init__(self, model_path: str, @@ -60,7 +59,7 @@ def __init__(self, create_body_pose: bool, optional Flag for creating a member variable for the pose of the body. (default = True) - body_pose: torch.tensor, optional, Bx(3*24) + body_pose: torch.tensor, optional, Bx(3*23) The default value for the body pose variable. (default = None) num_betas: int, optional @@ -139,6 +138,9 @@ def __init__(self, self.register_buffer( 'faces', torch.from_numpy(star_model['f'].astype(np.int64))) self.f = star_model['f'] + self.register_buffer('faces_tensor', + torch.from_numpy(star_model['f'].astype( + np.int64))) # alias for face tensor in render # Kinematic tree of the model self.register_buffer( @@ -151,11 +153,10 @@ def __init__(self, } self.register_buffer( 'parent', - torch.tensor([ + torch.LongTensor([ id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1]) - ], - dtype=torch.int64)) + ])) if create_global_orient: if global_orient is None: @@ -175,7 +176,7 @@ def __init__(self, if create_body_pose: if body_pose is None: default_body_pose = torch.zeros( - [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) + [batch_size, self.NUM_JOINTS, 3, 3], dtype=dtype) else: if torch.is_tensor(body_pose): default_body_pose = body_pose.clone().detach() @@ -213,25 +214,23 @@ def __init__(self, self.R = None def forward(self, - global_orient: Optional[torch.Tensor] = None, body_pose: Optional[torch.Tensor] = None, + global_orient: Optional[torch.Tensor] = None, betas: Optional[torch.Tensor] = None, transl: Optional[torch.Tensor] = None, + gender: Optional[str] = None, return_verts: bool = True, - return_full_pose: bool = True) -> torch.Tensor: + return_full_pose: bool = True, + **kwargs) -> torch.Tensor: """Forward pass for the STAR model. Args: - global_orient: torch.tensor, optional, shape Bx3 - Global orientation (rotation) of the body. If given, ignore the - member variable and use it as the global rotation of the body. - Useful if someone wishes to predicts this with an external - model. (default=None) - body_pose: torch.Tensor, shape Bx(J*3) + body_pose: torch.Tensor, shape Bx23x3x3 tensor. Pose parameters for the STAR model. It should be a tensor that contains joint rotations in axis-angle format. If given, ignore the member variable and use it as the body parameters. (default=None) + global_orient: torch.Tensor, shape Bx1x3x3 tensor. (default=None) betas: torch.Tensor, shape Bx10 Shape parameters for the STAR model. If given, ignore the member variable and use it as shape parameters. (default=None) @@ -243,15 +242,18 @@ def forward(self, output: Contains output parameters and attributes corresponding to other body models. """ - global_orient = ( - global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose + device = body_pose.device + + if body_pose.shape[1] % 24 != 0: + body_pose = torch.cat((global_orient, body_pose), dim=1) betas = betas if betas is not None else self.betas - apply_transl = transl is not None or hasattr(self, 'transl') if transl is None and hasattr(self, 'transl'): transl = self.transl batch_size = body_pose.shape[0] + if body_pose.shape[1] == 72: + body_pose = body_pose.view(batch_size, -1, 3) v_template = self.v_template[None, :] shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand( batch_size, -1, -1) @@ -259,8 +261,12 @@ def forward(self, v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor]) - pose_quat = self.normalize_quaternion(body_pose.view(-1, 3)).view( - batch_size, -1) + if len(body_pose.shape) == 4: + # the shape of body pose is rot matrix, convert to angle_axis + body_pose = rotation_matrix_to_angle_axis( + body_pose.view(-1, 3, 3)).view(batch_size, -1, 3) + + pose_quat = self.quat_feat(body_pose.view(-1, 3)).view(batch_size, -1) pose_feat = torch.cat((pose_quat[:, 4:], beta[:, 1]), 1) R = aa_to_rotmat(body_pose.view(-1, 3)).view(batch_size, 24, 3, 3) @@ -270,24 +276,50 @@ def forward(self, v_posed = v_shaped + torch.matmul( posedirs, pose_feat[:, :, None]).view(-1, 6890, 3) - root_transform = make_homegeneous_rotmat_batch( + J_ = J.clone() + J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] + G_ = torch.cat([R, J_[:, :, :, None]], dim=-1) + pad_row = torch.FloatTensor([0, 0, 0, + 1]).to(device).view(1, 1, 1, 4).expand( + batch_size, 24, -1, -1) + G_ = torch.cat([G_, pad_row], dim=2) + G = [G_[:, 0].clone()] + for i in range(1, 24): + G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :])) + G = torch.stack(G, dim=1) + + rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], + dim=2).view(batch_size, 24, 4, 1) + zeros = torch.zeros(batch_size, 24, 4, 3).to(device) + rest = torch.cat([zeros, rest], dim=-1) + rest = torch.matmul(G, rest) + G = G - rest + T = torch.matmul(self.weights, + G.permute(1, 0, 2, 3).contiguous().view(24, -1)).view( + 6890, batch_size, 4, 4).transpose(0, 1) + rest_shape_h = torch.cat( + [v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1) + v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] + v = v + transl[:, None, :] + v.f = self.f + v.v_posed = v_posed + v.v_shaped = v_shaped + + root_transform = self.with_zeros( torch.cat((R[:, 0], J[:, 0][:, :, None]), 2)) results = [root_transform] for i in range(0, self.parent.shape[0]): - transform_i = make_homegeneous_rotmat_batch( + transform_i = self.with_zeros( torch.cat((R[:, i + 1], J[:, i + 1][:, :, None] - J[:, self.parent[i]][:, :, None]), 2)) curr_res = torch.matmul(results[self.parent[i]], transform_i) results.append(curr_res) results = torch.stack(results, dim=1) posed_joints = results[:, :, :3, 3] - - if apply_transl: - posed_joints += transl[:, None, :] - v_posed += transl[:, None, :] + v.J_transformed = posed_joints + transl[:, None, :] joints, joint_mask = convert_kps( - posed_joints, + v.J_transformed, src=self.keypoint_src, dst=self.keypoint_dst, approximate=self.keypoint_approximate) @@ -296,6 +328,9 @@ def forward(self, joint_mask, dtype=torch.uint8, device=joints.device) joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1) + global_orient = body_pose[:, 0, :][:, None] + body_pose = body_pose[:, 1:, :] + output = dict( global_orient=global_orient, body_pose=body_pose, @@ -305,23 +340,33 @@ def forward(self, betas=beta) if return_verts: - output['vertices'] = v_posed + output['vertices'] = v if return_full_pose: output['full_pose'] = torch.cat([global_orient, body_pose], dim=1) return output @classmethod - def normalize_quaternion(self, theta: torch.Tensor) -> torch.Tensor: - """Computes a normalized quaternion ([0,0,0,0] when the body is in rest - pose) given joint angles. + def with_zeros(self, input): + """Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor. - Args: - theta (torch.Tensor): A tensor of joints axis angles, - batch size x number of joints x 3 + :param input: A tensor of dimensions batch size x 3 x 4 + :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) + """ + batch_size = input.shape[0] + row_append = torch.FloatTensor(([0.0, 0.0, 0.0, 1.0])).to(input.device) + row_append.requires_grad = False + padded_tensor = torch.cat( + [input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], 1) + return padded_tensor - Returns: - quat (torch.Tensor) + @classmethod + def quat_feat(self, theta): + """Computes a normalized quaternion ([0,0,0,0] when the body is in + rest pose) given joint angles. + + :param theta: A tensor of joints axis angles. + :return: """ l1norm = torch.norm(theta + 1e-8, p=2, dim=1) angle = torch.unsqueeze(l1norm, -1) @@ -331,3 +376,6 @@ def normalize_quaternion(self, theta: torch.Tensor) -> torch.Tensor: v_sin = torch.sin(angle) quat = torch.cat([v_sin * normalized, v_cos - 1], dim=1) return quat + + def name(self) -> str: + return 'STAR'