Skip to content

Commit

Permalink
Tiling for semantic segmentation (#201)
Browse files Browse the repository at this point in the history
* Add python sseg tiler

* Refactor get contours

* Update python tests

* Update cpp tests

* Fix isort

* Add cppp implementation

* Turn soft prediction on for sseg tiling

* Add checks of the input entities in cpp sseg tiler

* Fix isort
  • Loading branch information
sovrasov authored Sep 25, 2024
1 parent 02cede4 commit b8db281
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 25 deletions.
2 changes: 2 additions & 0 deletions model_api/cpp/models/include/models/segmentation_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,5 @@ class SegmentationModel : public ImageModel {
float soft_threshold = -std::numeric_limits<float>::infinity();
bool return_soft_prediction = true;
};

cv::Mat create_hard_prediction_from_soft_prediction(const cv::Mat& soft_prediction, float soft_threshold, int blur_strength);
22 changes: 11 additions & 11 deletions model_api/cpp/models/src/segmentation_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@
namespace {
constexpr char feature_vector_name[]{"feature_vector"};

cv::Mat get_activation_map(const cv::Mat& features) {
double min_soft_score, max_soft_score;
cv::minMaxLoc(features, &min_soft_score, &max_soft_score);
double factor = 255.0 / (max_soft_score - min_soft_score + 1e-12);

cv::Mat int_act_map;
features.convertTo(int_act_map, CV_8U, factor, -min_soft_score * factor);
return int_act_map;
}
}

cv::Mat create_hard_prediction_from_soft_prediction(const cv::Mat& soft_prediction, float soft_threshold, int blur_strength) {
if (soft_prediction.channels() == 1) {
return soft_prediction;
Expand Down Expand Up @@ -70,17 +81,6 @@ cv::Mat create_hard_prediction_from_soft_prediction(const cv::Mat& soft_predicti
return hard_prediction;
}

cv::Mat get_activation_map(const cv::Mat& features) {
double min_soft_score, max_soft_score;
cv::minMaxLoc(features, &min_soft_score, &max_soft_score);
double factor = 255.0 / (max_soft_score - min_soft_score + 1e-12);

cv::Mat int_act_map;
features.convertTo(int_act_map, CV_8U, factor, -min_soft_score * factor);
return int_act_map;
}
}

std::string SegmentationModel::ModelType = "Segmentation";

void SegmentationModel::init_from_config(const ov::AnyMap& top_priority, const ov::AnyMap& mid_priority) {
Expand Down
36 changes: 36 additions & 0 deletions model_api/cpp/tilers/include/tilers/semantic_segmentation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
// Copyright (C) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

#pragma once
#include <tilers/tiler_base.h>

struct ImageResult;
struct ImageResultWithSoftPrediction;

class SemanticSegmentationTiler : public TilerBase {
public:
SemanticSegmentationTiler(std::shared_ptr<ImageModel> model, const ov::AnyMap& configuration);
virtual std::unique_ptr<ImageResultWithSoftPrediction> run(const ImageInputData& inputData);
virtual ~SemanticSegmentationTiler() = default;

protected:
virtual std::unique_ptr<ResultBase> postprocess_tile(std::unique_ptr<ResultBase>, const cv::Rect&);
virtual std::unique_ptr<ResultBase> merge_results(const std::vector<std::unique_ptr<ResultBase>>&, const cv::Size&, const std::vector<cv::Rect>&);

int blur_strength = -1;
float soft_threshold = -std::numeric_limits<float>::infinity();
bool return_soft_prediction = true;
};
108 changes: 108 additions & 0 deletions model_api/cpp/tilers/src/semantic_segmentation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
// Copyright (C) 2024 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/


#include <vector>
#include <opencv2/core.hpp>

#include <tilers/semantic_segmentation.h>
#include <models/segmentation_model.h>
#include <models/results.h>
#include "utils/common.hpp"

namespace {
void normalize_soft_prediction(cv::Mat& soft_prediction, const cv::Mat& normalize_factor) {
float* data = soft_prediction.ptr<float>(0);
const int num_classes = soft_prediction.channels();
const size_t step_rows = soft_prediction.step[0] / sizeof(float);
const size_t step_cols = soft_prediction.step[1] / sizeof(float);

for (int y = 0; y < soft_prediction.rows; ++y) {
for (int x = 0; x < soft_prediction.cols; ++x) {
int weight = normalize_factor.at<int>(y, x);
if (weight > 0) {
for (int c = 0; c < num_classes; ++c) {
data[y * step_rows + x * step_cols + c] /= weight;
}
}
}
}
}
}

SemanticSegmentationTiler::SemanticSegmentationTiler(std::shared_ptr<ImageModel> _model, const ov::AnyMap& configuration) :
TilerBase(_model, configuration) {
ov::AnyMap extra_config;
try {
auto ov_model = model->getModel();
extra_config = ov_model->get_rt_info<ov::AnyMap>("model_info");
}
catch (const std::runtime_error&) {
extra_config = model->getInferenceAdapter()->getModelConfig();
}

blur_strength = get_from_any_maps("blur_strength", configuration, extra_config, blur_strength);
soft_threshold = get_from_any_maps("soft_threshold", configuration, extra_config, soft_threshold);
return_soft_prediction = get_from_any_maps("return_soft_prediction", configuration, extra_config, return_soft_prediction);
}

std::unique_ptr<ImageResultWithSoftPrediction> SemanticSegmentationTiler::run(const ImageInputData& inputData) {
auto result = this->run_impl(inputData);
return std::unique_ptr<ImageResultWithSoftPrediction>(static_cast<ImageResultWithSoftPrediction*>(result.release()));
}

std::unique_ptr<ResultBase> SemanticSegmentationTiler::postprocess_tile(std::unique_ptr<ResultBase> tile_result, const cv::Rect&) {
ImageResultWithSoftPrediction* soft = dynamic_cast<ImageResultWithSoftPrediction*>(tile_result.get());
if (!soft) {
throw std::runtime_error("SemanticSegmentationTiler requires the underlying model to return ImageResultWithSoftPrediction");
}
return tile_result;
}

std::unique_ptr<ResultBase> SemanticSegmentationTiler::merge_results(const std::vector<std::unique_ptr<ResultBase>>& tiles_results,
const cv::Size& image_size, const std::vector<cv::Rect>& tile_coords) {
if (tiles_results.empty()) {
return std::unique_ptr<ResultBase>(new ImageResultWithSoftPrediction());
}

cv::Mat voting_mask(cv::Size(image_size.width, image_size.height), CV_32SC1, cv::Scalar(0));
auto* sseg_res = static_cast<ImageResultWithSoftPrediction*>(tiles_results[0].get());
cv::Mat merged_soft_prediction(cv::Size(image_size.width, image_size.height), CV_32FC(sseg_res->soft_prediction.channels()), cv::Scalar(0));

for (size_t i = 0; i < tiles_results.size(); ++i) {
auto* sseg_res = static_cast<ImageResultWithSoftPrediction*>(tiles_results[i].get());
voting_mask(tile_coords[i]) += 1;
merged_soft_prediction(tile_coords[i]) += sseg_res->soft_prediction;
}

normalize_soft_prediction(merged_soft_prediction, voting_mask);

cv::Mat hard_prediction = create_hard_prediction_from_soft_prediction(merged_soft_prediction, soft_threshold, blur_strength);

std::unique_ptr<ResultBase> retVal;
if (return_soft_prediction) {
auto* result = new ImageResultWithSoftPrediction();
retVal = std::unique_ptr<ResultBase>(result);
result->soft_prediction = merged_soft_prediction;
result->resultImage = hard_prediction;
}
else {
auto* result = new ImageResult();
retVal = std::unique_ptr<ResultBase>(result);
result->resultImage = hard_prediction;
}
return retVal;
}
18 changes: 10 additions & 8 deletions model_api/python/model_api/models/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,32 @@ def postprocess(self, outputs, meta):
return hard_prediction

def get_contours(
self, hard_prediction: np.ndarray, soft_prediction: np.ndarray
self,
prediction: ImageResultWithSoftPrediction,
) -> list:
height, width = hard_prediction.shape[:2]
n_layers = soft_prediction.shape[2]
n_layers = prediction.soft_prediction.shape[2]

if n_layers == 1:
raise RuntimeError("Cannot get contours from soft prediction with 1 layer")
combined_contours = []
for layer_index in range(1, n_layers): # ignoring background
label = self.get_label_name(layer_index - 1)
if len(soft_prediction.shape) == 3:
current_label_soft_prediction = soft_prediction[:, :, layer_index]
if len(prediction.soft_prediction.shape) == 3:
current_label_soft_prediction = prediction.soft_prediction[
:, :, layer_index
]
else:
current_label_soft_prediction = soft_prediction
current_label_soft_prediction = prediction.soft_prediction

obj_group = hard_prediction == layer_index
obj_group = prediction.resultImage == layer_index
label_index_map = obj_group.astype(np.uint8) * 255

contours, _hierarchy = cv2.findContours(
label_index_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)

for contour in contours:
mask = np.zeros(hard_prediction.shape, dtype=np.uint8)
mask = np.zeros(prediction.resultImage.shape, dtype=np.uint8)
cv2.drawContours(
mask,
np.asarray([contour]),
Expand Down
2 changes: 2 additions & 0 deletions model_api/python/model_api/tilers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

from .detection import DetectionTiler
from .instance_segmentation import InstanceSegmentationTiler
from .semantic_segmentation import SemanticSegmentationTiler
from .tiler import Tiler

__all__ = [
"DetectionTiler",
"InstanceSegmentationTiler",
"Tiler",
"SemanticSegmentationTiler",
]
2 changes: 1 addition & 1 deletion model_api/python/model_api/tilers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
class InstanceSegmentationTiler(DetectionTiler):
"""
Tiler for object instance segmentation models.
This tiler expects model to output a lsit of `SegmentedObject` objects.
This tiler expects model to output a list of `SegmentedObject` objects.
In addition, this tiler allows to use a tile classifier model,
which predicts objectness score for each tile. Later, tiles can
Expand Down
97 changes: 97 additions & 0 deletions model_api/python/model_api/tilers/semantic_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (C) 2024 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import annotations

from contextlib import contextmanager

import numpy as np
from model_api.models import SegmentationModel
from model_api.models.utils import ImageResultWithSoftPrediction

from .tiler import Tiler


class SemanticSegmentationTiler(Tiler):
"""
Tiler for segmentation models.
"""

def _postprocess_tile(
self,
predictions: ImageResultWithSoftPrediction,
coord: list[int],
) -> dict:
"""Converts predictions to a format convenient for further merging.
Args:
predictions (ImageResultWithSoftPrediction): predictions from SegmentationModel
coord (list[int]): coordinates of the tile
Returns:
dict: postprocessed predictions
"""
output_dict = {}
output_dict["coord"] = coord
output_dict["masks"] = predictions.soft_prediction
return output_dict

def _merge_results(
self, results: list[dict], shape: tuple[int, int, int]
) -> ImageResultWithSoftPrediction:
"""Merge the results from all tiles.
Args:
results (list[dict]): list of tile predictions
shape (tuple[int, int, int]): shape of the original image
Returns:
ImageResultWithSoftPrediction: merged predictions
"""
height, width = shape[:2]
num_classes = len(self.model.labels)
full_logits_mask = np.zeros((height, width, num_classes), dtype=np.float32)
vote_mask = np.zeros((height, width), dtype=np.int32)
for result in results:
x1, y1, x2, y2 = result["coord"]
mask = result["masks"]
vote_mask[y1:y2, x1:x2] += 1
full_logits_mask[y1:y2, x1:x2, :] += mask[: y2 - y1, : x2 - x1, :]

full_logits_mask = full_logits_mask / vote_mask[:, :, None]
index_mask = full_logits_mask.argmax(2)
return ImageResultWithSoftPrediction(
resultImage=index_mask,
soft_prediction=full_logits_mask,
feature_vector=np.array([]),
saliency_map=np.array([]),
)

def __call__(self, inputs):
@contextmanager
def setup_segm_model():
return_soft_prediction_state = None
if isinstance(self.model, SegmentationModel):
return_soft_prediction_state = self.model.return_soft_prediction
self.model.return_soft_prediction = True
try:
yield
finally:
if isinstance(self.model, SegmentationModel):
self.model.return_soft_prediction = return_soft_prediction_state

with setup_segm_model():
return super().__call__(inputs)
14 changes: 13 additions & 1 deletion tests/cpp/accuracy/test_accuracy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <adapters/openvino_adapter.h>
#include <tilers/detection.h>
#include <tilers/instance_segmentation.h>
#include <tilers/semantic_segmentation.h>

using json = nlohmann::json;

Expand Down Expand Up @@ -197,7 +198,18 @@ TEST_P(ModelParameterizedTest, AccuracyTest)
throw std::runtime_error{"Failed to read the image"};
}

std::unique_ptr<ImageResult> pred = model->infer(image);
std::unique_ptr<ImageResult> pred;
if (modelData.tiler == "SemanticSegmentationTiler") {
auto tiler = SemanticSegmentationTiler(std::move(model), {});
if (modelData.input_res.height > 0 && modelData.input_res.width > 0) {
cv::resize(image, image, modelData.input_res);
}
pred = tiler.run(image);
}
else {
pred = model->infer(image);
}

ImageResultWithSoftPrediction* soft = dynamic_cast<ImageResultWithSoftPrediction*>(pred.get());
if (soft) {
const std::vector<Contour>& contours = model->getContours(*soft);
Expand Down
1 change: 1 addition & 0 deletions tests/python/accuracy/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def main():
download_otx_model(client, otx_models_dir, "sam_vit_b_zsl_encoder"),
download_otx_model(client, otx_models_dir, "sam_vit_b_zsl_decoder"),
download_otx_model(client, otx_models_dir, "rtmpose_tiny"),
download_otx_model(client, otx_models_dir, "segnext_t_tiling"),
)


Expand Down
Loading

0 comments on commit b8db281

Please sign in to comment.