Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python version of tilers #93

Merged
merged 27 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9f43116
Add tilers
sovrasov Jun 27, 2023
75f9fa7
Add xai for detection tiling
sovrasov Jun 30, 2023
26956b5
Recover the initial saliency map dimension
sovrasov Jun 30, 2023
9c02fce
Del type annotation for consistency
sovrasov Jul 1, 2023
b49d60d
Update detection tiler implementation
sovrasov Jul 3, 2023
c5da08e
Update tiler docstrings
sovrasov Jul 3, 2023
e1c69be
Del mata parameter
sovrasov Jul 4, 2023
8cd7acf
Fix detection tiler to support iseg
sovrasov Jul 4, 2023
532a909
Add explain to IS tiler
sovrasov Jul 5, 2023
c7a6673
Update vars naming in is wrappers
sovrasov Jul 5, 2023
d0af62e
UFix black
sovrasov Jul 5, 2023
3d9e235
Del unnecessary meta params
sovrasov Jul 5, 2023
5df7e9e
Fix IS tiling explain
sovrasov Jul 5, 2023
0cf63df
Enable 6 outputs for MRCNN
sovrasov Jul 5, 2023
d40a869
Apply max normalization for IS sal maps
sovrasov Jul 5, 2023
c84dc20
Adjust maskrcnn cpp wrapper to tiled verion of model
sovrasov Jul 6, 2023
cfdf887
Add data for tiling
sovrasov Jul 6, 2023
018fce4
Add a getter for internal model in tiler
sovrasov Jul 6, 2023
1c43887
Add tests for tiling
sovrasov Jul 6, 2023
2bd074c
Fix black in tests
sovrasov Jul 6, 2023
f7da575
Skip tilers in cpp tests
sovrasov Jul 6, 2023
e3d8e76
Eliminate one redundand assignment
sovrasov Jul 6, 2023
a0bd304
Fix isort
sovrasov Jul 6, 2023
6e99648
Fix black in is tiler
sovrasov Jul 6, 2023
85de6b9
Move masks collection to the main collection loop
sovrasov Jul 6, 2023
97ac9d4
Fix typo
sovrasov Jul 6, 2023
22dce9a
Del extra masks check
sovrasov Jul 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions model_api/cpp/models/src/instance_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,27 @@ cv::Mat segm_postprocess(const SegmentedObject& box, const cv::Mat& unpadded, in
std::vector<cv::Mat_<std::uint8_t>> average_and_normalize(const std::vector<std::vector<cv::Mat>>& saliency_maps) {
std::vector<cv::Mat_<std::uint8_t>> aggregated;
aggregated.reserve(saliency_maps.size());
for (const std::vector<cv::Mat>& per_class_maps : saliency_maps) {
if (per_class_maps.empty()) {
for (const std::vector<cv::Mat>& per_object_maps : saliency_maps) {
if (per_object_maps.empty()) {
aggregated.emplace_back();
} else {
cv::Mat_<double> saliency_map{per_class_maps.front().size()};
for (const cv::Mat& per_class_map : per_class_maps) {
if (saliency_map.size != per_class_map.size) {
cv::Mat_<double> saliency_map{per_object_maps.front().size()};
for (const cv::Mat& per_object_map : per_object_maps) {
if (saliency_map.size != per_object_map.size) {
throw std::runtime_error("saliency_maps must have same size");
} if (per_class_map.channels() != 1) {
} if (per_object_map.channels() != 1) {
throw std::runtime_error("saliency_maps must have one channel");
} if (per_class_map.type() != CV_8U) {
} if (per_object_map.type() != CV_8U) {
throw std::runtime_error("saliency_maps must have type CV_8U");
}
}
for (int row = 0; row < saliency_map.rows; ++row) {
for (int col = 0; col < saliency_map.cols; ++col) {
double sum = 0.0;
for (const cv::Mat& per_class_map : per_class_maps) {
sum += per_class_map.at<std::uint8_t>(row, col);
std::uint8_t max_val = 0;
for (const cv::Mat& per_object_map : per_object_maps) {
max_val = std::max(max_val, per_object_map.at<std::uint8_t>(row, col));
}
saliency_map.at<double>(row, col) = sum / per_class_maps.size();
saliency_map.at<double>(row, col) = max_val;
}
}
double min, max;
Expand Down Expand Up @@ -245,8 +245,8 @@ void MaskRCNNModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model) {
filtered.push_back({output.get_any_name(), output.get_partial_shape().get_max_shape().size()});
}
}
if (filtered.size() != 3) {
throw std::logic_error(std::string{"MaskRCNNModel model wrapper supports topologies with "} + saliency_map_name + ", " + feature_vector_name + " and 3 other outputs");
if (filtered.size() != 3 && filtered.size() != 4) {
throw std::logic_error(std::string{"MaskRCNNModel model wrapper supports topologies with "} + saliency_map_name + ", " + feature_vector_name + " and 3 or 4 other outputs");
}
outputNames.resize(3);
for (const NameRank& name_rank : filtered) {
Expand All @@ -260,6 +260,8 @@ void MaskRCNNModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model) {
case 4:
outputNames[2] = name_rank.name;
break;
case 0:
break;
default:
throw std::runtime_error("Unexpected output: " + name_rank.name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MaskRCNNModel(ImageModel):

def __init__(self, inference_adapter, configuration, preload=False):
super().__init__(inference_adapter, configuration, preload)
self._check_io_number((1, 2), (3, 4, 5, 8))
self._check_io_number((1, 2), (3, 4, 5, 6, 8))
if self.path_to_labels:
self.labels = load_labels(self.path_to_labels)
self.is_segmentoly = len(self.inputs) == 2
Expand Down Expand Up @@ -228,9 +228,9 @@ def postprocess(self, outputs, meta):

def _average_and_normalize(saliency_maps):
aggregated = []
for per_class_maps in saliency_maps:
if per_class_maps:
saliency_map = np.array(per_class_maps).mean(0)
for per_object_maps in saliency_maps:
if per_object_maps:
saliency_map = np.max(np.array(per_object_maps), axis=0)
max_values = np.max(saliency_map)
saliency_map = 255 * (saliency_map) / (max_values + 1e-12)
aggregated.append(saliency_map.astype(np.uint8))
Expand Down
24 changes: 24 additions & 0 deletions model_api/python/openvino/model_api/tilers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Copyright (C) 2023 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""


from .detection import DetectionTiler
from .instance_segmentation import InstanceSegmentationTiler

__all__ = [
"DetectionTiler",
"InstanceSegmentationTiler",
]
284 changes: 284 additions & 0 deletions model_api/python/openvino/model_api/tilers/detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
"""
Copyright (c) 2023 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import cv2 as cv
import numpy as np
from openvino.model_api.models.types import NumericalValue
from openvino.model_api.models.utils import Detection, DetectionResult, nms

from .tiler import Tiler


class DetectionTiler(Tiler):
"""
Tiler for object detection models.
This tiler expects model to output a lsit of `Detection` objects
or one `DetectionResult` object.
"""

def __init__(self, model, configuration=None, execution_mode="async"):
super().__init__(model, configuration, execution_mode)

@classmethod
def parameters(cls):
"""Defines the description and type of configurable data parameters for the tiler.

Returns:
- the dictionary with defined wrapper tiler parameters
"""
parameters = super().parameters()
parameters.update(
{
"max_pred_number": NumericalValue(
value_type=int,
default_value=100,
min=1,
description="Maximum numbers of prediction per image",
),
}
)
return parameters

def _postprocess_tile(self, predictions, coord):
"""Converts predictions to a format convinient for further merging.

Args:
predictions: predictions from a detection model: a list of `Detection` objects
or one `DetectionResult`
coord: a list containing coordinates for the processed tile

Returns:
a dict with postprocessed predictions in 6-items format: (label id, score, bbox)
"""

output_dict = {}
if hasattr(predictions, "objects"):
detections = _detection2array(predictions.objects)
elif hasattr(predictions, "segmentedObjects"):
detections = _detection2array(predictions.segmentedObjects)
else:
raise RuntimeError("Unsupported model predictions fromat")

output_dict["saliency_map"] = predictions.saliency_map
output_dict["features"] = predictions.feature_vector

offset_x, offset_y = coord[:2]
detections[:, 2:] += np.tile((offset_x, offset_y), 2)
output_dict["bboxes"] = detections
output_dict["coords"] = coord

return output_dict

def _merge_results(self, results, shape):
"""Merge results from all tiles.

To merge detections, per-class NMS is applied.

Args:
results: list of per-tile results
shape: original full-res image shape
Returns:
merged prediciton
"""

detections_array = np.empty((0, 6), dtype=np.float32)
feature_vectors = []
saliency_maps = []
tiles_coords = []
for result in results:
if len(result["bboxes"]):
detections_array = np.concatenate((detections_array, result["bboxes"]))
feature_vectors.append(result["features"])
saliency_maps.append(result["saliency_map"])
tiles_coords.append(result["coords"])

if np.prod(detections_array.shape):
detections_array, _ = _multiclass_nms(
detections_array, max_num=self.max_pred_number
)

merged_vector = (
np.mean(feature_vectors, axis=0) if feature_vectors else np.ndarray(0)
)
saliency_map = (
self._merge_saliency_maps(saliency_maps, shape, tiles_coords)
if saliency_maps
else np.ndarray(0)
)

detected_objects = []
for i in range(detections_array.shape[0]):
label = int(detections_array[i][0])
score = float(detections_array[i][1])
bbox = list(detections_array[i][2:])
detected_objects.append(
Detection(*bbox, score, label, self.model.labels[label])
)

return DetectionResult(
detected_objects,
saliency_map,
merged_vector,
)

def _merge_saliency_maps(self, saliency_maps, shape, tiles_coords):
"""Merged saliency maps from each tile

Args:
saliency_maps: list of saliency maps, shape of each map is (Nc, H, W)
shape: shape of the original image
tiles_coords: coordinates of tiles

Returns:
Merged saliency map with shape (Nc, H, W)
"""

if not saliency_maps:
return None

image_saliency_map = saliency_maps[0]

if len(image_saliency_map.shape) == 1:
return image_saliency_map

recover_shape = False
if len(image_saliency_map.shape) == 4:
recover_shape = True
image_saliency_map = image_saliency_map.squeeze(0)

num_classes = image_saliency_map.shape[0]
map_h, map_w = image_saliency_map.shape[1:]

image_h, image_w, _ = shape
ratio = map_h / self.tile_size, map_w / self.tile_size

image_map_h = int(image_h * ratio[0])
image_map_w = int(image_w * ratio[1])
merged_map = np.zeros((num_classes, image_map_h, image_map_w))

for i, saliency_map in enumerate(saliency_maps[1:], 1):
for class_idx in range(num_classes):
if len(saliency_map.shape) == 4:
saliency_map = saliency_map.squeeze(0)

cls_map = saliency_map[class_idx]

x_1, y_1, x_2, y_2 = tiles_coords[i]
y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1])
y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1])

map_h, map_w = cls_map.shape

if (map_h > y_2 - y_1 > 0) and (map_w > x_2 - x_1 > 0):
cls_map = cv.resize(cls_map, (x_2 - x_1, y_2 - y_1))

map_h, map_w = y_2 - y_1, x_2 - x_1

for hi, wi in [(h_, w_) for h_ in range(map_h) for w_ in range(map_w)]:
map_pixel = cls_map[hi, wi]
merged_pixel = merged_map[class_idx][y_1 + hi, x_1 + wi]
if merged_pixel != 0:
merged_map[class_idx][y_1 + hi, x_1 + wi] = 0.5 * (
map_pixel + merged_pixel
)
else:
merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel

for class_idx in range(num_classes):
image_map_cls = image_saliency_map[class_idx]
image_map_cls = cv.resize(image_map_cls, (image_map_w, image_map_h))

merged_map[class_idx] += 0.5 * image_map_cls
merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx])

if recover_shape:
merged_map = np.expand_dims(merged_map, 0)

return merged_map.astype(np.uint8)


def _non_linear_normalization(saliency_map):
"""Use non-linear normalization y=x**1.5 for 2D saliency maps."""

min_soft_score = np.min(saliency_map)
# make merged_map distribution positive to perform non-linear normalization y=x**1.5
saliency_map = (saliency_map - min_soft_score) ** 1.5

max_soft_score = np.max(saliency_map)
saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map

return np.floor(saliency_map)


def _multiclass_nms(
detections,
iou_threshold=0.45,
max_num=200,
):
"""Multi-class NMS.

strategy: in order to perform NMS independently per class,
we add an offset to all the boxes. The offset is dependent
only on the class idx, and is large enough so that boxes
from different classes do not overlap

Args:
detections (np.ndarray): labels, scores and boxes
iou_threshold (float, optional): IoU threshold. Defaults to 0.45.
max_num (int, optional): Max number of objects filter. Defaults to 200.

Returns:
tuple: (dets, indices), Dets are boxes with scores. Indices are indices of kept boxes.
"""
labels = detections[:, 0]
scores = detections[:, 1]
boxes = detections[:, 2:]
max_coordinate = boxes.max()
offsets = labels.astype(boxes.dtype) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]

keep = nms(*boxes_for_nms.T, scores, iou_threshold)
if max_num > 0:
keep = keep[:max_num]
keep = np.array(keep)
det = detections[keep]
return det, keep


def _detection2array(detections):
"""Convert list of OpenVINO Detection to a numpy array.

Args:
detections (List): List of OpenVINO Detection containing score, id, xmin, ymin, xmax, ymax

Returns:
np.ndarray: numpy array with [label, confidence, x1, y1, x2, y2]
"""
scores = np.empty((0, 1), dtype=np.float32)
labels = np.empty((0, 1), dtype=np.uint32)
boxes = np.empty((0, 4), dtype=np.float32)
for det in detections:
if (det.xmax - det.xmin) * (det.ymax - det.ymin) < 1.0:
continue
scores = np.append(scores, [[det.score]], axis=0)
labels = np.append(labels, [[det.id]], axis=0)
boxes = np.append(
boxes,
[[float(det.xmin), float(det.ymin), float(det.xmax), float(det.ymax)]],
axis=0,
)
detections = np.concatenate((labels, scores, boxes), -1)
return detections
Loading