Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GuardTrackedDetectionsBlock #705

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 62 additions & 58 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@
from inference.core.workflows.core_steps.transformations.dynamic_zones.v1 import (
DynamicZonesBlockV1,
)
from inference.core.workflows.core_steps.transformations.guard_tracked_detections.v1 import (
GuardTrackedDetectionsBlockV1,
)
from inference.core.workflows.core_steps.transformations.image_slicer.v1 import (
ImageSlicerBlockV1,
)
Expand Down Expand Up @@ -284,85 +287,86 @@

def load_blocks() -> List[Type[WorkflowBlock]]:
return [
TimeInZoneBlockV1,
BoundingRectBlockV1,
SegmentAnything2BlockV1,
DetectionsConsensusBlockV1,
ClipComparisonBlockV1,
LMMBlockV1,
LMMForClassificationBlockV1,
OpenAIBlockV1,
CogVLMBlockV1,
OCRModelBlockV1,
YoloWorldModelBlockV1,
RoboflowInstanceSegmentationModelBlockV1,
RoboflowKeypointDetectionModelBlockV1,
RoboflowClassificationModelBlockV1,
RoboflowMultiLabelClassificationModelBlockV1,
RoboflowObjectDetectionModelBlockV1,
BarcodeDetectorBlockV1,
QRCodeDetectorBlockV1,
AbsoluteStaticCropBlockV1,
DynamicCropBlockV1,
DetectionsFilterBlockV1,
DetectionOffsetBlockV1,
ByteTrackerBlockV1,
RelativeStaticCropBlockV1,
DetectionsTransformationBlockV1,
RoboflowDatasetUploadBlockV1,
ContinueIfBlockV1,
PerspectiveCorrectionBlockV1,
DynamicZonesBlockV1,
DetectionsClassesReplacementBlockV1,
ExpressionBlockV1,
PropertyDefinitionBlockV1,
DimensionCollapseBlockV1,
FirstNonEmptyOrDefaultBlockV1,
AntropicClaudeBlockV1,
BackgroundColorVisualizationBlockV1,
BarcodeDetectorBlockV1,
BlurVisualizationBlockV1,
BoundingBoxVisualizationBlockV1,
BoundingRectBlockV1,
ByteTrackerBlockV1,
CameraFocusBlockV1,
CircleVisualizationBlockV1,
ClipComparisonBlockV1,
ClipComparisonBlockV2,
CogVLMBlockV1,
ColorVisualizationBlockV1,
ContinueIfBlockV1,
ConvertGrayscaleBlockV1,
CornerVisualizationBlockV1,
CropVisualizationBlockV1,
DetectionOffsetBlockV1,
DetectionsClassesReplacementBlockV1,
DetectionsConsensusBlockV1,
DetectionsFilterBlockV1,
DetectionsStitchBlockV1,
DetectionsTransformationBlockV1,
DimensionCollapseBlockV1,
DominantColorBlockV1,
DotVisualizationBlockV1,
DynamicCropBlockV1,
DynamicZonesBlockV1,
EllipseVisualizationBlockV1,
ExpressionBlockV1,
FirstNonEmptyOrDefaultBlockV1,
Florence2BlockV1,
GoogleGeminiBlockV1,
GuardTrackedDetectionsBlockV1,
HaloVisualizationBlockV1,
ImageBlurBlockV1,
ImageContoursDetectionBlockV1,
ImagePreprocessingBlockV1,
ImageSlicerBlockV1,
ImageThresholdBlockV1,
JSONParserBlockV1,
LMMBlockV1,
LMMForClassificationBlockV1,
LabelVisualizationBlockV1,
LineCounterBlockV1,
LineCounterZoneVisualizationBlockV1,
MaskVisualizationBlockV1,
OCRModelBlockV1,
OpenAIBlockV1,
OpenAIBlockV2,
PathDeviationAnalyticsBlockV1,
PerspectiveCorrectionBlockV1,
PixelateVisualizationBlockV1,
PixelationCountBlockV1,
PolygonVisualizationBlockV1,
LineCounterZoneVisualizationBlockV1,
TriangleVisualizationBlockV1,
PolygonZoneVisualizationBlockV1,
PropertyDefinitionBlockV1,
QRCodeDetectorBlockV1,
RelativeStaticCropBlockV1,
RoboflowClassificationModelBlockV1,
RoboflowCustomMetadataBlockV1,
DetectionsStitchBlockV1,
ImageSlicerBlockV1,
DominantColorBlockV1,
PixelationCountBlockV1,
RoboflowDatasetUploadBlockV1,
RoboflowDatasetUploadBlockV2,
RoboflowInstanceSegmentationModelBlockV1,
RoboflowKeypointDetectionModelBlockV1,
RoboflowMultiLabelClassificationModelBlockV1,
RoboflowObjectDetectionModelBlockV1,
SIFTBlockV1,
SIFTComparisonBlockV1,
SIFTComparisonBlockV2,
SIFTBlockV1,
TemplateMatchingBlockV1,
ImageBlurBlockV1,
ConvertGrayscaleBlockV1,
ImageThresholdBlockV1,
ImageContoursDetectionBlockV1,
ClipComparisonBlockV2,
CameraFocusBlockV1,
RoboflowDatasetUploadBlockV2,
SegmentAnything2BlockV1,
StabilityAIInpaintingBlockV1,
StitchImagesBlockV1,
OpenAIBlockV2,
JSONParserBlockV1,
TemplateMatchingBlockV1,
TimeInZoneBlockV1,
TriangleVisualizationBlockV1,
VLMAsClassifierBlockV1,
GoogleGeminiBlockV1,
VLMAsDetectorBlockV1,
AntropicClaudeBlockV1,
LineCounterBlockV1,
PolygonZoneVisualizationBlockV1,
Florence2BlockV1,
StabilityAIInpaintingBlockV1,
ImagePreprocessingBlockV1,
PathDeviationAnalyticsBlockV1,
YoloWorldModelBlockV1,
]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import supervision as sv
from pydantic import ConfigDict, Field

from inference.core.workflows.execution_engine.entities.base import (
OutputDefinition,
VideoMetadata,
)
from inference.core.workflows.execution_engine.entities.types import (
FLOAT_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
OBJECT_DETECTION_PREDICTION_KIND,
StepOutputSelector,
WorkflowParameterSelector,
WorkflowVideoMetadataSelector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

OUTPUT_KEY: str = "tracked_detections"
LONG_DESCRIPTION = """
This block stores last known position for each bounding box
If box disappears then this block will bring it back so short gaps are filled with last known box position
The block requires detections to be tracked (i.e. each object must have unique tracker_id assigned,
which persists between frames)
"""


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Guard Tracked Detections",
"version": "v1",
"short_description": "Restore detections that randomly disappear",
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "transformation",
}
)
type: Literal["roboflow_core/guard_tracked_detections@v1"]
metadata: WorkflowVideoMetadataSelector
detections: StepOutputSelector(
kind=[
OBJECT_DETECTION_PREDICTION_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
]
) = Field( # type: ignore
description="Tracked detections",
examples=["$steps.object_detection_model.predictions"],
)
consider_detection_gone_timeout: Union[Optional[int], WorkflowParameterSelector(kind=[FLOAT_KIND])] = Field( # type: ignore
default=2,
description="Drop detections that had not been seen for longer than this timeout (in seconds)",
examples=[2, "$inputs.disappeared_detections_timeout"],
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(
name=OUTPUT_KEY,
kind=[
OBJECT_DETECTION_PREDICTION_KIND,
INSTANCE_SEGMENTATION_PREDICTION_KIND,
],
),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.0.0,<2.0.0"


class GuardTrackedDetectionsBlockV1(WorkflowBlock):
def __init__(self):
self._batch_of_last_known_detections: Dict[
str, Dict[Union[int, str], Tuple[float, sv.Detections]]
] = {}

@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
detections: sv.Detections,
metadata: VideoMetadata,
consider_detection_gone_timeout: float,
) -> BlockResult:
if metadata.comes_from_video_file and metadata.fps != 0:
ts = metadata.frame_number / metadata.fps
else:
ts = metadata.frame_timestamp.timestamp()
if detections.tracker_id is None:
raise ValueError(
f"tracker_id not initialized, {self.__class__.__name__} requires detections to be tracked"
)
cached_detections = self._batch_of_last_known_detections.setdefault(
metadata.video_identifier, {}
)
this_frame_tracked_ids = set()
for i, tracked_id in zip(range(len(detections)), detections.tracker_id):
this_frame_tracked_ids.add(tracked_id)
cached_detections[tracked_id] = (ts, detections[i])
for tracked_id in list(cached_detections.keys()):
last_seen_ts = cached_detections[tracked_id][0]
if ts - last_seen_ts > consider_detection_gone_timeout:
del cached_detections[tracked_id]
return {
OUTPUT_KEY: sv.Detections.merge(
cached_detection[1] for cached_detection in cached_detections.values()
)
}
Loading