diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 57a6f57ddac9..0ddc01392772 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -5516,7 +5516,11 @@ def _handle_speculative_decoding( return new_target_inputs, num_accepted_tokens_device - def reset_prefix_cache(self): + def reset_prefix_cache(self) -> None: + """Invalidate local KV prefix-cache reuse state.""" + if self.active_requests or self.waiting_queue: + raise RuntimeError( + "reset_prefix_cache() requires no active or queued requests.") self.kv_cache_manager.reset_reuse_state() def _handle_guided_decoder_errors( diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 99bd28282e41..8ef6dadd8289 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -789,6 +789,16 @@ def wakeup(self, wakeup_tags: List[str]) -> None: materialize_with_tag(*tags) torch.cuda.synchronize() + def reset_prefix_cache(self) -> None: + """Invalidate local KV prefix-cache reuse state on PyTorch engines.""" + engine = self.engine + if engine is None or not hasattr(engine, "reset_prefix_cache"): + raise NotImplementedError( + "reset_prefix_cache() is only supported by the PyTorch backend." + ) + with engine.control_action(): + engine.reset_prefix_cache() + def shutdown(self): if self.doing_shutdown: return diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index aac9d8d7cf99..9beea48607ab 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -1777,6 +1777,38 @@ def _collective_rpc( f"Executor type {type(self._executor)} does not support collective RPC." ) + @set_api_status("beta") + def reset_prefix_cache(self) -> None: + """Reset local KV prefix-cache reuse state. + + This invalidates local prefix-cache metadata in the PyTorch backend. It + requires no active or queued requests, and it does not reset + connector-managed external or offloaded cache state. Callers should + quiesce traffic before invoking this method. + + Raises: + RuntimeError: If the LLM is encode-only or has no active executor. + NotImplementedError: If the executor does not support prefix-cache + reset. + """ + if self._encode_only: + raise RuntimeError("reset_prefix_cache() is not available when " + "encode_only=True.") + if self._executor is None: + raise RuntimeError("reset_prefix_cache() requires an active " + "executor.") + + if hasattr(self._executor, "collective_rpc"): + self._collective_rpc("reset_prefix_cache") + return + + reset_prefix_cache = getattr(self._executor, "reset_prefix_cache", None) + if reset_prefix_cache is None: + raise NotImplementedError( + "reset_prefix_cache() is only supported by the PyTorch backend." + ) + reset_prefix_cache() + def _build_model(self): super()._build_model() assert self._engine_dir is None diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index 47135c765678..a43fb2043877 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -155,10 +155,6 @@ def update_weights(self, ipc_handles: Optional[dict] = None): logger.error("Encountered an error in update_weights") raise e - def reset_prefix_cache(self) -> None: - """Invalidate the KV cache prefix reuse state after weight updates.""" - self.engine.reset_prefix_cache() - @control_action_decorator def wait_for_engine_idle(self) -> None: """Block until the engine has no active or queued requests.""" diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 94de0a424f01..94cf8e76867e 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -12,6 +12,7 @@ from collections import deque from contextlib import asynccontextmanager from datetime import datetime +from functools import partial from http import HTTPStatus from pathlib import Path from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, @@ -686,6 +687,9 @@ def register_routes(self): self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) + self.app.add_api_route("/reset_prefix_cache", + self.reset_prefix_cache, + methods=["POST"]) resource_governor_queue = self.generator._executor.resource_governor_queue if resource_governor_queue is not None: from .resource_governor import ResourceGovernor @@ -1047,6 +1051,27 @@ async def get_kv_cache_events(self) -> JSONResponse: pass return JSONResponse(content=events) + async def reset_prefix_cache(self) -> Response: + """Reset local PyTorch KV prefix-cache reuse state.""" + reset_fn = getattr(self.generator, "reset_prefix_cache", None) + if reset_fn is None: + return self._create_not_supported_error( + "reset_prefix_cache() is only supported by the PyTorch backend." + ) + + try: + await asyncio.get_running_loop().run_in_executor(None, reset_fn) + except NotImplementedError as e: + return self._create_not_supported_error(str(e)) + except (RuntimeError, ValueError) as e: + return self.create_error_response( + message=str(e), + err_type="InvalidRequestError", + status_code=HTTPStatus.CONFLICT, + ) + + return JSONResponse(content={"status": "success"}) + async def _extract_metrics(self, res: RequestOutput, raw_request: Request): if not res.finished: return @@ -1975,26 +2000,55 @@ async def openai_responses_delete_response( "deleted": True }) - async def release_memory(self, - request: MemoryUpdateRequest) -> JSONResponse: - assert isinstance( - self.generator, AsyncLLM - ), "/release_memory endpoint is only supported with AsyncLLM()" - await self.generator.collective_rpc('sleep', args=(request.tags, )) - return JSONResponse(content={"status": "success"}) + async def _run_worker_control_rpc(self, method: str, + args: tuple[Any, ...]) -> None: + """Dispatch a worker control method through direct or collective RPC.""" + executor = getattr(self.generator, "_executor", None) + direct_method = getattr(executor, method, None) + if direct_method is not None: + await asyncio.get_running_loop().run_in_executor( + None, partial(direct_method, *args)) + return + + if isinstance(self.generator, AsyncLLM): + await self.generator.collective_rpc(method, args=args) + return + + collective_rpc = getattr(self.generator, "_collective_rpc", None) + if collective_rpc is not None: + await asyncio.get_running_loop().run_in_executor( + None, partial(collective_rpc, method, args)) + return + + raise NotImplementedError( + f"{method}() is only supported by the PyTorch backend.") + + async def _handle_worker_control_rpc(self, method: str, + args: tuple[Any, ...]) -> Response: + """Run a worker control method and map failures to HTTP responses.""" + try: + await self._run_worker_control_rpc(method, args) + except NotImplementedError as e: + return self._create_not_supported_error(str(e)) + except (RuntimeError, ValueError) as e: + return self.create_error_response( + message=str(e), + err_type="InvalidRequestError", + status_code=HTTPStatus.CONFLICT, + ) - async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse: - assert isinstance( - self.generator, AsyncLLM - ), "/resume_memory endpoint is only supported with AsyncLLM()" - await self.generator.collective_rpc('wakeup', args=(request.tags, )) return JSONResponse(content={"status": "success"}) - async def update_weights(self, - request: UpdateWeightsRequest) -> JSONResponse: - assert isinstance( - self.generator, AsyncLLM - ), "/update_weights endpoint is only supported with AsyncLLM()" + async def release_memory(self, request: MemoryUpdateRequest) -> Response: + return await self._handle_worker_control_rpc('sleep', (request.tags, )) + + async def resume_memory(self, request: MemoryUpdateRequest) -> Response: + return await self._handle_worker_control_rpc('wakeup', (request.tags, )) + + async def update_weights(self, request: UpdateWeightsRequest) -> Response: + if not isinstance(self.generator, AsyncLLM): + return self._create_not_supported_error( + "/update_weights endpoint is only supported with AsyncLLM()") await self.generator.collective_rpc('update_weights', args=(request.weights, )) return JSONResponse(content={"status": "success"}) diff --git a/tests/unittest/_torch/executor/test_py_executor.py b/tests/unittest/_torch/executor/test_py_executor.py index 2e70be239cc6..4b101db6100d 100644 --- a/tests/unittest/_torch/executor/test_py_executor.py +++ b/tests/unittest/_torch/executor/test_py_executor.py @@ -133,6 +133,41 @@ def test_handle_special_queue_items(mock_executor): assert 2 in mock_executor.canceled_req_ids +def test_reset_prefix_cache_resets_when_idle(): + executor = object.__new__(PyExecutor) + executor.active_requests = [] + executor.waiting_queue = [] + executor.kv_cache_manager = Mock() + + executor.reset_prefix_cache() + + executor.kv_cache_manager.reset_reuse_state.assert_called_once_with() + + +def test_reset_prefix_cache_rejects_active_requests(): + executor = object.__new__(PyExecutor) + executor.active_requests = [Mock()] + executor.waiting_queue = [] + executor.kv_cache_manager = Mock() + + with pytest.raises(RuntimeError, match="no active or queued requests"): + executor.reset_prefix_cache() + + executor.kv_cache_manager.reset_reuse_state.assert_not_called() + + +def test_reset_prefix_cache_rejects_queued_requests(): + executor = object.__new__(PyExecutor) + executor.active_requests = [] + executor.waiting_queue = [Mock()] + executor.kv_cache_manager = Mock() + + with pytest.raises(RuntimeError, match="no active or queued requests"): + executor.reset_prefix_cache() + + executor.kv_cache_manager.reset_reuse_state.assert_not_called() + + def test_clear_canceled_req_ids(mock_executor): """Test clearing canceled request IDs.""" mock_executor.canceled_req_ids = [1, 2, 3] diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 09703d40cf52..fe656fb04470 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -13,7 +13,9 @@ from tensorrt_llm import LLM from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.executor.ray_gpu_worker import RayGPUWorker, RayWorkerWrapper from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, SamplingParams +from tensorrt_llm.llmapi.rlhf_utils import WorkerExtension # Ray-backed LLM teardown spawns the executor main-loop, GC and log/error # listener threads in ray-core. These are torn down only when ``ray.shutdown()`` @@ -26,6 +28,16 @@ pytestmark = pytest.mark.threadleak(enabled=False) +@pytest.mark.part0 +def test_rlhf_worker_extension_uses_base_reset_prefix_cache(): + worker_cls = RayWorkerWrapper._inject_worker_extension( + RayGPUWorker, "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension" + ) + + assert "reset_prefix_cache" not in WorkerExtension.__dict__ + assert worker_cls.reset_prefix_cache is RayGPUWorker.reset_prefix_cache + + class RefHFModelWithIPCHandles(RefHFModel): def __init__(self, model_dir: str, device_id: int = 0, num_hidden_layers: int = 4): self.device_id = device_id diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index cea76984c235..e3dabffe635a 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -401,6 +401,10 @@ methods: default: 2 return_annotation: tensorrt_llm.executor.result.IterationResult status: beta + reset_prefix_cache: + parameters: {} + return_annotation: None + status: beta get_stats: parameters: timeout: diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index fb174a99152d..fa4619f5ad71 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -51,7 +51,9 @@ from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.sampling_params import (BatchedLogitsProcessor, LogitsProcessor, SamplingParams) -from tensorrt_llm.serve.openai_protocol import CompletionRequest +from tensorrt_llm.serve.openai_protocol import (CompletionRequest, + MemoryUpdateRequest, + UpdateWeightsRequest) from tensorrt_llm.serve.openai_server import OpenAIServer from tensorrt_llm.serve.postprocess_handlers import (ChatPostprocArgs, chat_stream_post_processor) @@ -2708,6 +2710,182 @@ async def is_disconnected(self): return True +class _FakeResetExecutor: + + def __init__(self): + self.num_reset_calls = 0 + + def reset_prefix_cache(self): + self.num_reset_calls += 1 + + def shutdown(self): + pass + + +class _FakeCollectiveResetExecutor: + + def __init__(self): + self.calls = [] + + def collective_rpc(self, method, args, kwargs, non_block, unique_reply_rank, + target_ranks): + self.calls.append( + (method, args, kwargs, non_block, unique_reply_rank, target_ranks)) + return [None] + + def shutdown(self): + pass + + +class _FakeUnsupportedResetExecutor: + + def shutdown(self): + pass + + +class _FakeWorkerControlExecutor: + + def __init__(self): + self.calls = [] + + def sleep(self, tags): + self.calls.append(("sleep", tags)) + + def wakeup(self, tags): + self.calls.append(("wakeup", tags)) + + +class _FakeWorkerControlGenerator: + + def __init__(self): + self._executor = _FakeWorkerControlExecutor() + + +class _FakeCollectiveWorkerControlGenerator: + + def __init__(self): + self._executor = object() + self.calls = [] + + def _collective_rpc(self, method, args): + self.calls.append((method, args)) + + +class _FakeNotImplementedResetGenerator: + + def reset_prefix_cache(self): + raise NotImplementedError("not supported") + + +def test_llm_reset_prefix_cache_dispatches_to_executor() -> None: + llm = object.__new__(LLM_torch) + llm._encode_only = False + llm._executor = _FakeResetExecutor() + + llm.reset_prefix_cache() + + assert llm._executor.num_reset_calls == 1 + + +def test_llm_reset_prefix_cache_uses_collective_rpc() -> None: + llm = object.__new__(LLM_torch) + llm._encode_only = False + llm._executor = _FakeCollectiveResetExecutor() + + llm.reset_prefix_cache() + + assert llm._executor.calls == [("reset_prefix_cache", (), None, False, None, + None)] + + +def test_llm_reset_prefix_cache_rejects_encode_only() -> None: + llm = object.__new__(LLM_torch) + llm._encode_only = True + llm._executor = _FakeResetExecutor() + + with pytest.raises(RuntimeError, match="encode_only=True"): + llm.reset_prefix_cache() + + +def test_llm_reset_prefix_cache_rejects_unsupported_executor() -> None: + llm = object.__new__(LLM_torch) + llm._encode_only = False + llm._executor = _FakeUnsupportedResetExecutor() + + with pytest.raises(NotImplementedError, + match="only supported by the PyTorch backend"): + llm.reset_prefix_cache() + + +def test_openai_reset_prefix_cache_endpoint() -> None: + server = object.__new__(OpenAIServer) + server.generator = _FakeResetExecutor() + + response = asyncio.run(server.reset_prefix_cache()) + + assert response.status_code == 200 + assert server.generator.num_reset_calls == 1 + + +def test_openai_reset_prefix_cache_endpoint_rejects_unsupported_generator( +) -> None: + server = object.__new__(OpenAIServer) + server.generator = object() + + response = asyncio.run(server.reset_prefix_cache()) + + assert response.status_code == 501 + + +def test_openai_reset_prefix_cache_endpoint_maps_not_implemented() -> None: + server = object.__new__(OpenAIServer) + server.generator = _FakeNotImplementedResetGenerator() + + response = asyncio.run(server.reset_prefix_cache()) + + assert response.status_code == 501 + + +def test_openai_memory_endpoints_dispatch_to_non_ray_executor() -> None: + server = object.__new__(OpenAIServer) + server.generator = _FakeWorkerControlGenerator() + + release_response = asyncio.run( + server.release_memory(MemoryUpdateRequest(tags=["kv_cache"]))) + resume_response = asyncio.run( + server.resume_memory(MemoryUpdateRequest(tags=["kv_cache"]))) + + assert release_response.status_code == 200 + assert resume_response.status_code == 200 + assert server.generator._executor.calls == [("sleep", ["kv_cache"]), + ("wakeup", ["kv_cache"])] + + +def test_openai_memory_endpoints_dispatch_to_non_ray_collective_rpc() -> None: + server = object.__new__(OpenAIServer) + server.generator = _FakeCollectiveWorkerControlGenerator() + + release_response = asyncio.run( + server.release_memory(MemoryUpdateRequest(tags=["kv_cache"]))) + resume_response = asyncio.run( + server.resume_memory(MemoryUpdateRequest(tags=["kv_cache"]))) + + assert release_response.status_code == 200 + assert resume_response.status_code == 200 + assert server.generator.calls == [("sleep", (["kv_cache"], )), + ("wakeup", (["kv_cache"], ))] + + +def test_openai_update_weights_rejects_non_ray_generator() -> None: + server = object.__new__(OpenAIServer) + server.generator = _FakeWorkerControlGenerator() + + response = asyncio.run( + server.update_weights(UpdateWeightsRequest(weights=None))) + + assert response.status_code == 501 + + def test_openai_completion_list_prompt_stream_reuses_stream_metadata() -> None: async def run_request():