|
12 | 12 | from collections import deque |
13 | 13 | from contextlib import asynccontextmanager |
14 | 14 | from datetime import datetime |
| 15 | +from functools import partial |
15 | 16 | from http import HTTPStatus |
16 | 17 | from pathlib import Path |
17 | 18 | from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, |
@@ -666,6 +667,9 @@ def register_routes(self): |
666 | 667 | self.app.add_api_route("/kv_cache_events", |
667 | 668 | self.get_kv_cache_events, |
668 | 669 | methods=["POST"]) |
| 670 | + self.app.add_api_route("/reset_prefix_cache", |
| 671 | + self.reset_prefix_cache, |
| 672 | + methods=["POST"]) |
669 | 673 | resource_governor_queue = self.generator._executor.resource_governor_queue |
670 | 674 | if resource_governor_queue is not None: |
671 | 675 | from .resource_governor import ResourceGovernor |
@@ -1027,6 +1031,28 @@ async def get_kv_cache_events(self) -> JSONResponse: |
1027 | 1031 | pass |
1028 | 1032 | return JSONResponse(content=events) |
1029 | 1033 |
|
| 1034 | + async def reset_prefix_cache(self) -> Response: |
| 1035 | + """Reset local PyTorch KV prefix-cache reuse state.""" |
| 1036 | + reset_fn = getattr(self.generator, "reset_prefix_cache", None) |
| 1037 | + if reset_fn is None: |
| 1038 | + return self._create_not_supported_error( |
| 1039 | + "reset_prefix_cache() is only supported by the PyTorch backend." |
| 1040 | + ) |
| 1041 | + |
| 1042 | + try: |
| 1043 | + await asyncio.get_running_loop().run_in_executor( |
| 1044 | + None, reset_fn) |
| 1045 | + except NotImplementedError as e: |
| 1046 | + return self._create_not_supported_error(str(e)) |
| 1047 | + except (RuntimeError, ValueError) as e: |
| 1048 | + return self.create_error_response( |
| 1049 | + message=str(e), |
| 1050 | + err_type="InvalidRequestError", |
| 1051 | + status_code=HTTPStatus.CONFLICT, |
| 1052 | + ) |
| 1053 | + |
| 1054 | + return JSONResponse(content={"status": "success"}) |
| 1055 | + |
1030 | 1056 | async def _extract_metrics(self, res: RequestOutput, raw_request: Request): |
1031 | 1057 | if not res.finished: |
1032 | 1058 | return |
@@ -1952,26 +1978,58 @@ async def openai_responses_delete_response( |
1952 | 1978 | "deleted": True |
1953 | 1979 | }) |
1954 | 1980 |
|
1955 | | - async def release_memory(self, |
1956 | | - request: MemoryUpdateRequest) -> JSONResponse: |
1957 | | - assert isinstance( |
1958 | | - self.generator, AsyncLLM |
1959 | | - ), "/release_memory endpoint is only supported with AsyncLLM()" |
1960 | | - await self.generator.collective_rpc('sleep', args=(request.tags, )) |
1961 | | - return JSONResponse(content={"status": "success"}) |
| 1981 | + async def _run_worker_control_rpc(self, method: str, |
| 1982 | + args: tuple[Any, ...]) -> None: |
| 1983 | + """Dispatch a worker control method through direct or collective RPC.""" |
| 1984 | + executor = getattr(self.generator, "_executor", None) |
| 1985 | + direct_method = getattr(executor, method, None) |
| 1986 | + if direct_method is not None: |
| 1987 | + await asyncio.get_running_loop().run_in_executor( |
| 1988 | + None, partial(direct_method, *args)) |
| 1989 | + return |
| 1990 | + |
| 1991 | + if isinstance(self.generator, AsyncLLM): |
| 1992 | + await self.generator.collective_rpc(method, args=args) |
| 1993 | + return |
| 1994 | + |
| 1995 | + collective_rpc = getattr(self.generator, "_collective_rpc", None) |
| 1996 | + if collective_rpc is not None: |
| 1997 | + await asyncio.get_running_loop().run_in_executor( |
| 1998 | + None, partial(collective_rpc, method, args)) |
| 1999 | + return |
| 2000 | + |
| 2001 | + raise NotImplementedError( |
| 2002 | + f"{method}() is only supported by the PyTorch backend.") |
| 2003 | + |
| 2004 | + async def _handle_worker_control_rpc(self, method: str, |
| 2005 | + args: tuple[Any, ...]) -> Response: |
| 2006 | + """Run a worker control method and map failures to HTTP responses.""" |
| 2007 | + try: |
| 2008 | + await self._run_worker_control_rpc(method, args) |
| 2009 | + except NotImplementedError as e: |
| 2010 | + return self._create_not_supported_error(str(e)) |
| 2011 | + except (RuntimeError, ValueError) as e: |
| 2012 | + return self.create_error_response( |
| 2013 | + message=str(e), |
| 2014 | + err_type="InvalidRequestError", |
| 2015 | + status_code=HTTPStatus.CONFLICT, |
| 2016 | + ) |
1962 | 2017 |
|
1963 | | - async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse: |
1964 | | - assert isinstance( |
1965 | | - self.generator, AsyncLLM |
1966 | | - ), "/resume_memory endpoint is only supported with AsyncLLM()" |
1967 | | - await self.generator.collective_rpc('wakeup', args=(request.tags, )) |
1968 | 2018 | return JSONResponse(content={"status": "success"}) |
1969 | 2019 |
|
| 2020 | + async def release_memory(self, request: MemoryUpdateRequest) -> Response: |
| 2021 | + return await self._handle_worker_control_rpc('sleep', |
| 2022 | + (request.tags, )) |
| 2023 | + |
| 2024 | + async def resume_memory(self, request: MemoryUpdateRequest) -> Response: |
| 2025 | + return await self._handle_worker_control_rpc('wakeup', |
| 2026 | + (request.tags, )) |
| 2027 | + |
1970 | 2028 | async def update_weights(self, |
1971 | | - request: UpdateWeightsRequest) -> JSONResponse: |
1972 | | - assert isinstance( |
1973 | | - self.generator, AsyncLLM |
1974 | | - ), "/update_weights endpoint is only supported with AsyncLLM()" |
| 2029 | + request: UpdateWeightsRequest) -> Response: |
| 2030 | + if not isinstance(self.generator, AsyncLLM): |
| 2031 | + return self._create_not_supported_error( |
| 2032 | + "/update_weights endpoint is only supported with AsyncLLM()") |
1975 | 2033 | await self.generator.collective_rpc('update_weights', |
1976 | 2034 | args=(request.weights, )) |
1977 | 2035 | return JSONResponse(content={"status": "success"}) |
|
0 commit comments