Skip to content

Commit

Permalink
Reland Reland "Port DW Pose preprocessor" (#1892)
Browse files Browse the repository at this point in the history
* Port DW Pose preprocessor (#1856)

* ➕ Add dependencies

* 🚧 wip

* 🚧 wip

* 🚧 download models

* 🚧 Minor fixes

* 🔧 update gitignore

* 🐛 Fix normalization issue

* 🚧 load DW model only when DW preprocessor is selected

* ✅ Change test config

* 🎨 nits

* 🐛 Fix A1111 safe torch issue

📝 v1.1.235 (#1859)

Revert "Port DW Pose preprocessor (#1856)" (#1860)

This reverts commit 0d3310f.

Reland "Port DW Pose preprocessor" (#1861)

* Revert "Revert "Port DW Pose preprocessor (#1856)" (#1860)"

This reverts commit 17e100e.

* 🐛 Fix install.py

📝 v1.1.236 (#1862)

:bug: Delay import of mmpose (#1866)

:memo: v1.1.237 (#1868)

:bug: Fix all keypoints invalid issue (#1871)

:bug: lazy import

:construction: update test expectation

:construction: Switch to onnx

:construction: solve onnx package issue

:wrench: Check cuda in more efficient way

:art: Format code

:wrench: Make onnx runtime optional

* Use cv2 to load and run model on cpu

* nit
  • Loading branch information
huchenlei authored Aug 9, 2023
1 parent 4fd548b commit 4fa9190
Show file tree
Hide file tree
Showing 14 changed files with 703 additions and 41 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ annotator/downloads/

# test results and expectations
web_tests/results/
web_tests/expectations/
web_tests/expectations/
*_diff.png
68 changes: 52 additions & 16 deletions annotator/openpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,19 @@
from .body import Body, BodyResult, Keypoint
from .hand import Hand
from .face import Face
from .types import PoseResult, HandResult, FaceResult
from modules import devices
from annotator.annotator_path import models_path

from typing import NamedTuple, Tuple, List, Callable, Union, Optional
from typing import Tuple, List, Callable, Union, Optional

body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"

HandResult = List[Keypoint]
FaceResult = List[Keypoint]
remote_onnx_det = "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx"
remote_onnx_pose = "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx"

class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]

def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Expand Down Expand Up @@ -162,8 +158,7 @@ def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[flo
'canvas_height': canvas_height,
'canvas_width': canvas_width,
}, indent=4)



class OpenposeDetector:
"""
A class for detecting human poses in images using the Openpose model.
Expand All @@ -179,6 +174,8 @@ def __init__(self):
self.hand_estimation = None
self.face_estimation = None

self.dw_pose_estimation = None

def load_model(self):
"""
Load the Openpose body, hand, and face models.
Expand All @@ -202,10 +199,25 @@ def load_model(self):
self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
self.face_estimation = Face(face_modelpath)

def load_dw_model(self):
from .wholebody import Wholebody # DW Pose

def load_model(filename: str, remote_url: str):
local_path = os.path.join(self.model_dir, filename)
if not os.path.exists(local_path):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_url, model_dir=self.model_dir)
return local_path

onnx_det = load_model("yolox_l.onnx", remote_onnx_det)
onnx_pose = load_model("dw-ll_ucoco_384.onnx", remote_onnx_pose)
self.dw_pose_estimation = Wholebody(onnx_det, onnx_pose)

def unload_model(self):
"""
Unload the Openpose models by moving them to the CPU.
Note: DW Pose models always run on CPU, so no need to `unload` them.
"""
if self.body_estimation is not None:
self.body_estimation.model.to("cpu")
Expand Down Expand Up @@ -302,10 +314,29 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
), left_hand, right_hand, face))

return results


def detect_poses_dw(self, oriImg) -> List[PoseResult]:
"""
Detect poses in the given image using DW Pose:
https://github.com/IDEA-Research/DWPose
Args:
oriImg (numpy.ndarray): The input image for pose detection.
Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
from .wholebody import Wholebody # DW Pose

self.load_dw_model()

with torch.no_grad():
keypoints_info = self.dw_pose_estimation(oriImg.copy())
return Wholebody.format_result(keypoints_info)

def __call__(
self, oriImg, include_body=True, include_hand=False, include_face=False,
json_pose_callback: Callable[[str], None] = None,
self, oriImg, include_body=True, include_hand=False, include_face=False,
use_dw_pose=False, json_pose_callback: Callable[[str], None] = None,
):
"""
Detect and draw poses in the given image.
Expand All @@ -315,14 +346,19 @@ def __call__(
include_body (bool, optional): Whether to include body keypoints. Defaults to True.
include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
include_face (bool, optional): Whether to include face keypoints. Defaults to False.
use_dw_pose (bool, optional): Whether to use DW pose detection algorithm. Defaults to False.
json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.
Returns:
numpy.ndarray: The image with detected and drawn poses.
"""
H, W, _ = oriImg.shape
poses = self.detect_poses(oriImg, include_hand, include_face)

if use_dw_pose:
poses = self.detect_poses_dw(oriImg)
else:
poses = self.detect_poses(oriImg, include_hand, include_face)

if json_pose_callback:
json_pose_callback(encode_poses_as_json(poses, H, W))
return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)

return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
19 changes: 1 addition & 18 deletions annotator/openpose/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,7 @@

from . import util
from .model import bodypose_model

class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1


class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float = 0.0
total_parts: int = 0

from .types import Keypoint, BodyResult

class Body(object):
def __init__(self, model_path):
Expand Down
124 changes: 124 additions & 0 deletions annotator/openpose/cv_ox_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import cv2
import numpy as np

def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]

areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]

keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])

w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]

return keep

def multiclass_nms(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy. Class-aware version."""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)

def demo_postprocess(outputs, img_size, p6=False):
grids = []
expanded_strides = []
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]

hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]

for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))

grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides

return outputs

def preprocess(img, input_size, swap=(2, 0, 1)):
if len(img.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_img = np.ones(input_size, dtype=np.uint8) * 114

r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img

padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r

def inference_detector(session, oriImg):
input_shape = (640,640)
img, ratio = preprocess(oriImg, input_shape)

input = img[None, :, :, :]
outNames = session.getUnconnectedOutLayersNames()
session.setInput(input)
output = session.forward(outNames)

predictions = demo_postprocess(output[0], input_shape)[0]

boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]

boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
boxes_xyxy /= ratio
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
if dets is not None:
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
isscore = final_scores>0.3
iscat = final_cls_inds == 0
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
final_boxes = final_boxes[isbbox]

return final_boxes
Loading

0 comments on commit 4fa9190

Please sign in to comment.