Skip to content

Commit

Permalink
[python] Add batch inference emulation (#195)
Browse files Browse the repository at this point in the history
* Add batch inference emulation in python

* Refactor kp pipeline to use batch inference

* Protect entry to batch infer with await

* Fix isort

* Del unused args
  • Loading branch information
sovrasov authored Sep 13, 2024
1 parent 43d2078 commit 60ba90a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
14 changes: 1 addition & 13 deletions model_api/python/model_api/models/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any

import numpy as np
from model_api.pipelines import AsyncPipeline

from .image_model import ImageModel
from .types import ListValue
Expand Down Expand Up @@ -87,7 +86,6 @@ class TopDownKeypointDetectionPipeline:

def __init__(self, base_model: KeypointDetectionModel) -> None:
self.base_model = base_model
self.async_pipeline = AsyncPipeline(self.base_model)

def predict(
self, image: np.ndarray, detections: list[Detection]
Expand Down Expand Up @@ -125,17 +123,7 @@ def predict_crops(self, crops: list[np.ndarray]) -> list[DetectedKeypoints]:
Returns:
list[DetectedKeypoints]: per crop keypoints
"""
for i, crop in enumerate(crops):
self.async_pipeline.submit_data(crop, i)
self.async_pipeline.await_all()

num_crops = len(crops)
result = []
for j in range(num_crops):
crop_prediction, _ = self.async_pipeline.get_result(j)
result.append(crop_prediction)

return result
return self.base_model.infer_batch(crops)


def _decode_simcc(
Expand Down
36 changes: 36 additions & 0 deletions model_api/python/model_api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging as log
import re
from contextlib import contextmanager

from model_api.adapters.inference_adapter import InferenceAdapter
from model_api.adapters.onnx_adapter import ONNXRuntimeAdapter
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(self, inference_adapter, configuration=dict(), preload=False):
self.model_loaded = False
if preload:
self.load()
self.callback_fn = lambda _: None

def get_model(self):
"""Returns the ov.Model object stored in the InferenceAdapter.
Expand Down Expand Up @@ -402,6 +404,40 @@ def __call__(self, inputs):
raw_result = self.infer_sync(dict_data)
return self.postprocess(raw_result, input_meta)

def infer_batch(self, inputs):
"""
Applies preprocessing, asynchronous inference, postprocessing routines to a collection of inputs.
Args:
inputs (list): a list of inputs for inference
Returns:
list: a list of inference results
"""
self.await_all()

completed_results = {}

@contextmanager
def tmp_callback():
old_callback = self.callback_fn

def batch_infer_callback(result, id):
completed_results[id] = result

try:
self.set_callback(batch_infer_callback)
yield
finally:
self.set_callback(old_callback)

with tmp_callback():
for i, input in enumerate(inputs):
self.infer_async(input, i)
self.await_all()

return [completed_results[i] for i in range(len(inputs))]

def load(self, force=False):
if not self.model_loaded or force:
self.model_loaded = True
Expand Down

0 comments on commit 60ba90a

Please sign in to comment.