Skip to content
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5516,7 +5516,11 @@ def _handle_speculative_decoding(

return new_target_inputs, num_accepted_tokens_device

def reset_prefix_cache(self):
def reset_prefix_cache(self) -> None:
"""Invalidate local KV prefix-cache reuse state."""
if self.active_requests or self.waiting_queue:
raise RuntimeError(
"reset_prefix_cache() requires no active or queued requests.")
self.kv_cache_manager.reset_reuse_state()
Comment thread
milesial marked this conversation as resolved.

def _handle_guided_decoder_errors(
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,16 @@ def wakeup(self, wakeup_tags: List[str]) -> None:
materialize_with_tag(*tags)
torch.cuda.synchronize()

def reset_prefix_cache(self) -> None:
"""Invalidate local KV prefix-cache reuse state on PyTorch engines."""
engine = self.engine
if engine is None or not hasattr(engine, "reset_prefix_cache"):
raise NotImplementedError(
"reset_prefix_cache() is only supported by the PyTorch backend."
)
with engine.control_action():
engine.reset_prefix_cache()

def shutdown(self):
if self.doing_shutdown:
return
Expand Down
32 changes: 32 additions & 0 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,38 @@ def _collective_rpc(
f"Executor type {type(self._executor)} does not support collective RPC."
)

@set_api_status("beta")
def reset_prefix_cache(self) -> None:
"""Reset local KV prefix-cache reuse state.

This invalidates local prefix-cache metadata in the PyTorch backend. It
requires no active or queued requests, and it does not reset
connector-managed external or offloaded cache state. Callers should
quiesce traffic before invoking this method.

Raises:
RuntimeError: If the LLM is encode-only or has no active executor.
NotImplementedError: If the executor does not support prefix-cache
reset.
"""
if self._encode_only:
raise RuntimeError("reset_prefix_cache() is not available when "
"encode_only=True.")
if self._executor is None:
raise RuntimeError("reset_prefix_cache() requires an active "
"executor.")

if hasattr(self._executor, "collective_rpc"):
self._collective_rpc("reset_prefix_cache")
return

reset_prefix_cache = getattr(self._executor, "reset_prefix_cache", None)
if reset_prefix_cache is None:
raise NotImplementedError(
"reset_prefix_cache() is only supported by the PyTorch backend."
)
reset_prefix_cache()

def _build_model(self):
super()._build_model()
assert self._engine_dir is None
Expand Down
4 changes: 0 additions & 4 deletions tensorrt_llm/llmapi/rlhf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,6 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
logger.error("Encountered an error in update_weights")
raise e

def reset_prefix_cache(self) -> None:
"""Invalidate the KV cache prefix reuse state after weight updates."""
self.engine.reset_prefix_cache()

@control_action_decorator
def wait_for_engine_idle(self) -> None:
"""Block until the engine has no active or queued requests."""
Expand Down
88 changes: 71 additions & 17 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import deque
from contextlib import asynccontextmanager
from datetime import datetime
from functools import partial
from http import HTTPStatus
from pathlib import Path
from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List,
Expand Down Expand Up @@ -686,6 +687,9 @@ def register_routes(self):
self.app.add_api_route("/kv_cache_events",
self.get_kv_cache_events,
methods=["POST"])
self.app.add_api_route("/reset_prefix_cache",
self.reset_prefix_cache,
methods=["POST"])
resource_governor_queue = self.generator._executor.resource_governor_queue
if resource_governor_queue is not None:
from .resource_governor import ResourceGovernor
Expand Down Expand Up @@ -1047,6 +1051,27 @@ async def get_kv_cache_events(self) -> JSONResponse:
pass
return JSONResponse(content=events)

async def reset_prefix_cache(self) -> Response:
"""Reset local PyTorch KV prefix-cache reuse state."""
reset_fn = getattr(self.generator, "reset_prefix_cache", None)
if reset_fn is None:
return self._create_not_supported_error(
"reset_prefix_cache() is only supported by the PyTorch backend."
)

try:
await asyncio.get_running_loop().run_in_executor(None, reset_fn)
except NotImplementedError as e:
return self._create_not_supported_error(str(e))
except (RuntimeError, ValueError) as e:
return self.create_error_response(
message=str(e),
err_type="InvalidRequestError",
status_code=HTTPStatus.CONFLICT,
)

return JSONResponse(content={"status": "success"})

async def _extract_metrics(self, res: RequestOutput, raw_request: Request):
if not res.finished:
return
Expand Down Expand Up @@ -1975,26 +2000,55 @@ async def openai_responses_delete_response(
"deleted": True
})

async def release_memory(self,
request: MemoryUpdateRequest) -> JSONResponse:
assert isinstance(
self.generator, AsyncLLM
), "/release_memory endpoint is only supported with AsyncLLM()"
await self.generator.collective_rpc('sleep', args=(request.tags, ))
return JSONResponse(content={"status": "success"})
async def _run_worker_control_rpc(self, method: str,
args: tuple[Any, ...]) -> None:
"""Dispatch a worker control method through direct or collective RPC."""
executor = getattr(self.generator, "_executor", None)
direct_method = getattr(executor, method, None)
if direct_method is not None:
await asyncio.get_running_loop().run_in_executor(
None, partial(direct_method, *args))
return

if isinstance(self.generator, AsyncLLM):
await self.generator.collective_rpc(method, args=args)
return

collective_rpc = getattr(self.generator, "_collective_rpc", None)
if collective_rpc is not None:
await asyncio.get_running_loop().run_in_executor(
None, partial(collective_rpc, method, args))
return

raise NotImplementedError(
f"{method}() is only supported by the PyTorch backend.")

async def _handle_worker_control_rpc(self, method: str,
args: tuple[Any, ...]) -> Response:
"""Run a worker control method and map failures to HTTP responses."""
try:
await self._run_worker_control_rpc(method, args)
except NotImplementedError as e:
return self._create_not_supported_error(str(e))
except (RuntimeError, ValueError) as e:
return self.create_error_response(
message=str(e),
err_type="InvalidRequestError",
status_code=HTTPStatus.CONFLICT,
)

async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
assert isinstance(
self.generator, AsyncLLM
), "/resume_memory endpoint is only supported with AsyncLLM()"
await self.generator.collective_rpc('wakeup', args=(request.tags, ))
return JSONResponse(content={"status": "success"})

async def update_weights(self,
request: UpdateWeightsRequest) -> JSONResponse:
assert isinstance(
self.generator, AsyncLLM
), "/update_weights endpoint is only supported with AsyncLLM()"
async def release_memory(self, request: MemoryUpdateRequest) -> Response:
return await self._handle_worker_control_rpc('sleep', (request.tags, ))

async def resume_memory(self, request: MemoryUpdateRequest) -> Response:
return await self._handle_worker_control_rpc('wakeup', (request.tags, ))

async def update_weights(self, request: UpdateWeightsRequest) -> Response:
if not isinstance(self.generator, AsyncLLM):
return self._create_not_supported_error(
"/update_weights endpoint is only supported with AsyncLLM()")
await self.generator.collective_rpc('update_weights',
args=(request.weights, ))
return JSONResponse(content={"status": "success"})
Expand Down
35 changes: 35 additions & 0 deletions tests/unittest/_torch/executor/test_py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,41 @@ def test_handle_special_queue_items(mock_executor):
assert 2 in mock_executor.canceled_req_ids


def test_reset_prefix_cache_resets_when_idle():
executor = object.__new__(PyExecutor)
executor.active_requests = []
executor.waiting_queue = []
executor.kv_cache_manager = Mock()

executor.reset_prefix_cache()

executor.kv_cache_manager.reset_reuse_state.assert_called_once_with()


def test_reset_prefix_cache_rejects_active_requests():
executor = object.__new__(PyExecutor)
executor.active_requests = [Mock()]
executor.waiting_queue = []
executor.kv_cache_manager = Mock()

with pytest.raises(RuntimeError, match="no active or queued requests"):
executor.reset_prefix_cache()

executor.kv_cache_manager.reset_reuse_state.assert_not_called()


def test_reset_prefix_cache_rejects_queued_requests():
executor = object.__new__(PyExecutor)
executor.active_requests = []
executor.waiting_queue = [Mock()]
executor.kv_cache_manager = Mock()

with pytest.raises(RuntimeError, match="no active or queued requests"):
executor.reset_prefix_cache()

executor.kv_cache_manager.reset_reuse_state.assert_not_called()


def test_clear_canceled_req_ids(mock_executor):
"""Test clearing canceled request IDs."""
mock_executor.canceled_req_ids = [1, 2, 3]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from tensorrt_llm import LLM
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.executor.ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams
from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension

# Ray-backed LLM teardown spawns the executor main-loop, GC and log/error
# listener threads in ray-core. These are torn down only when ``ray.shutdown()``
Expand All @@ -26,6 +28,16 @@
pytestmark = pytest.mark.threadleak(enabled=False)


@pytest.mark.part0
def test_rlhf_worker_extension_uses_base_reset_prefix_cache():
worker_cls = RayWorkerWrapper._inject_worker_extension(
RayGPUWorker, "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension"
)

assert "reset_prefix_cache" not in WorkerExtension.__dict__
assert worker_cls.reset_prefix_cache is RayGPUWorker.reset_prefix_cache


class RefHFModelWithIPCHandles(RefHFModel):
def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4):
self.device_id = device_id
Expand Down
4 changes: 4 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ methods:
default: 2
return_annotation: tensorrt_llm.executor.result.IterationResult
status: beta
reset_prefix_cache:
parameters: {}
return_annotation: None
status: beta
get_stats:
parameters:
timeout:
Expand Down
Loading
Loading