Skip to content

Commit 5f1026e

Browse files
Support optimal confidence from model-eval (#2206)
* squash commit * Show confidence hint 0.4 to match default * dev dockerfiles overlay inference_models source build * confidence filter debug logging * confidence filter lazy imports not needed * drop unneeded comment * kwargs.get(recommended_parameters) -> Optional[RecommendedParameters]=None * per-model default confidence, move to post_processing, inline per-image refinement into existing loops * concrete class post_process confidence optional default None * cleanup * simplify ConfidenceFilter and avoid double filtering * fix OOB bugs in yolov5/7 and rfdetr * confidencefilter readability * undo no-op diffs * revert workflow UI change for now * deeplapv3plus: drop double construction of SemanticSegmentationResult * Explicit 'best', 'default' or float confidence - easy opt-out * legacy inference ignore string valued confidence * move Confidence to entities, validation throws ModelInputError, hint confidence Tensor contents * use pydantic native ge/le validation instead of annotated_types * keep 0.5 request default, default to 'default' in SDK instead of 'best' * scalar threshold fast path * fix yolov10 * update inference_sdk * drop pydantic validation on confidence * bump version to 0.25.0, update changelog, revert Dockerfile and CI change * bump inference-models requirements to 0.25.0 --------- Co-authored-by: Paweł Pęczek <146137186+PawelPeczek-Roboflow@users.noreply.github.com>
1 parent 38b0fbf commit 5f1026e

78 files changed

Lines changed: 1902 additions & 259 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

inference/core/entities/requests/inference.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import BaseModel, ConfigDict, Field, validator
55

66
from inference.core.entities.common import ApiKey, ModelID, ModelType
7+
from inference_sdk.http.entities import Confidence
78

89

910
class BaseRequest(BaseModel):
@@ -145,10 +146,13 @@ class ObjectDetectionInferenceRequest(CVInferenceRequest):
145146
examples=[["class-1", "class-2", "class-n"]],
146147
description="If provided, only predictions for the listed classes will be returned",
147148
)
148-
confidence: Optional[float] = Field(
149+
confidence: Confidence = Field(
149150
default=0.4,
150-
examples=[0.5],
151-
description="The confidence threshold used to filter out predictions",
151+
examples=[0.5, "best", "default"],
152+
description=(
153+
'Confidence threshold. "best" uses model-eval thresholds, '
154+
'"default" uses the model built-in, or pass a float.'
155+
),
152156
)
153157
fix_batch_size: Optional[bool] = Field(
154158
default=False,
@@ -245,10 +249,13 @@ def __init__(self, **kwargs):
245249
kwargs["model_type"] = "classification"
246250
super().__init__(**kwargs)
247251

248-
confidence: Optional[float] = Field(
252+
confidence: Confidence = Field(
249253
default=0.4,
250-
examples=[0.5],
251-
description="The confidence threshold used to filter out predictions",
254+
examples=[0.5, "best", "default"],
255+
description=(
256+
'Confidence threshold. "best" uses model-eval thresholds, '
257+
'"default" uses the model built-in, or pass a float.'
258+
),
252259
)
253260
visualization_stroke_width: Optional[int] = Field(
254261
default=1,

inference/core/models/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,16 @@ def infer_from_request(
139139
is also included in the response.
140140
"""
141141
t1 = perf_counter()
142-
responses = self.infer(**request.dict(), return_image_dims=False)
142+
kwargs = request.dict()
143+
confidence = kwargs.get("confidence")
144+
if isinstance(confidence, str):
145+
logger.warning(
146+
"Legacy inference does not support confidence=%r, "
147+
"using model default",
148+
confidence,
149+
)
150+
kwargs.pop("confidence")
151+
responses = self.infer(**kwargs, return_image_dims=False)
143152
for response in responses:
144153
response.time = perf_counter() - t1
145154
logger.debug(f"model infer time: {response.time * 1000.0} ms")

inference/core/models/inference_models_adapters.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -677,25 +677,29 @@ def postprocess(
677677
List[ClassificationInferenceResponse],
678678
]:
679679
mapped_kwargs = self.map_inference_kwargs(kwargs)
680-
post_processed_predictions = self._model.post_process(
681-
predictions, **mapped_kwargs
682-
)
683-
if isinstance(post_processed_predictions, list):
684-
# multi-label classification
685-
return prepare_multi_label_classification_response(
686-
post_processed_predictions,
687-
image_sizes=returned_metadata,
688-
class_names=self.class_names,
689-
confidence_threshold=kwargs.get("confidence", 0.5),
680+
if isinstance(self._model, MultiLabelClassificationModel):
681+
post_processed_predictions = self._model.post_process(
682+
predictions, **mapped_kwargs
690683
)
691-
else:
692-
# single-label classification
693-
return prepare_classification_response(
684+
return prepare_multi_label_classification_response(
694685
post_processed_predictions,
695686
image_sizes=returned_metadata,
696687
class_names=self.class_names,
697-
confidence_threshold=kwargs.get("confidence", 0.5),
698688
)
689+
# Single-label classification: top-1 always wins regardless of
690+
# confidence, so per-class refinement isn't meaningful here. The base
691+
# class deliberately opts out of recommendedParameters entirely. The
692+
# response builder still uses kwargs.get("confidence", 0.5) for the
693+
# cutoff that decides which alternative classes show up.
694+
post_processed_predictions = self._model.post_process(
695+
predictions, **mapped_kwargs
696+
)
697+
return prepare_classification_response(
698+
post_processed_predictions,
699+
image_sizes=returned_metadata,
700+
class_names=self.class_names,
701+
confidence_threshold=kwargs.get("confidence") or 0.5,
702+
)
699703

700704
def clear_cache(self, delete_from_disk: bool = True) -> None:
701705
"""Clears any cache if necessary. TODO: Implement this to delete the cache from the experimental model.
@@ -747,20 +751,27 @@ def prepare_multi_label_classification_response(
747751
post_processed_predictions: List[MultiLabelClassificationPrediction],
748752
image_sizes: List[Tuple[int, int]],
749753
class_names: List[str],
750-
confidence_threshold: float,
751754
) -> List[MultiLabelClassificationInferenceResponse]:
755+
"""Build the API response from a model's post-processed predictions.
756+
757+
`prediction.class_ids` is the authoritative list of "passed" classes —
758+
the model's `post_process` already applied the
759+
full priority chain (user → per-class → global → default), so the
760+
response builder doesn't re-threshold here. The full per-class score
761+
vector is still emitted in `image_predictions_dict` for UI display.
762+
"""
752763
results = []
753764
for prediction, image_size in zip(post_processed_predictions, image_sizes):
754-
image_predictions_dict = dict()
755-
predicted_classes = []
756-
for class_id, confidence in enumerate(prediction.confidence.cpu().tolist()):
757-
cls_name = class_names[class_id]
758-
image_predictions_dict[cls_name] = {
765+
image_predictions_dict = {
766+
class_names[class_id]: {
759767
"confidence": confidence,
760768
"class_id": class_id,
761769
}
762-
if confidence > confidence_threshold:
763-
predicted_classes.append(cls_name)
770+
for class_id, confidence in enumerate(prediction.confidence.cpu().tolist())
771+
}
772+
predicted_classes = [
773+
class_names[class_id] for class_id in prediction.class_ids.tolist()
774+
]
764775
results.append(
765776
MultiLabelClassificationInferenceResponse(
766777
predictions=image_predictions_dict,

inference_models/docs/changelog.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# Changelog
22

3+
## `0.25.0`
4+
5+
### Added
6+
7+
- `post_process(...)` on object detection, instance segmentation, keypoint detection, classification, and semantic
8+
segmentation models now accepts `confidence` as `"best"` (use per-class or global thresholds from
9+
`RecommendedParameters` when available), `"default"` (model's built-in default), or a float override. Shared NMS
10+
helpers accept a per-class `torch.Tensor` for single-pass per-class filtering.
11+
12+
---
13+
314
## `0.24.4`
415

516
### Changed

inference_models/inference_models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
if os.environ.get("TOKENIZERS_PARALLELISM") is None:
2929
os.environ["TOKENIZERS_PARALLELISM"] = "false"
3030

31-
from inference_models.entities import ColorFormat
31+
from inference_models.entities import ColorFormat, Confidence
3232
from inference_models.model_pipelines.auto_loaders.core import AutoModelPipeline
3333
from inference_models.models.auto_loaders.core import AutoModel
3434
from inference_models.models.auto_loaders.entities import (
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import namedtuple
2-
from typing import Literal
2+
from typing import Literal, Union
33

44
ImageDimensions = namedtuple("ImageDimensions", ["height", "width"])
55
ColorFormat = Literal["rgb", "bgr"]
6+
Confidence = Union[float, Literal["best", "default"]]

inference_models/inference_models/models/auto_loaders/auto_resolution_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
TaskType,
1818
)
1919
from inference_models.utils.file_system import dump_json, read_json
20-
from inference_models.weights_providers.entities import ModelDependency
20+
from inference_models.weights_providers.entities import (
21+
ModelDependency,
22+
RecommendedParameters,
23+
)
2124

2225

2326
class AutoResolutionCacheEntry(BaseModel):
@@ -30,6 +33,7 @@ class AutoResolutionCacheEntry(BaseModel):
3033
model_dependencies: Optional[List[ModelDependency]] = Field(default=None)
3134
created_at: datetime
3235
model_features: Optional[dict] = Field(default=None)
36+
recommended_parameters: Optional[RecommendedParameters] = Field(default=None)
3337

3438

3539
class AutoResolutionCache(ABC):

inference_models/inference_models/models/auto_loaders/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
ModelDependency,
8282
ModelPackageMetadata,
8383
Quantization,
84+
RecommendedParameters,
8485
)
8586

8687
MODEL_TYPES_TO_LOAD_FROM_CHECKPOINT = {
@@ -926,6 +927,7 @@ def model_directory_pointer(model_dir: str) -> None:
926927
model_dependencies=model_metadata.model_dependencies,
927928
model_dependencies_instances=model_dependencies_instances,
928929
model_dependencies_directories=model_dependencies_directories,
930+
recommended_parameters=model_metadata.recommended_parameters,
929931
max_package_loading_attempts=max_package_loading_attempts,
930932
model_download_file_lock_acquire_timeout=model_download_file_lock_acquire_timeout,
931933
verify_hash_while_download=verify_hash_while_download,
@@ -1078,6 +1080,10 @@ def attempt_loading_model_with_auto_load_cache(
10781080
package_id=cache_entry.model_package_id,
10791081
)
10801082
model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
1083+
# Cache stores the already-resolved (package-vs-model) value written
1084+
# in initialize_model — no need to re-run resolve_recommended_parameters.
1085+
if cache_entry.recommended_parameters is not None:
1086+
model_init_kwargs["recommended_parameters"] = cache_entry.recommended_parameters
10811087
model = model_class.from_pretrained(
10821088
model_package_cache_dir, **model_init_kwargs
10831089
)
@@ -1113,6 +1119,7 @@ def attempt_loading_matching_model_packages(
11131119
model_dependencies: Optional[List[ModelDependency]],
11141120
model_dependencies_instances: Dict[str, AnyModel],
11151121
model_dependencies_directories: Dict[str, str],
1122+
recommended_parameters: Optional[RecommendedParameters] = None,
11161123
max_package_loading_attempts: Optional[int] = None,
11171124
model_download_file_lock_acquire_timeout: int = FILE_LOCK_ACQUIRE_TIMEOUT,
11181125
verbose: bool = True,
@@ -1153,6 +1160,7 @@ def attempt_loading_matching_model_packages(
11531160
model_dependencies=model_dependencies,
11541161
model_dependencies_instances=model_dependencies_instances,
11551162
model_dependencies_directories=model_dependencies_directories,
1163+
recommended_parameters=recommended_parameters,
11561164
verify_hash_while_download=verify_hash_while_download,
11571165
download_files_without_hash=download_files_without_hash,
11581166
on_file_created=partial(
@@ -1218,6 +1226,7 @@ def initialize_model(
12181226
model_dependencies: Optional[List[ModelDependency]],
12191227
model_dependencies_instances: Dict[str, AnyModel],
12201228
model_dependencies_directories: Dict[str, str],
1229+
recommended_parameters: Optional[RecommendedParameters] = None,
12211230
model_download_file_lock_acquire_timeout: int = FILE_LOCK_ACQUIRE_TIMEOUT,
12221231
verify_hash_while_download: bool = True,
12231232
download_files_without_hash: bool = False,
@@ -1307,6 +1316,12 @@ def initialize_model(
13071316
)
13081317
resolved_files.update(dependencies_resolved_files)
13091318
model_init_kwargs[MODEL_DEPENDENCIES_KEY] = model_dependencies_instances
1319+
resolved_recommended_parameters = resolve_recommended_parameters(
1320+
package_level=model_package.recommended_parameters,
1321+
model_level=recommended_parameters,
1322+
)
1323+
if resolved_recommended_parameters is not None:
1324+
model_init_kwargs["recommended_parameters"] = resolved_recommended_parameters
13101325
model = model_class.from_pretrained(model_package_cache_dir, **model_init_kwargs)
13111326
dump_auto_resolution_cache(
13121327
use_auto_resolution_cache=use_auto_resolution_cache,
@@ -1320,6 +1335,7 @@ def initialize_model(
13201335
resolved_files=resolved_files,
13211336
model_dependencies=model_dependencies,
13221337
model_features=model_package.model_features,
1338+
recommended_parameters=resolved_recommended_parameters,
13231339
)
13241340
return model, model_package_cache_dir
13251341

@@ -1484,6 +1500,7 @@ def dump_auto_resolution_cache(
14841500
resolved_files: Set[str],
14851501
model_dependencies: Optional[List[ModelDependency]],
14861502
model_features: Optional[dict],
1503+
recommended_parameters: Optional[RecommendedParameters] = None,
14871504
) -> None:
14881505
if not use_auto_resolution_cache:
14891506
return None
@@ -1497,6 +1514,7 @@ def dump_auto_resolution_cache(
14971514
created_at=datetime.now(),
14981515
model_dependencies=model_dependencies,
14991516
model_features=model_features,
1517+
recommended_parameters=recommended_parameters,
15001518
)
15011519
auto_resolution_cache.register(
15021520
auto_negotiation_hash=auto_negotiation_hash, cache_entry=cache_content
@@ -1812,3 +1830,11 @@ def load_class_from_path(module_path: str, class_name: str) -> AnyModel:
18121830
help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror",
18131831
)
18141832
return getattr(module, class_name)
1833+
1834+
1835+
def resolve_recommended_parameters(
1836+
package_level: Optional[RecommendedParameters],
1837+
model_level: Optional[RecommendedParameters],
1838+
) -> Optional[RecommendedParameters]:
1839+
"""Package-level recommended_parameters take priority over model-level."""
1840+
return package_level if package_level is not None else model_level

inference_models/inference_models/models/base/classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ class ClassificationPrediction:
1717

1818
class ClassificationModel(ABC, Generic[PreprocessedInputs, RawPrediction]):
1919

20+
# Single-label classification deliberately opts out of recommendedParameters.
21+
# Top-1 always wins regardless of confidence, so per-class refinement isn't
22+
# a meaningful semantic for this task type. (Multi-label classification opts
23+
# in below — that's where per-class thresholds actually filter the result.)
24+
2025
@classmethod
2126
@abstractmethod
2227
def from_pretrained(

0 commit comments

Comments
 (0)