|
| 1 | +from typing import Any, Dict, List, Literal, Optional, Type, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import supervision as sv |
| 5 | +from pydantic import ConfigDict, Field |
| 6 | + |
| 7 | +from inference.core.workflows.execution_engine.entities.base import ( |
| 8 | + Batch, |
| 9 | + OutputDefinition, |
| 10 | +) |
| 11 | +from inference.core.workflows.execution_engine.entities.types import ( |
| 12 | + DICTIONARY_KIND, |
| 13 | + FLOAT_ZERO_TO_ONE_KIND, |
| 14 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 15 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 16 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 17 | + Selector, |
| 18 | +) |
| 19 | +from inference.core.workflows.prototypes.block import ( |
| 20 | + BlockResult, |
| 21 | + WorkflowBlock, |
| 22 | + WorkflowBlockManifest, |
| 23 | +) |
| 24 | + |
| 25 | +SHORT_DESCRIPTION = "Filter detections by applying a per-class confidence threshold." |
| 26 | + |
| 27 | +LONG_DESCRIPTION = """ |
| 28 | +Filter detection predictions by applying a different confidence threshold to each class, keeping only detections whose confidence meets or exceeds the threshold configured for their class (with a configurable fallback threshold for classes that are not listed). |
| 29 | +
|
| 30 | +## How This Block Works |
| 31 | +
|
| 32 | +This block applies class-aware confidence filtering to detection predictions, enabling precise control over which detections are retained based on per-class quality requirements. The block: |
| 33 | +
|
| 34 | +1. Takes detection predictions (object detection, instance segmentation, or keypoint detection) and a dictionary mapping class names to confidence thresholds |
| 35 | +2. Iterates through each detection, looking up the threshold associated with the detection's class name |
| 36 | +3. If the class is not present in the dictionary, falls back to the configurable `default_threshold` value |
| 37 | +4. Keeps only the detections whose confidence is greater than or equal to the resolved threshold |
| 38 | +5. Returns the filtered detections while preserving all original metadata (class ids, masks, keypoints, tracker ids, etc.) |
| 39 | +
|
| 40 | +Unlike a single global confidence threshold, this block lets you demand high-confidence predictions for classes that are prone to false positives while keeping a more permissive threshold for classes that are harder to detect. Unlike the generic detections filter, it exposes a purpose-built dictionary input that maps cleanly to a simple `{"class_name": threshold}` JSON object. |
| 41 | +
|
| 42 | +## Common Use Cases |
| 43 | +
|
| 44 | +- **Noise-prone classes**: Demand very high confidence (e.g. 0.9) for classes that frequently produce false positives, while accepting lower confidence for well-behaved classes |
| 45 | +- **Hard-to-detect classes**: Lower the threshold for classes that the model rarely detects with high confidence so that they are not filtered out entirely |
| 46 | +- **Production-grade filtering**: Apply domain-specific thresholds tuned during evaluation so that downstream analytics, alerts, or counting blocks only see detections that meet the project's quality bar |
| 47 | +- **Multi-class pipelines**: Combine with object detection models that predict many classes at once when a single global confidence threshold is too coarse |
| 48 | +
|
| 49 | +## Connecting to Other Blocks |
| 50 | +
|
| 51 | +The filtered predictions from this block can be connected to: |
| 52 | +
|
| 53 | +- **Visualization blocks** (Bounding Box Visualization, Label Visualization, Polygon Visualization) to render only detections that cleared their per-class threshold |
| 54 | +- **Counting and analytics blocks** (Line Counter, Time in Zone, Velocity) so that metrics reflect only high-quality detections |
| 55 | +- **Tracking blocks** (Byte Tracker) so that tracker associations are not polluted by low-confidence noise |
| 56 | +- **Storage or sink blocks** (Roboflow Dataset Upload, Webhook Sink, CSV Formatter) so that only detections meeting the quality bar are persisted or transmitted |
| 57 | +- **Downstream transformation blocks** (Dynamic Crop, Detection Offset) for subsequent processing on the filtered subset |
| 58 | +""" |
| 59 | + |
| 60 | + |
| 61 | +class BlockManifest(WorkflowBlockManifest): |
| 62 | + model_config = ConfigDict( |
| 63 | + json_schema_extra={ |
| 64 | + "name": "Per-Class Confidence Filter", |
| 65 | + "version": "v1", |
| 66 | + "short_description": SHORT_DESCRIPTION, |
| 67 | + "long_description": LONG_DESCRIPTION, |
| 68 | + "license": "Apache-2.0", |
| 69 | + "block_type": "transformation", |
| 70 | + "ui_manifest": { |
| 71 | + "section": "flow_control", |
| 72 | + "icon": "far fa-filter", |
| 73 | + "blockPriority": 2, |
| 74 | + }, |
| 75 | + } |
| 76 | + ) |
| 77 | + type: Literal["roboflow_core/per_class_confidence_filter@v1"] |
| 78 | + predictions: Selector( |
| 79 | + kind=[ |
| 80 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 81 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 82 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 83 | + ] |
| 84 | + ) = Field( |
| 85 | + description="Detection predictions to filter. Each detection is kept only if its confidence is greater than or equal to the threshold configured for its class (with a fallback to default_threshold for classes that are not listed in class_thresholds).", |
| 86 | + examples=["$steps.object_detection_model.predictions"], |
| 87 | + ) |
| 88 | + class_thresholds: Union[ |
| 89 | + Dict[str, float], |
| 90 | + Selector(kind=[DICTIONARY_KIND]), |
| 91 | + ] = Field( |
| 92 | + description="Mapping of class name to minimum confidence threshold. Detections whose class name is present in this dictionary are kept only if their confidence is at least the corresponding threshold. Classes not present fall back to default_threshold. Thresholds should be in the [0.0, 1.0] range.", |
| 93 | + examples=[{"person": 0.98, "car": 0.5}, "$inputs.class_thresholds"], |
| 94 | + ) |
| 95 | + default_threshold: Union[float, Selector(kind=[FLOAT_ZERO_TO_ONE_KIND])] = Field( |
| 96 | + default=0.3, |
| 97 | + description="Confidence threshold applied to detections whose class name is not listed in class_thresholds. Must be in the [0.0, 1.0] range.", |
| 98 | + examples=[0.3, "$inputs.default_threshold"], |
| 99 | + ) |
| 100 | + |
| 101 | + @classmethod |
| 102 | + def get_parameters_accepting_batches(cls) -> List[str]: |
| 103 | + return ["predictions"] |
| 104 | + |
| 105 | + @classmethod |
| 106 | + def describe_outputs(cls) -> List[OutputDefinition]: |
| 107 | + return [ |
| 108 | + OutputDefinition( |
| 109 | + name="predictions", |
| 110 | + kind=[ |
| 111 | + OBJECT_DETECTION_PREDICTION_KIND, |
| 112 | + INSTANCE_SEGMENTATION_PREDICTION_KIND, |
| 113 | + KEYPOINT_DETECTION_PREDICTION_KIND, |
| 114 | + ], |
| 115 | + ), |
| 116 | + ] |
| 117 | + |
| 118 | + @classmethod |
| 119 | + def get_execution_engine_compatibility(cls) -> Optional[str]: |
| 120 | + return ">=1.3.0,<2.0.0" |
| 121 | + |
| 122 | + |
| 123 | +class PerClassConfidenceFilterBlockV1(WorkflowBlock): |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def get_manifest(cls) -> Type[WorkflowBlockManifest]: |
| 127 | + return BlockManifest |
| 128 | + |
| 129 | + def run( |
| 130 | + self, |
| 131 | + predictions: Batch[sv.Detections], |
| 132 | + class_thresholds: Dict[str, Any], |
| 133 | + default_threshold: float = 0.3, |
| 134 | + ) -> BlockResult: |
| 135 | + return [ |
| 136 | + { |
| 137 | + "predictions": filter_detections_by_class_confidence( |
| 138 | + detections=detections, |
| 139 | + class_thresholds=class_thresholds, |
| 140 | + default_threshold=default_threshold, |
| 141 | + ) |
| 142 | + } |
| 143 | + for detections in predictions |
| 144 | + ] |
| 145 | + |
| 146 | + |
| 147 | +def filter_detections_by_class_confidence( |
| 148 | + detections: sv.Detections, |
| 149 | + class_thresholds: Dict[str, Any], |
| 150 | + default_threshold: float = 0.3, |
| 151 | +) -> sv.Detections: |
| 152 | + if detections is None or len(detections) == 0: |
| 153 | + return detections |
| 154 | + confidences = detections.confidence |
| 155 | + if confidences is None: |
| 156 | + return detections |
| 157 | + class_names = detections.data.get("class_name", []) |
| 158 | + thresholds = {str(k): float(v) for k, v in (class_thresholds or {}).items()} |
| 159 | + default = float(default_threshold) |
| 160 | + keep: List[int] = [] |
| 161 | + for i, confidence in enumerate(confidences): |
| 162 | + class_name = class_names[i] if i < len(class_names) else None |
| 163 | + threshold = thresholds.get(str(class_name), default) |
| 164 | + if float(confidence) >= threshold: |
| 165 | + keep.append(i) |
| 166 | + if len(keep) == len(detections): |
| 167 | + return detections |
| 168 | + return detections[np.array(keep, dtype=int)] |
0 commit comments