diff --git a/model_api/cpp/models/src/instance_segmentation.cpp b/model_api/cpp/models/src/instance_segmentation.cpp index 82834a8a..59612d64 100644 --- a/model_api/cpp/models/src/instance_segmentation.cpp +++ b/model_api/cpp/models/src/instance_segmentation.cpp @@ -77,27 +77,27 @@ cv::Mat segm_postprocess(const SegmentedObject& box, const cv::Mat& unpadded, in std::vector> average_and_normalize(const std::vector>& saliency_maps) { std::vector> aggregated; aggregated.reserve(saliency_maps.size()); - for (const std::vector& per_class_maps : saliency_maps) { - if (per_class_maps.empty()) { + for (const std::vector& per_object_maps : saliency_maps) { + if (per_object_maps.empty()) { aggregated.emplace_back(); } else { - cv::Mat_ saliency_map{per_class_maps.front().size()}; - for (const cv::Mat& per_class_map : per_class_maps) { - if (saliency_map.size != per_class_map.size) { + cv::Mat_ saliency_map{per_object_maps.front().size()}; + for (const cv::Mat& per_object_map : per_object_maps) { + if (saliency_map.size != per_object_map.size) { throw std::runtime_error("saliency_maps must have same size"); - } if (per_class_map.channels() != 1) { + } if (per_object_map.channels() != 1) { throw std::runtime_error("saliency_maps must have one channel"); - } if (per_class_map.type() != CV_8U) { + } if (per_object_map.type() != CV_8U) { throw std::runtime_error("saliency_maps must have type CV_8U"); } } for (int row = 0; row < saliency_map.rows; ++row) { for (int col = 0; col < saliency_map.cols; ++col) { - double sum = 0.0; - for (const cv::Mat& per_class_map : per_class_maps) { - sum += per_class_map.at(row, col); + std::uint8_t max_val = 0; + for (const cv::Mat& per_object_map : per_object_maps) { + max_val = std::max(max_val, per_object_map.at(row, col)); } - saliency_map.at(row, col) = sum / per_class_maps.size(); + saliency_map.at(row, col) = max_val; } } double min, max; @@ -245,8 +245,8 @@ void MaskRCNNModel::prepareInputsOutputs(std::shared_ptr& model) { filtered.push_back({output.get_any_name(), output.get_partial_shape().get_max_shape().size()}); } } - if (filtered.size() != 3) { - throw std::logic_error(std::string{"MaskRCNNModel model wrapper supports topologies with "} + saliency_map_name + ", " + feature_vector_name + " and 3 other outputs"); + if (filtered.size() != 3 && filtered.size() != 4) { + throw std::logic_error(std::string{"MaskRCNNModel model wrapper supports topologies with "} + saliency_map_name + ", " + feature_vector_name + " and 3 or 4 other outputs"); } outputNames.resize(3); for (const NameRank& name_rank : filtered) { @@ -260,6 +260,8 @@ void MaskRCNNModel::prepareInputsOutputs(std::shared_ptr& model) { case 4: outputNames[2] = name_rank.name; break; + case 0: + break; default: throw std::runtime_error("Unexpected output: " + name_rank.name); } diff --git a/model_api/python/openvino/model_api/models/instance_segmentation.py b/model_api/python/openvino/model_api/models/instance_segmentation.py index d5fb4426..6685edc4 100644 --- a/model_api/python/openvino/model_api/models/instance_segmentation.py +++ b/model_api/python/openvino/model_api/models/instance_segmentation.py @@ -27,7 +27,7 @@ class MaskRCNNModel(ImageModel): def __init__(self, inference_adapter, configuration, preload=False): super().__init__(inference_adapter, configuration, preload) - self._check_io_number((1, 2), (3, 4, 5, 8)) + self._check_io_number((1, 2), (3, 4, 5, 6, 8)) if self.path_to_labels: self.labels = load_labels(self.path_to_labels) self.is_segmentoly = len(self.inputs) == 2 @@ -228,9 +228,9 @@ def postprocess(self, outputs, meta): def _average_and_normalize(saliency_maps): aggregated = [] - for per_class_maps in saliency_maps: - if per_class_maps: - saliency_map = np.array(per_class_maps).mean(0) + for per_object_maps in saliency_maps: + if per_object_maps: + saliency_map = np.max(np.array(per_object_maps), axis=0) max_values = np.max(saliency_map) saliency_map = 255 * (saliency_map) / (max_values + 1e-12) aggregated.append(saliency_map.astype(np.uint8)) diff --git a/model_api/python/openvino/model_api/tilers/__init__.py b/model_api/python/openvino/model_api/tilers/__init__.py new file mode 100644 index 00000000..06be1a4d --- /dev/null +++ b/model_api/python/openvino/model_api/tilers/__init__.py @@ -0,0 +1,24 @@ +""" + Copyright (C) 2023 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + +from .detection import DetectionTiler +from .instance_segmentation import InstanceSegmentationTiler + +__all__ = [ + "DetectionTiler", + "InstanceSegmentationTiler", +] diff --git a/model_api/python/openvino/model_api/tilers/detection.py b/model_api/python/openvino/model_api/tilers/detection.py new file mode 100644 index 00000000..73967feb --- /dev/null +++ b/model_api/python/openvino/model_api/tilers/detection.py @@ -0,0 +1,284 @@ +""" + Copyright (c) 2023 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import cv2 as cv +import numpy as np +from openvino.model_api.models.types import NumericalValue +from openvino.model_api.models.utils import Detection, DetectionResult, nms + +from .tiler import Tiler + + +class DetectionTiler(Tiler): + """ + Tiler for object detection models. + This tiler expects model to output a lsit of `Detection` objects + or one `DetectionResult` object. + """ + + def __init__(self, model, configuration=None, execution_mode="async"): + super().__init__(model, configuration, execution_mode) + + @classmethod + def parameters(cls): + """Defines the description and type of configurable data parameters for the tiler. + + Returns: + - the dictionary with defined wrapper tiler parameters + """ + parameters = super().parameters() + parameters.update( + { + "max_pred_number": NumericalValue( + value_type=int, + default_value=100, + min=1, + description="Maximum numbers of prediction per image", + ), + } + ) + return parameters + + def _postprocess_tile(self, predictions, coord): + """Converts predictions to a format convinient for further merging. + + Args: + predictions: predictions from a detection model: a list of `Detection` objects + or one `DetectionResult` + coord: a list containing coordinates for the processed tile + + Returns: + a dict with postprocessed predictions in 6-items format: (label id, score, bbox) + """ + + output_dict = {} + if hasattr(predictions, "objects"): + detections = _detection2array(predictions.objects) + elif hasattr(predictions, "segmentedObjects"): + detections = _detection2array(predictions.segmentedObjects) + else: + raise RuntimeError("Unsupported model predictions fromat") + + output_dict["saliency_map"] = predictions.saliency_map + output_dict["features"] = predictions.feature_vector + + offset_x, offset_y = coord[:2] + detections[:, 2:] += np.tile((offset_x, offset_y), 2) + output_dict["bboxes"] = detections + output_dict["coords"] = coord + + return output_dict + + def _merge_results(self, results, shape): + """Merge results from all tiles. + + To merge detections, per-class NMS is applied. + + Args: + results: list of per-tile results + shape: original full-res image shape + Returns: + merged prediciton + """ + + detections_array = np.empty((0, 6), dtype=np.float32) + feature_vectors = [] + saliency_maps = [] + tiles_coords = [] + for result in results: + if len(result["bboxes"]): + detections_array = np.concatenate((detections_array, result["bboxes"])) + feature_vectors.append(result["features"]) + saliency_maps.append(result["saliency_map"]) + tiles_coords.append(result["coords"]) + + if np.prod(detections_array.shape): + detections_array, _ = _multiclass_nms( + detections_array, max_num=self.max_pred_number + ) + + merged_vector = ( + np.mean(feature_vectors, axis=0) if feature_vectors else np.ndarray(0) + ) + saliency_map = ( + self._merge_saliency_maps(saliency_maps, shape, tiles_coords) + if saliency_maps + else np.ndarray(0) + ) + + detected_objects = [] + for i in range(detections_array.shape[0]): + label = int(detections_array[i][0]) + score = float(detections_array[i][1]) + bbox = list(detections_array[i][2:]) + detected_objects.append( + Detection(*bbox, score, label, self.model.labels[label]) + ) + + return DetectionResult( + detected_objects, + saliency_map, + merged_vector, + ) + + def _merge_saliency_maps(self, saliency_maps, shape, tiles_coords): + """Merged saliency maps from each tile + + Args: + saliency_maps: list of saliency maps, shape of each map is (Nc, H, W) + shape: shape of the original image + tiles_coords: coordinates of tiles + + Returns: + Merged saliency map with shape (Nc, H, W) + """ + + if not saliency_maps: + return None + + image_saliency_map = saliency_maps[0] + + if len(image_saliency_map.shape) == 1: + return image_saliency_map + + recover_shape = False + if len(image_saliency_map.shape) == 4: + recover_shape = True + image_saliency_map = image_saliency_map.squeeze(0) + + num_classes = image_saliency_map.shape[0] + map_h, map_w = image_saliency_map.shape[1:] + + image_h, image_w, _ = shape + ratio = map_h / self.tile_size, map_w / self.tile_size + + image_map_h = int(image_h * ratio[0]) + image_map_w = int(image_w * ratio[1]) + merged_map = np.zeros((num_classes, image_map_h, image_map_w)) + + for i, saliency_map in enumerate(saliency_maps[1:], 1): + for class_idx in range(num_classes): + if len(saliency_map.shape) == 4: + saliency_map = saliency_map.squeeze(0) + + cls_map = saliency_map[class_idx] + + x_1, y_1, x_2, y_2 = tiles_coords[i] + y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1]) + y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1]) + + map_h, map_w = cls_map.shape + + if (map_h > y_2 - y_1 > 0) and (map_w > x_2 - x_1 > 0): + cls_map = cv.resize(cls_map, (x_2 - x_1, y_2 - y_1)) + + map_h, map_w = y_2 - y_1, x_2 - x_1 + + for hi, wi in [(h_, w_) for h_ in range(map_h) for w_ in range(map_w)]: + map_pixel = cls_map[hi, wi] + merged_pixel = merged_map[class_idx][y_1 + hi, x_1 + wi] + if merged_pixel != 0: + merged_map[class_idx][y_1 + hi, x_1 + wi] = 0.5 * ( + map_pixel + merged_pixel + ) + else: + merged_map[class_idx][y_1 + hi, x_1 + wi] = map_pixel + + for class_idx in range(num_classes): + image_map_cls = image_saliency_map[class_idx] + image_map_cls = cv.resize(image_map_cls, (image_map_w, image_map_h)) + + merged_map[class_idx] += 0.5 * image_map_cls + merged_map[class_idx] = _non_linear_normalization(merged_map[class_idx]) + + if recover_shape: + merged_map = np.expand_dims(merged_map, 0) + + return merged_map.astype(np.uint8) + + +def _non_linear_normalization(saliency_map): + """Use non-linear normalization y=x**1.5 for 2D saliency maps.""" + + min_soft_score = np.min(saliency_map) + # make merged_map distribution positive to perform non-linear normalization y=x**1.5 + saliency_map = (saliency_map - min_soft_score) ** 1.5 + + max_soft_score = np.max(saliency_map) + saliency_map = 255.0 / (max_soft_score + 1e-12) * saliency_map + + return np.floor(saliency_map) + + +def _multiclass_nms( + detections, + iou_threshold=0.45, + max_num=200, +): + """Multi-class NMS. + + strategy: in order to perform NMS independently per class, + we add an offset to all the boxes. The offset is dependent + only on the class idx, and is large enough so that boxes + from different classes do not overlap + + Args: + detections (np.ndarray): labels, scores and boxes + iou_threshold (float, optional): IoU threshold. Defaults to 0.45. + max_num (int, optional): Max number of objects filter. Defaults to 200. + + Returns: + tuple: (dets, indices), Dets are boxes with scores. Indices are indices of kept boxes. + """ + labels = detections[:, 0] + scores = detections[:, 1] + boxes = detections[:, 2:] + max_coordinate = boxes.max() + offsets = labels.astype(boxes.dtype) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + + keep = nms(*boxes_for_nms.T, scores, iou_threshold) + if max_num > 0: + keep = keep[:max_num] + keep = np.array(keep) + det = detections[keep] + return det, keep + + +def _detection2array(detections): + """Convert list of OpenVINO Detection to a numpy array. + + Args: + detections (List): List of OpenVINO Detection containing score, id, xmin, ymin, xmax, ymax + + Returns: + np.ndarray: numpy array with [label, confidence, x1, y1, x2, y2] + """ + scores = np.empty((0, 1), dtype=np.float32) + labels = np.empty((0, 1), dtype=np.uint32) + boxes = np.empty((0, 4), dtype=np.float32) + for det in detections: + if (det.xmax - det.xmin) * (det.ymax - det.ymin) < 1.0: + continue + scores = np.append(scores, [[det.score]], axis=0) + labels = np.append(labels, [[det.id]], axis=0) + boxes = np.append( + boxes, + [[float(det.xmin), float(det.ymin), float(det.xmax), float(det.ymax)]], + axis=0, + ) + detections = np.concatenate((labels, scores, boxes), -1) + return detections diff --git a/model_api/python/openvino/model_api/tilers/instance_segmentation.py b/model_api/python/openvino/model_api/tilers/instance_segmentation.py new file mode 100644 index 00000000..bd6bd451 --- /dev/null +++ b/model_api/python/openvino/model_api/tilers/instance_segmentation.py @@ -0,0 +1,215 @@ +""" + Copyright (c) 2023 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import cv2 as cv +import numpy as np +from openvino.model_api.models.instance_segmentation import _segm_postprocess +from openvino.model_api.models.utils import InstanceSegmentationResult, SegmentedObject + +from .detection import DetectionTiler, _multiclass_nms + + +class InstanceSegmentationTiler(DetectionTiler): + """ + Tiler for object instance segmentation models. + This tiler expects model to output a lsit of `SegmentedObject` objects. + + In addition, this tiler allows to use a tile classifier model, + which predicts objectness score for each tile. Later, tiles can + be filtered by this score. + """ + + def __init__( + self, + model, + configuration=None, + execution_mode="async", + tile_classifier_model=None, + ): + """ + Constructor for creating a semantic segmentation tiling pipeline + + Args: + model: underlying model + configuration: it contains values for parameters accepted by specific + tiler (`tile_size`, `tiles_overlap` etc.) which are set as data attributes. + execution_mode: Controls inference mode of the tiler (`async` or `sync`). + tile_classifier_model: an `ImageModel`, which has "tile_prob" output. + """ + super().__init__(model, configuration, execution_mode) + self.tile_classifier_model = tile_classifier_model + + def _filter_tiles(self, image, tile_coords, confidence_threshold=0.35): + """Filter tiles by objectness score provided by tile classifier + + Args: + image: full size image + tile_coords: tile coordinates + + Returns: + tile coordinates to keep + """ + if self.tile_classifier_model is not None: + keep_coords = [] + for i, coord in enumerate(tile_coords): + tile_img = self._crop_tile(image, coord) + tile_dict, _ = self.model.preprocess(tile_img) + cls_outputs = self.tile_classifier_model.infer_sync(tile_dict) + if i == 0 or cls_outputs["tile_prob"] > confidence_threshold: + keep_coords.append(coord) + return keep_coords + + return tile_coords + + def _postprocess_tile(self, predictions, coord): + """Converts predictions to a format convinient for further merging. + + Args: + predictions: predictions from an instance segmentation model: a list of `SegmentedObject` objects + coord: a list containing coordinates for the processed tile + + Returns: + a dict with postprocessed detections in 6-items format: (label id, score, bbox) and masks + """ + + output_dict = super()._postprocess_tile(predictions, coord) + output_dict["masks"] = [] + for segm_res in predictions.segmentedObjects: + output_dict["masks"].append(segm_res.mask) + + return output_dict + + def _merge_results(self, results, shape): + """Merge results from all tiles. + + To merge detections, per-class NMS is applied. + + Args: + results: list of per-tile results + shape: original full-res image shape + Returns: + merged prediciton + """ + + detections_array = np.empty((0, 6), dtype=np.float32) + feature_vectors = [] + saliency_maps = [] + tiles_coords = [] + masks = [] + for result in results: + if len(result["bboxes"]): + detections_array = np.concatenate((detections_array, result["bboxes"])) + feature_vectors.append(result["features"]) + saliency_maps.append(result["saliency_map"]) + tiles_coords.append(result["coords"]) + if len(result["masks"]): + masks.extend(result["masks"]) + + keep_idxs = [] + if np.prod(detections_array.shape): + detections_array, keep_idxs = _multiclass_nms( + detections_array, max_num=self.max_pred_number + ) + masks = [masks[keep_idx] for keep_idx in keep_idxs] + + merged_vector = ( + np.mean(feature_vectors, axis=0) if feature_vectors else np.ndarray(0) + ) + saliency_map = ( + self._merge_saliency_maps(saliency_maps, shape, tiles_coords) + if saliency_maps + else [] + ) + + detected_objects = [] + for i in range(detections_array.shape[0]): + label = int(detections_array[i][0]) + score = float(detections_array[i][1]) + bbox = list(detections_array[i][2:]) + detected_objects.append( + SegmentedObject(*bbox, score, label, self.model.labels[label], masks[i]) + ) + + for i, (det, mask) in enumerate(zip(detected_objects, masks)): + box = np.array([det.xmin, det.ymin, det.xmax, det.ymax]) + masks[i] = _segm_postprocess(box, mask, *shape[:-1]) + + return InstanceSegmentationResult( + detected_objects, + saliency_map, + merged_vector, + ) + + def _merge_saliency_maps(self, saliency_maps, shape, tiles_coords): + """Merged saliency maps from each tile + + Args: + saliency_maps: list of saliency maps, shape of each map is (Nc, H, W) + shape: shape of the original image + tiles_coords: coordinates of tiles + + Returns: + Merged saliency map with shape (Nc, H, W) + """ + + if not saliency_maps: + return None + + image_saliency_map = saliency_maps[0] + + if not image_saliency_map: + return image_saliency_map + + num_classes = len(image_saliency_map) + map_h, map_w = image_saliency_map[0].shape + image_h, image_w, _ = shape + + ratio = map_h / self.tile_size, map_w / self.tile_size + image_map_h = int(image_h * ratio[0]) + image_map_w = int(image_w * ratio[1]) + + merged_map = [np.zeros((image_map_h, image_map_w)) for _ in range(num_classes)] + + for i, saliency_map in enumerate(saliency_maps[1:], 1): + for class_idx in range(num_classes): + cls_map = saliency_map[class_idx] + if len(cls_map.shape) < 2: + continue + + x_1, y_1, x_2, y_2 = tiles_coords[i] + y_1, x_1 = int(y_1 * ratio[0]), int(x_1 * ratio[1]) + y_2, x_2 = int(y_2 * ratio[0]), int(x_2 * ratio[1]) + + map_h, map_w = cls_map.shape + + cls_map = cv.resize(cls_map, (x_2 - x_1, y_2 - y_1)) + + map_h, map_w = y_2 - y_1, x_2 - x_1 + + tile_map = merged_map[class_idx][y_1 : y_1 + map_h, x_1 : x_1 + map_w] + merged_map[class_idx][ + y_1 : y_1 + map_h, x_1 : x_1 + map_w + ] = np.maximum(tile_map, cls_map) + + for class_idx in range(num_classes): + image_map_cls = image_saliency_map[class_idx] + if len(image_map_cls.shape) < 2: + continue + image_map_cls = cv.resize(image_map_cls, (image_map_w, image_map_h)) + merged_map[class_idx] += 0.5 * image_map_cls + merged_map[class_idx] = merged_map[class_idx].astype(np.uint8) + + return merged_map diff --git a/model_api/python/openvino/model_api/tilers/tiler.py b/model_api/python/openvino/model_api/tilers/tiler.py new file mode 100644 index 00000000..5d314585 --- /dev/null +++ b/model_api/python/openvino/model_api/tilers/tiler.py @@ -0,0 +1,286 @@ +""" + Copyright (c) 2023 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import abc +import logging as log +from itertools import product + +from openvino.model_api.models.types import NumericalValue +from openvino.model_api.pipelines import AsyncPipeline + + +class Tiler(metaclass=abc.ABCMeta): + EXECUTION_MODES = ["async", "sync"] + """ + An abstract tiler + + The abstract tiler is free from any executor dependencies. + It sets the `Model` instance with the provided model + and applys it to tiles of the input image, and then merges + results from all tiles. + + The `_postprocess_tile` and `_merge_results` methods must be implemented in a specific inherited tiler. + Attributes: + logger (Logger): instance of the Logger + model (Model): model being executed + model_loaded (bool): a flag whether the model is loaded to device + async_pipeline (AsyncPipeline): a pipeline for asynchronous execution mode + execution_mode: Controls inference mode of the tiler (`async` or `sync`). + """ + + def __init__(self, model, configuration=None, execution_mode="async"): + """ + Base constructor for creating a tiling pipeline + + Args: + model: underlying model + configuration: it contains values for parameters accepted by specific + tiler (`tile_size`, `tiles_overlap` etc.) which are set as data attributes. + execution_mode: Controls inference mode of the tiler (`async` or `sync`). + """ + + self.logger = log.getLogger() + self.model = model + for name, parameter in self.parameters().items(): + self.__setattr__(name, parameter.default_value) + self._load_config(configuration) + self.async_pipeline = AsyncPipeline(self.model) + if execution_mode not in Tiler.EXECUTION_MODES: + raise ValueError( + f"Wrong execution mode. The following modes are supported {Tiler.EXECUTION_MODES}" + ) + self.execution_mode = execution_mode + + def get_model(self): + """Getter for underlying model""" + return self.model + + @classmethod + def parameters(cls): + """Defines the description and type of configurable data parameters for the tiler. + + The structure is similar to model wrapper parameters. + + Returns: + - the dictionary with defined wrapper tiler parameters + """ + parameters = {} + parameters.update( + { + "tile_size": NumericalValue( + value_type=int, + default_value=400, + min=1, + description="Size of one tile", + ), + "tiles_overlap": NumericalValue( + value_type=float, + default_value=0.5, + min=0.0, + max=1.0, + description="Overlap of tiles", + ), + } + ) + return parameters + + def _load_config(self, config): + """Reads the configuration and creates data attributes + by setting the wrapper parameters with values from configuration. + + Args: + config (dict): the dictionary with keys to be set as data attributes + and its values. The example of the config is the following: + { + 'confidence_threshold': 0.5, + 'resize_type': 'fit_to_window', + } + + Note: + The config keys should be provided in `parameters` method for each wrapper, + then the default value of the parameter will be updated. If some key presented + in the config is not introduced in `parameters`, it will be omitted. + + Raises: + RuntimeError: if the configuration is incorrect + """ + parameters = self.parameters() + + for name, param in parameters.items(): + try: + value = param.from_str( + self.model.inference_adapter.get_rt_info( + ["model_info", name] + ).astype(str) + ) + self.__setattr__(name, value) + except RuntimeError as error: + missing_rt_info = ( + "Cannot get runtime attribute. Path to runtime attribute is incorrect." + in str(error) + ) + is_OVMSAdapter = ( + str(error) == "OVMSAdapter does not support RT info getting" + ) + if not missing_rt_info and not is_OVMSAdapter: + raise + + for name, value in config.items(): + if value is None: + continue + if name in parameters: + errors = parameters[name].validate(value) + if errors: + self.logger.error(f'Error with "{name}" parameter:') + for error in errors: + self.logger.error(f"\t{error}") + raise RuntimeError("Incorrect user configuration") + value = parameters[name].get_value(value) + self.__setattr__(name, value) + else: + self.logger.warning( + f'The parameter "{name}" not found in tiler, will be omitted' + ) + + def __call__(self, inputs): + """ + Applies full pipeline of tiling inference in one call. + + Args: + inputs: raw input data, the data type is defined by underlying model wrapper + + Returns: + - postprocessed data in the format defined by underlying model wrapper + """ + + tile_coords = self._tile(inputs) + tile_coords = self._filter_tiles(inputs, tile_coords) + + if self.execution_mode == "sync": + return self._predict_sync(inputs, tile_coords) + return self._predict_async(inputs, tile_coords) + + def _tile(self, image): + """Tiles an input image to overlapping or non-overlapping patches. + + This method implementation also adds the full image as the first tile to process. + + Args: + image: Input image to tile. + + Returns: + Tiles coordinates + """ + height, width = image.shape[:2] + + coords = [[0, 0, width, height]] + for loc_j, loc_i in product( + range(0, width, int(self.tile_size * (1 - self.tiles_overlap))), + range(0, height, int(self.tile_size * (1 - self.tiles_overlap))), + ): + x2 = min(loc_j + self.tile_size, width) + y2 = min(loc_i + self.tile_size, height) + coords.append([loc_j, loc_i, x2, y2]) + + return coords + + def _filter_tiles(self, image, tile_coords): + """Filter tiles by some criterion + + Args: + image: full size image + tile_coords: tile coordinates + + Returns: + keep_coords: tile coordinates to keep + """ + return tile_coords + + def _predict_sync(self, image, tile_coords): + """Makes prediction by splitting the input image into tiles in synchronous mode. + + Args: + image: full size image + tile_coords: list of tile coordinates + + Returns: + Inference results aggregated from all tiles + """ + tile_results = [] + for coord in tile_coords: + tile_img = self._crop_tile(image, coord) + tile_predictions = self.model(tile_img) + tile_result = self._postprocess_tile(tile_predictions, coord) + tile_results.append(tile_result) + + return self._merge_results(tile_results, image.shape) + + def _predict_async(self, image, tile_coords): + """Makes prediction by splitting the input image into tiles in asynchronous mode. + + Args: + image: full size image + tile_coords: tile coordinates + + Returns: + Inference results aggregated from all tiles + """ + for i, coord in enumerate(tile_coords): + self.async_pipeline.submit_data(self._crop_tile(image, coord), i) + self.async_pipeline.await_all() + + num_tiles = len(tile_coords) + tile_results = [] + for j in range(num_tiles): + tile_prediction, _ = self.async_pipeline.get_result(j) + tile_result = self._postprocess_tile(tile_prediction, tile_coords[j]) + tile_results.append(tile_result) + + return self._merge_results(tile_results, image.shape) + + @abc.abstractmethod + def _postprocess_tile(self, predictions, coord): + """Postprocesses predicitons made by a model from one tile. + + Args: + predictions: model-dependent set of predicitons or one prediciton + coord: a list containing coordinates for the processed tile + + Returns: + Postprocessed predictions + """ + + @abc.abstractmethod + def _merge_results(self, results, shape): + """Merge results from all tiles. + + Args: + results: list of tile results + shape: original full-res image shape + """ + + def _crop_tile(self, image, coord): + """Crop tile from the full image. + + Args: + image: full-res image + coord: tile coordinates + + Returns: + cropped tile + """ + x1, y1, x2, y2 = coord + return image[y1:y2, x1:x2] diff --git a/tests/cpp/accuracy/test_accuracy.cpp b/tests/cpp/accuracy/test_accuracy.cpp index 8f9e6a2d..03481cec 100644 --- a/tests/cpp/accuracy/test_accuracy.cpp +++ b/tests/cpp/accuracy/test_accuracy.cpp @@ -37,6 +37,7 @@ struct ModelData { std::string name; std::string type; std::vector testData; + std::string tiler; }; class ModelParameterizedTest : public testing::TestWithParam { @@ -65,6 +66,9 @@ inline void from_json(const nlohmann::json& j, ModelData& test) } test.testData.push_back(data); } + if (j.contains("tiler")) { + test.tiler = j.at("tiler").get(); + } } namespace { @@ -80,6 +84,10 @@ std::vector GetTestData(const std::string& path) TEST_P(ModelParameterizedTest, AccuracyTest) { auto modelData = GetParam(); + if (modelData.tiler.size()) { + return; + } + std::string modelPath; const std::string& name = modelData.name; if (name.substr(name.size() - 4) == ".xml") { diff --git a/tests/python/accuracy/prepare_data.py b/tests/python/accuracy/prepare_data.py index ba290f8a..7bfb32a1 100644 --- a/tests/python/accuracy/prepare_data.py +++ b/tests/python/accuracy/prepare_data.py @@ -62,3 +62,5 @@ def prepare_data(data_dir="./data"): retrieve_otx_model(args.data_dir, "detection_model_with_xai_head") retrieve_otx_model(args.data_dir, "segmentation_model_with_xai_head") retrieve_otx_model(args.data_dir, "maskrcnn_model_with_xai_head") + retrieve_otx_model(args.data_dir, "maskrcnn_xai_tiling") + retrieve_otx_model(args.data_dir, "tile_classifier") diff --git a/tests/python/accuracy/public_scope.json b/tests/python/accuracy/public_scope.json index e16ef937..b7f1e552 100644 --- a/tests/python/accuracy/public_scope.json +++ b/tests/python/accuracy/public_scope.json @@ -248,5 +248,31 @@ "reference": ["61, 277, 358, 382, 17 (horse): 0.998, 18312, RotatedRect: 212.000 327.000 290.000 100.000 0.000; 1, 14, 162, 321, 2 (car): 0.994, 25867, RotatedRect: 58.915 191.505 255.448 156.624 55.105, RotatedRect: 47.600 24.800 18.783 23.255 26.565; 327, 96, 341, 134, 1 (bicycle): 0.930, 279, RotatedRect: 333.500 114.000 36.000 13.000 90.000; 460, 106, 493, 148, 1 (bicycle): 0.898, 786, RotatedRect: 476.284 126.621 27.308 45.993 19.179; 294, 93, 315, 153, 1 (bicycle): 0.869, 789, RotatedRect: 304.000 124.000 58.000 18.000 90.000; 278, 109, 290, 152, 1 (bicycle): 0.817, 355, RotatedRect: 283.500 130.000 42.000 11.000 90.000; 4, 4, 102, 191, 2 (car): 0.701, 9658, RotatedRect: 51.806 97.259 184.445 95.281 89.246; 270, 93, 290, 152, 1 (bicycle): 0.660, 723, RotatedRect: 280.500 122.500 17.000 59.000 0.000; 322, 114, 343, 152, 18 (sheep): 0.520, 298, RotatedRect: 332.000 133.000 34.000 14.000 90.000; 6; [1,1280,1,1]"] } ] + }, + { + "name": "otx_models/maskrcnn_xai_tiling.xml", + "type": "MaskRCNNModel", + "tiler": "InstanceSegmentationTiler", + "extra_name": "otx_models/tile_classifier.xml", + "extra_type": "ImageModel", + "input_res": "(3500,3500)", + "test_data": [ + { + "image": "coco128/images/train2017/000000000074.jpg", + "reference": ["1535.0, 585.0, 1662.0, 697.0, 2 (ellipse): 0.643, 9822, RotatedRect: 1598.500 641.500 111.000 109.000 90.000; 3091.0, 3097.0, 3105.0, 3112.0, 1 (rectangle): 0.483, 197, RotatedRect: 159.500 166.000 14.000 13.000 90.000; 2734.0, 60.0, 2867.0, 324.0, 1 (rectangle): 0.401, 30622, RotatedRect: 2800.000 188.500 255.000 132.000 90.000; 4; [1,1280,1,1]"] + } + ] + }, + { + "name": "otx_models/detection_model_with_xai_head.xml", + "type": "DetectionModel", + "tiler": "DetectionTiler", + "input_res": "(3500,3500)", + "test_data": [ + { + "image": "coco128/images/train2017/000000000074.jpg", + "reference": ["336.0, 2275.0, 1944.0, 3114.0, 1 (person): 0.361; 2523.0, 862.0, 2709.0, 1224.0, 1 (person): 0.313; [1,2,35,46]; [1,320,1,1]"] + } + ] } ] diff --git a/tests/python/accuracy/test_accuracy.py b/tests/python/accuracy/test_accuracy.py index 2e3b0094..99361460 100644 --- a/tests/python/accuracy/test_accuracy.py +++ b/tests/python/accuracy/test_accuracy.py @@ -15,6 +15,7 @@ SegmentationModel, add_rotated_rects, ) +from openvino.model_api.tilers import DetectionTiler, InstanceSegmentationTiler def read_config(path: Path): @@ -44,7 +45,18 @@ def test_image_models(data, dump, result, model_data): name = model_data["name"] if name.endswith(".xml"): name = f"{data}/{name}" + model = eval(model_data["type"]).create_model(name, device="CPU", download_dir=data) + if "tiler" in model_data: + if "extra_model" in model_data: + extra_model = eval(model_data["extra_type"]).create_model( + model_data["extra_model"], device="CPU", download_dir=data + ) + model = eval(model_data["tiler"])( + model, configuration={}, tile_classifier=extra_model + ) + else: + model = eval(model_data["tiler"])(model, configuration={}) if dump: result.append(model_data) @@ -55,6 +67,8 @@ def test_image_models(data, dump, result, model_data): image = cv2.imread(str(image_path)) if image is None: raise RuntimeError("Failed to read the image") + if "input_res" in model_data: + image = cv2.resize(image, eval(model_data["input_res"])) outputs = model(image) if isinstance(outputs, ClassificationResult): assert 1 == len(test_data["reference"]) @@ -98,6 +112,11 @@ def test_image_models(data, dump, result, model_data): save_name = os.path.basename(name) else: save_name = name + ".xml" - model.save(data + "/serialized/" + save_name) + + if "tiler" in model_data: + model.get_model().save(data + "/serialized/" + save_name) + else: + model.save(data + "/serialized/" + save_name) + if dump: result[-1]["test_data"] = inference_results