Skip to content

Commit 5d7e0e6

Browse files
Add Per-Class Confidence Filter workflow block (#2283)
Adds a new transformation block roboflow_core/per_class_confidence_filter@v1 that filters detection predictions using a different confidence threshold per class. Detections are kept only if their confidence is at least the threshold configured for their class, with a configurable default_threshold fallback for classes that are not listed. The existing detections_filter block can express this via its query language, but requires a verbose StatementGroup/BinaryStatement configuration. This block exposes a simple {"class_name": threshold} dictionary for the common per-class case. Supports object detection, instance segmentation, and keypoint detection prediction kinds. Includes a unit test suite covering manifest validation, boundary conditions, unknown-class fallback, empty inputs, and batch processing, plus an integration test exercising the full ExecutionEngine path with synthetic detections. Co-authored-by: Paweł Pęczek <146137186+PawelPeczek-Roboflow@users.noreply.github.com>
1 parent 5cff9b5 commit 5d7e0e6

5 files changed

Lines changed: 576 additions & 0 deletions

File tree

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,9 @@
476476
from inference.core.workflows.core_steps.transformations.image_slicer.v2 import (
477477
ImageSlicerBlockV2,
478478
)
479+
from inference.core.workflows.core_steps.transformations.per_class_confidence_filter.v1 import (
480+
PerClassConfidenceFilterBlockV1,
481+
)
479482
from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import (
480483
PerspectiveCorrectionBlockV1,
481484
)
@@ -744,6 +747,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
744747
DynamicCropBlockV1,
745748
DetectionsFilterBlockV1,
746749
DetectionOffsetBlockV1,
750+
PerClassConfidenceFilterBlockV1,
747751
DepthEstimationBlockV1,
748752
ByteTrackerBlockV1,
749753
RelativeStaticCropBlockV1,

inference/core/workflows/core_steps/transformations/per_class_confidence_filter/__init__.py

Whitespace-only changes.
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from typing import Any, Dict, List, Literal, Optional, Type, Union
2+
3+
import numpy as np
4+
import supervision as sv
5+
from pydantic import ConfigDict, Field
6+
7+
from inference.core.workflows.execution_engine.entities.base import (
8+
Batch,
9+
OutputDefinition,
10+
)
11+
from inference.core.workflows.execution_engine.entities.types import (
12+
DICTIONARY_KIND,
13+
FLOAT_ZERO_TO_ONE_KIND,
14+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
15+
KEYPOINT_DETECTION_PREDICTION_KIND,
16+
OBJECT_DETECTION_PREDICTION_KIND,
17+
Selector,
18+
)
19+
from inference.core.workflows.prototypes.block import (
20+
BlockResult,
21+
WorkflowBlock,
22+
WorkflowBlockManifest,
23+
)
24+
25+
SHORT_DESCRIPTION = "Filter detections by applying a per-class confidence threshold."
26+
27+
LONG_DESCRIPTION = """
28+
Filter detection predictions by applying a different confidence threshold to each class, keeping only detections whose confidence meets or exceeds the threshold configured for their class (with a configurable fallback threshold for classes that are not listed).
29+
30+
## How This Block Works
31+
32+
This block applies class-aware confidence filtering to detection predictions, enabling precise control over which detections are retained based on per-class quality requirements. The block:
33+
34+
1. Takes detection predictions (object detection, instance segmentation, or keypoint detection) and a dictionary mapping class names to confidence thresholds
35+
2. Iterates through each detection, looking up the threshold associated with the detection's class name
36+
3. If the class is not present in the dictionary, falls back to the configurable `default_threshold` value
37+
4. Keeps only the detections whose confidence is greater than or equal to the resolved threshold
38+
5. Returns the filtered detections while preserving all original metadata (class ids, masks, keypoints, tracker ids, etc.)
39+
40+
Unlike a single global confidence threshold, this block lets you demand high-confidence predictions for classes that are prone to false positives while keeping a more permissive threshold for classes that are harder to detect. Unlike the generic detections filter, it exposes a purpose-built dictionary input that maps cleanly to a simple `{"class_name": threshold}` JSON object.
41+
42+
## Common Use Cases
43+
44+
- **Noise-prone classes**: Demand very high confidence (e.g. 0.9) for classes that frequently produce false positives, while accepting lower confidence for well-behaved classes
45+
- **Hard-to-detect classes**: Lower the threshold for classes that the model rarely detects with high confidence so that they are not filtered out entirely
46+
- **Production-grade filtering**: Apply domain-specific thresholds tuned during evaluation so that downstream analytics, alerts, or counting blocks only see detections that meet the project's quality bar
47+
- **Multi-class pipelines**: Combine with object detection models that predict many classes at once when a single global confidence threshold is too coarse
48+
49+
## Connecting to Other Blocks
50+
51+
The filtered predictions from this block can be connected to:
52+
53+
- **Visualization blocks** (Bounding Box Visualization, Label Visualization, Polygon Visualization) to render only detections that cleared their per-class threshold
54+
- **Counting and analytics blocks** (Line Counter, Time in Zone, Velocity) so that metrics reflect only high-quality detections
55+
- **Tracking blocks** (Byte Tracker) so that tracker associations are not polluted by low-confidence noise
56+
- **Storage or sink blocks** (Roboflow Dataset Upload, Webhook Sink, CSV Formatter) so that only detections meeting the quality bar are persisted or transmitted
57+
- **Downstream transformation blocks** (Dynamic Crop, Detection Offset) for subsequent processing on the filtered subset
58+
"""
59+
60+
61+
class BlockManifest(WorkflowBlockManifest):
62+
model_config = ConfigDict(
63+
json_schema_extra={
64+
"name": "Per-Class Confidence Filter",
65+
"version": "v1",
66+
"short_description": SHORT_DESCRIPTION,
67+
"long_description": LONG_DESCRIPTION,
68+
"license": "Apache-2.0",
69+
"block_type": "transformation",
70+
"ui_manifest": {
71+
"section": "flow_control",
72+
"icon": "far fa-filter",
73+
"blockPriority": 2,
74+
},
75+
}
76+
)
77+
type: Literal["roboflow_core/per_class_confidence_filter@v1"]
78+
predictions: Selector(
79+
kind=[
80+
OBJECT_DETECTION_PREDICTION_KIND,
81+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
82+
KEYPOINT_DETECTION_PREDICTION_KIND,
83+
]
84+
) = Field(
85+
description="Detection predictions to filter. Each detection is kept only if its confidence is greater than or equal to the threshold configured for its class (with a fallback to default_threshold for classes that are not listed in class_thresholds).",
86+
examples=["$steps.object_detection_model.predictions"],
87+
)
88+
class_thresholds: Union[
89+
Dict[str, float],
90+
Selector(kind=[DICTIONARY_KIND]),
91+
] = Field(
92+
description="Mapping of class name to minimum confidence threshold. Detections whose class name is present in this dictionary are kept only if their confidence is at least the corresponding threshold. Classes not present fall back to default_threshold. Thresholds should be in the [0.0, 1.0] range.",
93+
examples=[{"person": 0.98, "car": 0.5}, "$inputs.class_thresholds"],
94+
)
95+
default_threshold: Union[float, Selector(kind=[FLOAT_ZERO_TO_ONE_KIND])] = Field(
96+
default=0.3,
97+
description="Confidence threshold applied to detections whose class name is not listed in class_thresholds. Must be in the [0.0, 1.0] range.",
98+
examples=[0.3, "$inputs.default_threshold"],
99+
)
100+
101+
@classmethod
102+
def get_parameters_accepting_batches(cls) -> List[str]:
103+
return ["predictions"]
104+
105+
@classmethod
106+
def describe_outputs(cls) -> List[OutputDefinition]:
107+
return [
108+
OutputDefinition(
109+
name="predictions",
110+
kind=[
111+
OBJECT_DETECTION_PREDICTION_KIND,
112+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
113+
KEYPOINT_DETECTION_PREDICTION_KIND,
114+
],
115+
),
116+
]
117+
118+
@classmethod
119+
def get_execution_engine_compatibility(cls) -> Optional[str]:
120+
return ">=1.3.0,<2.0.0"
121+
122+
123+
class PerClassConfidenceFilterBlockV1(WorkflowBlock):
124+
125+
@classmethod
126+
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
127+
return BlockManifest
128+
129+
def run(
130+
self,
131+
predictions: Batch[sv.Detections],
132+
class_thresholds: Dict[str, Any],
133+
default_threshold: float = 0.3,
134+
) -> BlockResult:
135+
return [
136+
{
137+
"predictions": filter_detections_by_class_confidence(
138+
detections=detections,
139+
class_thresholds=class_thresholds,
140+
default_threshold=default_threshold,
141+
)
142+
}
143+
for detections in predictions
144+
]
145+
146+
147+
def filter_detections_by_class_confidence(
148+
detections: sv.Detections,
149+
class_thresholds: Dict[str, Any],
150+
default_threshold: float = 0.3,
151+
) -> sv.Detections:
152+
if detections is None or len(detections) == 0:
153+
return detections
154+
confidences = detections.confidence
155+
if confidences is None:
156+
return detections
157+
class_names = detections.data.get("class_name", [])
158+
thresholds = {str(k): float(v) for k, v in (class_thresholds or {}).items()}
159+
default = float(default_threshold)
160+
keep: List[int] = []
161+
for i, confidence in enumerate(confidences):
162+
class_name = class_names[i] if i < len(class_names) else None
163+
threshold = thresholds.get(str(class_name), default)
164+
if float(confidence) >= threshold:
165+
keep.append(i)
166+
if len(keep) == len(detections):
167+
return detections
168+
return detections[np.array(keep, dtype=int)]
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import numpy as np
2+
import supervision as sv
3+
4+
from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS
5+
from inference.core.managers.base import ModelManager
6+
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
7+
from inference.core.workflows.execution_engine.core import ExecutionEngine
8+
9+
10+
PER_CLASS_CONFIDENCE_FILTER_WORKFLOW = {
11+
"version": "1.3.0",
12+
"inputs": [
13+
{
14+
"type": "WorkflowBatchInput",
15+
"name": "predictions",
16+
"kind": ["object_detection_prediction"],
17+
},
18+
{"type": "WorkflowParameter", "name": "class_thresholds"},
19+
{"type": "WorkflowParameter", "name": "default_threshold", "default_value": 0.3},
20+
],
21+
"steps": [
22+
{
23+
"type": "roboflow_core/per_class_confidence_filter@v1",
24+
"name": "filter",
25+
"predictions": "$inputs.predictions",
26+
"class_thresholds": "$inputs.class_thresholds",
27+
"default_threshold": "$inputs.default_threshold",
28+
}
29+
],
30+
"outputs": [
31+
{
32+
"type": "JsonField",
33+
"name": "filtered",
34+
"selector": "$steps.filter.predictions",
35+
}
36+
],
37+
}
38+
39+
40+
def _make_detections(
41+
class_names: list[str], confidences: list[float]
42+
) -> sv.Detections:
43+
n = len(class_names)
44+
return sv.Detections(
45+
xyxy=np.array([[0, 0, 10, 10]] * n, dtype=np.float64),
46+
class_id=np.arange(n),
47+
confidence=np.array(confidences, dtype=np.float64),
48+
data={
49+
"class_name": np.array(class_names),
50+
"detection_id": np.array([f"d{i}" for i in range(n)]),
51+
},
52+
)
53+
54+
55+
def test_per_class_confidence_filter_end_to_end(
56+
model_manager: ModelManager,
57+
) -> None:
58+
# given
59+
workflow_init_parameters = {
60+
"workflows_core.model_manager": model_manager,
61+
"workflows_core.api_key": None,
62+
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
63+
}
64+
execution_engine = ExecutionEngine.init(
65+
workflow_definition=PER_CLASS_CONFIDENCE_FILTER_WORKFLOW,
66+
init_parameters=workflow_init_parameters,
67+
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
68+
)
69+
predictions = _make_detections(
70+
class_names=["person", "person", "car", "dog"],
71+
confidences=[0.99, 0.7, 0.6, 0.4],
72+
)
73+
74+
# when
75+
result = execution_engine.run(
76+
runtime_parameters={
77+
"predictions": [predictions],
78+
"class_thresholds": {"person": 0.98, "car": 0.5},
79+
"default_threshold": 0.5,
80+
}
81+
)
82+
83+
# then
84+
assert isinstance(result, list)
85+
assert len(result) == 1
86+
filtered: sv.Detections = result[0]["filtered"]
87+
assert list(filtered.data["class_name"]) == ["person", "car"]
88+
assert list(filtered.confidence) == [0.99, 0.6]
89+
90+
91+
def test_per_class_confidence_filter_default_threshold_filters_unknown_class(
92+
model_manager: ModelManager,
93+
) -> None:
94+
# given
95+
workflow_init_parameters = {
96+
"workflows_core.model_manager": model_manager,
97+
"workflows_core.api_key": None,
98+
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
99+
}
100+
execution_engine = ExecutionEngine.init(
101+
workflow_definition=PER_CLASS_CONFIDENCE_FILTER_WORKFLOW,
102+
init_parameters=workflow_init_parameters,
103+
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
104+
)
105+
predictions = _make_detections(
106+
class_names=["cat", "cat"],
107+
confidences=[0.2, 0.8],
108+
)
109+
110+
# when
111+
result = execution_engine.run(
112+
runtime_parameters={
113+
"predictions": [predictions],
114+
"class_thresholds": {"person": 0.98},
115+
"default_threshold": 0.5,
116+
}
117+
)
118+
119+
# then
120+
filtered: sv.Detections = result[0]["filtered"]
121+
assert list(filtered.data["class_name"]) == ["cat"]
122+
assert list(filtered.confidence) == [0.8]
123+
124+
125+
def test_per_class_confidence_filter_handles_batch_of_images(
126+
model_manager: ModelManager,
127+
) -> None:
128+
# given
129+
workflow_init_parameters = {
130+
"workflows_core.model_manager": model_manager,
131+
"workflows_core.api_key": None,
132+
"workflows_core.step_execution_mode": StepExecutionMode.LOCAL,
133+
}
134+
execution_engine = ExecutionEngine.init(
135+
workflow_definition=PER_CLASS_CONFIDENCE_FILTER_WORKFLOW,
136+
init_parameters=workflow_init_parameters,
137+
max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS,
138+
)
139+
predictions_image_1 = _make_detections(
140+
class_names=["person", "car"], confidences=[0.99, 0.4]
141+
)
142+
predictions_image_2 = _make_detections(
143+
class_names=["car"], confidences=[0.55]
144+
)
145+
146+
# when
147+
result = execution_engine.run(
148+
runtime_parameters={
149+
"predictions": [predictions_image_1, predictions_image_2],
150+
"class_thresholds": {"person": 0.98, "car": 0.5},
151+
"default_threshold": 0.3,
152+
}
153+
)
154+
155+
# then
156+
assert len(result) == 2
157+
assert list(result[0]["filtered"].data["class_name"]) == ["person"]
158+
assert list(result[1]["filtered"].data["class_name"]) == ["car"]

0 commit comments

Comments
 (0)