Skip to content

Commit ae9226e

Browse files
authored
[None][feat] Add PyTorch reset_prefix_cache API (#14970)
Signed-off-by: milesial <milesial@users.noreply.github.com>
1 parent aef7d47 commit ae9226e

7 files changed

Lines changed: 211 additions & 0 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5156,6 +5156,9 @@ def _handle_speculative_decoding(
51565156
return new_target_inputs, num_accepted_tokens_device
51575157

51585158
def reset_prefix_cache(self):
5159+
if self.active_requests or self.waiting_queue:
5160+
raise RuntimeError(
5161+
"reset_prefix_cache() requires no active or queued requests.")
51595162
self.kv_cache_manager.reset_reuse_state()
51605163

51615164
def _handle_guided_decoder_errors(

tensorrt_llm/executor/base_worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,16 @@ def wakeup(self, wakeup_tags: List[str]) -> None:
788788
materialize_with_tag(*tags)
789789
torch.cuda.synchronize()
790790

791+
def reset_prefix_cache(self) -> None:
792+
"""Invalidate local KV prefix-cache reuse state on PyTorch engines."""
793+
engine = self.engine
794+
if engine is None or not hasattr(engine, "reset_prefix_cache"):
795+
raise NotImplementedError(
796+
"reset_prefix_cache() is only supported by the PyTorch backend."
797+
)
798+
with engine.control_action():
799+
engine.reset_prefix_cache()
800+
791801
def shutdown(self):
792802
if self.doing_shutdown:
793803
return

tensorrt_llm/llmapi/llm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,33 @@ def _collective_rpc(
14941494
f"Executor type {type(self._executor)} does not support collective RPC."
14951495
)
14961496

1497+
@set_api_status("beta")
1498+
def reset_prefix_cache(self) -> None:
1499+
"""Reset local KV prefix-cache reuse state.
1500+
1501+
This invalidates local prefix-cache metadata in the PyTorch backend. It
1502+
requires no active or queued requests, and it does not reset
1503+
connector-managed external or offloaded cache state. Callers should
1504+
quiesce traffic before invoking this method.
1505+
"""
1506+
if self._encode_only:
1507+
raise RuntimeError("reset_prefix_cache() is not available when "
1508+
"encode_only=True.")
1509+
if self._executor is None:
1510+
raise RuntimeError("reset_prefix_cache() requires an active "
1511+
"executor.")
1512+
1513+
if hasattr(self._executor, "collective_rpc"):
1514+
self._collective_rpc("reset_prefix_cache")
1515+
return
1516+
1517+
reset_prefix_cache = getattr(self._executor, "reset_prefix_cache", None)
1518+
if reset_prefix_cache is None:
1519+
raise NotImplementedError(
1520+
"reset_prefix_cache() is only supported by the PyTorch backend."
1521+
)
1522+
reset_prefix_cache()
1523+
14971524
def _build_model(self):
14981525
super()._build_model()
14991526
assert self._engine_dir is None

tensorrt_llm/serve/openai_server.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,9 @@ def register_routes(self):
666666
self.app.add_api_route("/kv_cache_events",
667667
self.get_kv_cache_events,
668668
methods=["POST"])
669+
self.app.add_api_route("/reset_prefix_cache",
670+
self.reset_prefix_cache,
671+
methods=["POST"])
669672
resource_governor_queue = self.generator._executor.resource_governor_queue
670673
if resource_governor_queue is not None:
671674
from .resource_governor import ResourceGovernor
@@ -1027,6 +1030,27 @@ async def get_kv_cache_events(self) -> JSONResponse:
10271030
pass
10281031
return JSONResponse(content=events)
10291032

1033+
async def reset_prefix_cache(self) -> Response:
1034+
reset_prefix_cache = getattr(self.generator, "reset_prefix_cache", None)
1035+
if reset_prefix_cache is None:
1036+
return self._create_not_supported_error(
1037+
"reset_prefix_cache() is only supported by the PyTorch backend."
1038+
)
1039+
1040+
try:
1041+
await asyncio.get_running_loop().run_in_executor(
1042+
None, reset_prefix_cache)
1043+
except NotImplementedError as e:
1044+
return self._create_not_supported_error(str(e))
1045+
except (RuntimeError, ValueError) as e:
1046+
return self.create_error_response(
1047+
message=str(e),
1048+
err_type="InvalidRequestError",
1049+
status_code=HTTPStatus.CONFLICT,
1050+
)
1051+
1052+
return Response(status_code=200)
1053+
10301054
async def _extract_metrics(self, res: RequestOutput, raw_request: Request):
10311055
if not res.finished:
10321056
return

tests/unittest/_torch/executor/test_py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,41 @@ def test_handle_special_queue_items(mock_executor):
133133
assert 2 in mock_executor.canceled_req_ids
134134

135135

136+
def test_reset_prefix_cache_resets_when_idle():
137+
executor = object.__new__(PyExecutor)
138+
executor.active_requests = []
139+
executor.waiting_queue = []
140+
executor.kv_cache_manager = Mock()
141+
142+
executor.reset_prefix_cache()
143+
144+
executor.kv_cache_manager.reset_reuse_state.assert_called_once_with()
145+
146+
147+
def test_reset_prefix_cache_rejects_active_requests():
148+
executor = object.__new__(PyExecutor)
149+
executor.active_requests = [Mock()]
150+
executor.waiting_queue = []
151+
executor.kv_cache_manager = Mock()
152+
153+
with pytest.raises(RuntimeError, match="no active or queued requests"):
154+
executor.reset_prefix_cache()
155+
156+
executor.kv_cache_manager.reset_reuse_state.assert_not_called()
157+
158+
159+
def test_reset_prefix_cache_rejects_queued_requests():
160+
executor = object.__new__(PyExecutor)
161+
executor.active_requests = []
162+
executor.waiting_queue = [Mock()]
163+
executor.kv_cache_manager = Mock()
164+
165+
with pytest.raises(RuntimeError, match="no active or queued requests"):
166+
executor.reset_prefix_cache()
167+
168+
executor.kv_cache_manager.reset_reuse_state.assert_not_called()
169+
170+
136171
def test_clear_canceled_req_ids(mock_executor):
137172
"""Test clearing canceled request IDs."""
138173
mock_executor.canceled_req_ids = [1, 2, 3]

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,10 @@ methods:
387387
default: 2
388388
return_annotation: tensorrt_llm.executor.result.IterationResult
389389
status: beta
390+
reset_prefix_cache:
391+
parameters: {}
392+
return_annotation: None
393+
status: beta
390394
get_stats:
391395
parameters:
392396
timeout:

tests/unittest/llmapi/test_llm.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,114 @@ async def is_disconnected(self):
27082708
return True
27092709

27102710

2711+
class _FakeResetExecutor:
2712+
2713+
def __init__(self):
2714+
self.num_reset_calls = 0
2715+
2716+
def reset_prefix_cache(self):
2717+
self.num_reset_calls += 1
2718+
2719+
def shutdown(self):
2720+
pass
2721+
2722+
2723+
class _FakeCollectiveResetExecutor:
2724+
2725+
def __init__(self):
2726+
self.calls = []
2727+
2728+
def collective_rpc(self, method, args, kwargs, non_block, unique_reply_rank,
2729+
target_ranks):
2730+
self.calls.append(
2731+
(method, args, kwargs, non_block, unique_reply_rank, target_ranks))
2732+
return [None]
2733+
2734+
def shutdown(self):
2735+
pass
2736+
2737+
2738+
class _FakeUnsupportedResetExecutor:
2739+
2740+
def shutdown(self):
2741+
pass
2742+
2743+
2744+
class _FakeNotImplementedResetGenerator:
2745+
2746+
def reset_prefix_cache(self):
2747+
raise NotImplementedError("not supported")
2748+
2749+
2750+
def test_llm_reset_prefix_cache_dispatches_to_executor() -> None:
2751+
llm = object.__new__(LLM_torch)
2752+
llm._encode_only = False
2753+
llm._executor = _FakeResetExecutor()
2754+
2755+
llm.reset_prefix_cache()
2756+
2757+
assert llm._executor.num_reset_calls == 1
2758+
2759+
2760+
def test_llm_reset_prefix_cache_uses_collective_rpc() -> None:
2761+
llm = object.__new__(LLM_torch)
2762+
llm._encode_only = False
2763+
llm._executor = _FakeCollectiveResetExecutor()
2764+
2765+
llm.reset_prefix_cache()
2766+
2767+
assert llm._executor.calls == [("reset_prefix_cache", (), None, False, None,
2768+
None)]
2769+
2770+
2771+
def test_llm_reset_prefix_cache_rejects_encode_only() -> None:
2772+
llm = object.__new__(LLM_torch)
2773+
llm._encode_only = True
2774+
llm._executor = _FakeResetExecutor()
2775+
2776+
with pytest.raises(RuntimeError, match="encode_only=True"):
2777+
llm.reset_prefix_cache()
2778+
2779+
2780+
def test_llm_reset_prefix_cache_rejects_unsupported_executor() -> None:
2781+
llm = object.__new__(LLM_torch)
2782+
llm._encode_only = False
2783+
llm._executor = _FakeUnsupportedResetExecutor()
2784+
2785+
with pytest.raises(NotImplementedError,
2786+
match="only supported by the PyTorch backend"):
2787+
llm.reset_prefix_cache()
2788+
2789+
2790+
def test_openai_reset_prefix_cache_endpoint() -> None:
2791+
server = object.__new__(OpenAIServer)
2792+
server.generator = _FakeResetExecutor()
2793+
2794+
response = asyncio.run(server.reset_prefix_cache())
2795+
2796+
assert response.status_code == 200
2797+
assert server.generator.num_reset_calls == 1
2798+
2799+
2800+
def test_openai_reset_prefix_cache_endpoint_rejects_unsupported_generator(
2801+
) -> None:
2802+
server = object.__new__(OpenAIServer)
2803+
server.generator = object()
2804+
2805+
response = asyncio.run(server.reset_prefix_cache())
2806+
2807+
assert response.status_code == 501
2808+
2809+
2810+
def test_openai_reset_prefix_cache_endpoint_maps_not_implemented() -> None:
2811+
server = object.__new__(OpenAIServer)
2812+
server.generator = _FakeNotImplementedResetGenerator()
2813+
2814+
response = asyncio.run(server.reset_prefix_cache())
2815+
2816+
assert response.status_code == 501
2817+
2818+
27112819
def test_openai_completion_list_prompt_stream_reuses_stream_metadata() -> None:
27122820

27132821
async def run_request():

0 commit comments

Comments
 (0)