Skip to content

Commit c9a58d5

Browse files
committed
[None][fix] Restore PyTorch reset_prefix_cache API
Signed-off-by: Alexandre Milesi <milesial@users.noreply.github.com>
1 parent db7161b commit c9a58d5

9 files changed

Lines changed: 349 additions & 22 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5156,7 +5156,11 @@ def _handle_speculative_decoding(
51565156

51575157
return new_target_inputs, num_accepted_tokens_device
51585158

5159-
def reset_prefix_cache(self):
5159+
def reset_prefix_cache(self) -> None:
5160+
"""Invalidate local KV prefix-cache reuse state."""
5161+
if self.active_requests or self.waiting_queue:
5162+
raise RuntimeError(
5163+
"reset_prefix_cache() requires no active or queued requests.")
51605164
self.kv_cache_manager.reset_reuse_state()
51615165

51625166
def _handle_guided_decoder_errors(

tensorrt_llm/executor/base_worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,15 @@ def wakeup(self, wakeup_tags: List[str]) -> None:
788788
materialize_with_tag(*tags)
789789
torch.cuda.synchronize()
790790

791+
def reset_prefix_cache(self) -> None:
792+
"""Invalidate local KV prefix-cache reuse state on PyTorch engines."""
793+
engine = self.engine
794+
if engine is None or not hasattr(engine, "reset_prefix_cache"):
795+
raise NotImplementedError(
796+
"reset_prefix_cache() is only supported by the PyTorch backend.")
797+
with engine.control_action():
798+
engine.reset_prefix_cache()
799+
791800
def shutdown(self):
792801
if self.doing_shutdown:
793802
return

tensorrt_llm/llmapi/llm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,38 @@ def _collective_rpc(
14941494
f"Executor type {type(self._executor)} does not support collective RPC."
14951495
)
14961496

1497+
@set_api_status("beta")
1498+
def reset_prefix_cache(self) -> None:
1499+
"""Reset local KV prefix-cache reuse state.
1500+
1501+
This invalidates local prefix-cache metadata in the PyTorch backend. It
1502+
requires no active or queued requests, and it does not reset
1503+
connector-managed external or offloaded cache state. Callers should
1504+
quiesce traffic before invoking this method.
1505+
1506+
Raises:
1507+
RuntimeError: If the LLM is encode-only or has no active executor.
1508+
NotImplementedError: If the executor does not support prefix-cache
1509+
reset.
1510+
"""
1511+
if self._encode_only:
1512+
raise RuntimeError("reset_prefix_cache() is not available when "
1513+
"encode_only=True.")
1514+
if self._executor is None:
1515+
raise RuntimeError("reset_prefix_cache() requires an active "
1516+
"executor.")
1517+
1518+
if hasattr(self._executor, "collective_rpc"):
1519+
self._collective_rpc("reset_prefix_cache")
1520+
return
1521+
1522+
reset_prefix_cache = getattr(self._executor, "reset_prefix_cache",
1523+
None)
1524+
if reset_prefix_cache is None:
1525+
raise NotImplementedError(
1526+
"reset_prefix_cache() is only supported by the PyTorch backend.")
1527+
reset_prefix_cache()
1528+
14971529
def _build_model(self):
14981530
super()._build_model()
14991531
assert self._engine_dir is None

tensorrt_llm/llmapi/rlhf_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,6 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
155155
logger.error("Encountered an error in update_weights")
156156
raise e
157157

158-
def reset_prefix_cache(self) -> None:
159-
"""Invalidate the KV cache prefix reuse state after weight updates."""
160-
self.engine.reset_prefix_cache()
161-
162158
@control_action_decorator
163159
def wait_for_engine_idle(self) -> None:
164160
"""Block until the engine has no active or queued requests."""

tensorrt_llm/serve/openai_server.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections import deque
1313
from contextlib import asynccontextmanager
1414
from datetime import datetime
15+
from functools import partial
1516
from http import HTTPStatus
1617
from pathlib import Path
1718
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
@@ -666,6 +667,9 @@ def register_routes(self):
666667
self.app.add_api_route("/kv_cache_events",
667668
self.get_kv_cache_events,
668669
methods=["POST"])
670+
self.app.add_api_route("/reset_prefix_cache",
671+
self.reset_prefix_cache,
672+
methods=["POST"])
669673
resource_governor_queue = self.generator._executor.resource_governor_queue
670674
if resource_governor_queue is not None:
671675
from .resource_governor import ResourceGovernor
@@ -1027,6 +1031,28 @@ async def get_kv_cache_events(self) -> JSONResponse:
10271031
pass
10281032
return JSONResponse(content=events)
10291033

1034+
async def reset_prefix_cache(self) -> Response:
1035+
"""Reset local PyTorch KV prefix-cache reuse state."""
1036+
reset_fn = getattr(self.generator, "reset_prefix_cache", None)
1037+
if reset_fn is None:
1038+
return self._create_not_supported_error(
1039+
"reset_prefix_cache() is only supported by the PyTorch backend."
1040+
)
1041+
1042+
try:
1043+
await asyncio.get_running_loop().run_in_executor(
1044+
None, reset_fn)
1045+
except NotImplementedError as e:
1046+
return self._create_not_supported_error(str(e))
1047+
except (RuntimeError, ValueError) as e:
1048+
return self.create_error_response(
1049+
message=str(e),
1050+
err_type="InvalidRequestError",
1051+
status_code=HTTPStatus.CONFLICT,
1052+
)
1053+
1054+
return JSONResponse(content={"status": "success"})
1055+
10301056
async def _extract_metrics(self, res: RequestOutput, raw_request: Request):
10311057
if not res.finished:
10321058
return
@@ -1952,26 +1978,58 @@ async def openai_responses_delete_response(
19521978
"deleted": True
19531979
})
19541980

1955-
async def release_memory(self,
1956-
request: MemoryUpdateRequest) -> JSONResponse:
1957-
assert isinstance(
1958-
self.generator, AsyncLLM
1959-
), "/release_memory endpoint is only supported with AsyncLLM()"
1960-
await self.generator.collective_rpc('sleep', args=(request.tags, ))
1961-
return JSONResponse(content={"status": "success"})
1981+
async def _run_worker_control_rpc(self, method: str,
1982+
args: tuple[Any, ...]) -> None:
1983+
"""Dispatch a worker control method through direct or collective RPC."""
1984+
executor = getattr(self.generator, "_executor", None)
1985+
direct_method = getattr(executor, method, None)
1986+
if direct_method is not None:
1987+
await asyncio.get_running_loop().run_in_executor(
1988+
None, partial(direct_method, *args))
1989+
return
1990+
1991+
if isinstance(self.generator, AsyncLLM):
1992+
await self.generator.collective_rpc(method, args=args)
1993+
return
1994+
1995+
collective_rpc = getattr(self.generator, "_collective_rpc", None)
1996+
if collective_rpc is not None:
1997+
await asyncio.get_running_loop().run_in_executor(
1998+
None, partial(collective_rpc, method, args))
1999+
return
2000+
2001+
raise NotImplementedError(
2002+
f"{method}() is only supported by the PyTorch backend.")
2003+
2004+
async def _handle_worker_control_rpc(self, method: str,
2005+
args: tuple[Any, ...]) -> Response:
2006+
"""Run a worker control method and map failures to HTTP responses."""
2007+
try:
2008+
await self._run_worker_control_rpc(method, args)
2009+
except NotImplementedError as e:
2010+
return self._create_not_supported_error(str(e))
2011+
except (RuntimeError, ValueError) as e:
2012+
return self.create_error_response(
2013+
message=str(e),
2014+
err_type="InvalidRequestError",
2015+
status_code=HTTPStatus.CONFLICT,
2016+
)
19622017

1963-
async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
1964-
assert isinstance(
1965-
self.generator, AsyncLLM
1966-
), "/resume_memory endpoint is only supported with AsyncLLM()"
1967-
await self.generator.collective_rpc('wakeup', args=(request.tags, ))
19682018
return JSONResponse(content={"status": "success"})
19692019

2020+
async def release_memory(self, request: MemoryUpdateRequest) -> Response:
2021+
return await self._handle_worker_control_rpc('sleep',
2022+
(request.tags, ))
2023+
2024+
async def resume_memory(self, request: MemoryUpdateRequest) -> Response:
2025+
return await self._handle_worker_control_rpc('wakeup',
2026+
(request.tags, ))
2027+
19702028
async def update_weights(self,
1971-
request: UpdateWeightsRequest) -> JSONResponse:
1972-
assert isinstance(
1973-
self.generator, AsyncLLM
1974-
), "/update_weights endpoint is only supported with AsyncLLM()"
2029+
request: UpdateWeightsRequest) -> Response:
2030+
if not isinstance(self.generator, AsyncLLM):
2031+
return self._create_not_supported_error(
2032+
"/update_weights endpoint is only supported with AsyncLLM()")
19752033
await self.generator.collective_rpc('update_weights',
19762034
args=(request.weights, ))
19772035
return JSONResponse(content={"status": "success"})

tests/unittest/_torch/executor/test_py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,41 @@ def test_handle_special_queue_items(mock_executor):
133133
assert 2 in mock_executor.canceled_req_ids
134134

135135

136+
def test_reset_prefix_cache_resets_when_idle():
137+
executor = object.__new__(PyExecutor)
138+
executor.active_requests = []
139+
executor.waiting_queue = []
140+
executor.kv_cache_manager = Mock()
141+
142+
executor.reset_prefix_cache()
143+
144+
executor.kv_cache_manager.reset_reuse_state.assert_called_once_with()
145+
146+
147+
def test_reset_prefix_cache_rejects_active_requests():
148+
executor = object.__new__(PyExecutor)
149+
executor.active_requests = [Mock()]
150+
executor.waiting_queue = []
151+
executor.kv_cache_manager = Mock()
152+
153+
with pytest.raises(RuntimeError, match="no active or queued requests"):
154+
executor.reset_prefix_cache()
155+
156+
executor.kv_cache_manager.reset_reuse_state.assert_not_called()
157+
158+
159+
def test_reset_prefix_cache_rejects_queued_requests():
160+
executor = object.__new__(PyExecutor)
161+
executor.active_requests = []
162+
executor.waiting_queue = [Mock()]
163+
executor.kv_cache_manager = Mock()
164+
165+
with pytest.raises(RuntimeError, match="no active or queued requests"):
166+
executor.reset_prefix_cache()
167+
168+
executor.kv_cache_manager.reset_reuse_state.assert_not_called()
169+
170+
136171
def test_clear_canceled_req_ids(mock_executor):
137172
"""Test clearing canceled request IDs."""
138173
mock_executor.canceled_req_ids = [1, 2, 3]

tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from tensorrt_llm import LLM
1515
from tensorrt_llm._torch.utils import get_device_uuid
16+
from tensorrt_llm.executor.ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
1617
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams
18+
from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension
1719

1820
# Ray-backed LLM teardown spawns the executor main-loop, GC and log/error
1921
# listener threads in ray-core. These are torn down only when ``ray.shutdown()``
@@ -26,6 +28,15 @@
2628
pytestmark = pytest.mark.threadleak(enabled=False)
2729

2830

31+
@pytest.mark.part0
32+
def test_rlhf_worker_extension_uses_base_reset_prefix_cache():
33+
worker_cls = RayWorkerWrapper._inject_worker_extension(
34+
RayGPUWorker, "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension")
35+
36+
assert "reset_prefix_cache" not in WorkerExtension.__dict__
37+
assert worker_cls.reset_prefix_cache is RayGPUWorker.reset_prefix_cache
38+
39+
2940
class RefHFModelWithIPCHandles(RefHFModel):
3041
def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4):
3142
self.device_id = device_id

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,10 @@ methods:
387387
default: 2
388388
return_annotation: tensorrt_llm.executor.result.IterationResult
389389
status: beta
390+
reset_prefix_cache:
391+
parameters: {}
392+
return_annotation: None
393+
status: beta
390394
get_stats:
391395
parameters:
392396
timeout:

0 commit comments

Comments
 (0)