Skip to content

Commit 70e2c4a

Browse files
authored
Forward request metadata through decorators (#2198)
1 parent 267a6ac commit 70e2c4a

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

inference/core/managers/decorators/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ def add_model(
8686
service_secret=service_secret,
8787
)
8888

89+
def record_request_metadata(
90+
self,
91+
model_id: str,
92+
original_model_id: Optional[str] = None,
93+
model_id_alias: Optional[str] = None,
94+
) -> None:
95+
self.model_manager.record_request_metadata(
96+
model_id=model_id,
97+
original_model_id=original_model_id,
98+
model_id_alias=model_id_alias,
99+
)
100+
89101
async def infer_from_request(
90102
self, model_id: str, request: InferenceRequest, **kwargs
91103
) -> InferenceResponse:

tests/inference/unit_tests/core/managers/test_decorators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from inference.core.managers.base import ModelManager
44
from inference.core.managers.decorators.base import ModelManagerDecorator
55
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
6+
from inference.core.managers.decorators.locked_load import (
7+
LockedLoadModelManagerDecorator,
8+
)
69
from inference.core.managers.model_load_collector import (
710
RequestModelIds,
811
current_request_path,
@@ -59,3 +62,26 @@ def test_fixed_size_cache_records_request_metadata_for_warm_model() -> None:
5962
assert description.model_id == "sam3/sam3_interactive"
6063
assert description.request_aliases == ["sam3/sam3_final"]
6164
assert description.request_paths == ["/sam3/embed_image"]
65+
66+
67+
def test_nested_decorators_record_request_metadata_for_warm_model() -> None:
68+
base_manager = ModelManager(model_registry=MagicMock())
69+
base_manager._models = {"some/1": MagicMock()}
70+
decorator = WithFixedSizeCache(
71+
LockedLoadModelManagerDecorator(base_manager), max_size=8
72+
)
73+
path_token = current_request_path.set("/infer/object_detection")
74+
ids = RequestModelIds()
75+
ids_token = request_model_ids.set(ids)
76+
77+
try:
78+
decorator.add_model(model_id="some/1", api_key="key")
79+
finally:
80+
request_model_ids.reset(ids_token)
81+
current_request_path.reset(path_token)
82+
83+
[description] = base_manager.describe_models()
84+
assert description.model_id == "some/1"
85+
assert description.request_aliases == []
86+
assert description.request_paths == ["/infer/object_detection"]
87+
assert ids.get_ids() == {"some/1"}

0 commit comments

Comments
 (0)