diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 27848de026..e2a7495624 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -443,9 +443,9 @@ def update_params(self, request: Any): """Update params.""" self.executor.update_params(request) - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" - self.executor.sleep(level) + await self.executor.sleep(level) def wakeup(self, tags: list[str] | None = None): """Wakeup.""" diff --git a/lmdeploy/pytorch/engine/executor/mp_executor.py b/lmdeploy/pytorch/engine/executor/mp_executor.py index 9f457eec3c..4a3d60531c 100644 --- a/lmdeploy/pytorch/engine/executor/mp_executor.py +++ b/lmdeploy/pytorch/engine/executor/mp_executor.py @@ -373,6 +373,14 @@ def warmup(self): """Build cache engine.""" self.collective_rpc('warmup') + async def sleep(self, level: int = 1): + """Sleep.""" + await self.collective_rpc_async('sleep', args=(level, ), return_mask=0) + + def wakeup(self, tags: list[str] | None = None): + """Wakeup.""" + self.collective_rpc('wakeup', args=(tags, ), return_mask=0) + async def _prefetch_outputs(self): while True: out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0] diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index d585e78c15..c20adbfd8d 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -321,6 +321,18 @@ def collective_rpc(self, kwargs = dict() return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout) + async def collective_rpc_async(self, + method: str, + args: tuple[Any] = None, + kwargs: dict[str, Any] = None): + """Collective async rpc.""" + if args is None: + args = list() + if kwargs is None: + kwargs = dict() + tasks = [getattr(worker, method).remote(*args, **kwargs) for worker in self.workers] + return await asyncio.gather(*tasks) + def build_model(self): """Build model.""" self.collective_rpc('build_model') @@ -353,9 +365,9 @@ def warmup(self): """Build cache engine.""" self.collective_rpc('warmup') - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" - self.collective_rpc('sleep', (level, )) + await self.collective_rpc_async('sleep', (level, )) def wakeup(self, tags: list[str] | None = None): """Wakeup.""" diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index 34c7412ee6..e4d292bd9c 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0): assert dp_rank == 0 return await self.model_agent.get_output_async() + async def sleep(self, level: int = 1): + """Sleep.""" + await self.model_agent.sleep(level) + + def wakeup(self, tags: list[str] | None = None): + """Wakeup.""" + self.model_agent.wakeup(tags) + def get_input_processor(self): """Get input processor.""" return self.model_agent.get_input_processor() diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index c0352787b1..2dfe423ed7 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -53,9 +53,9 @@ def end_session(self, session_id: int): """End session.""" return self._collective_rpc('end_session', session_id) - def sleep(self, level: int): + async def sleep(self, level: int): """sleep.""" - return self._collective_rpc('sleep', level) + return await self._collective_rpc_async('sleep', level) def wakeup(self, tags: list[str] | None = None): """Wakeup.""" diff --git a/lmdeploy/pytorch/engine/mp_engine/base_worker.py b/lmdeploy/pytorch/engine/mp_engine/base_worker.py index 0e0fa0fa82..bc2076863a 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base_worker.py +++ b/lmdeploy/pytorch/engine/mp_engine/base_worker.py @@ -100,9 +100,9 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): """ return self.engine.p2p_drop_connect(drop_conn_request) - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """sleep.""" - return self.engine.sleep(level) + return await self.engine.sleep(level) def wakeup(self, tags: list[str] | None = None): """Wakeup.""" diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index c5dfcd0364..a259fbdd90 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -46,7 +46,7 @@ class GenOut: history_token_len: int input_token_len: int generate_token_len: int - finish_reason: Literal['stop', 'length', 'error'] | None = None + finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None token_ids: list[int] | None = None logprobs: list[dict[int, float]] | None = None logits: Any = None @@ -201,6 +201,23 @@ def _build_stat_loggers(self): # set stats loggers of metrics processor metrics_processor.stat_loggers = self.stat_loggers + def _if_session_stale(self, session: Session, + input_token_len: int) -> GenOut | None: + """If ``session.epoch`` was stamped by api_server and + ``stop_all_session`` ran since then (the engine epoch changed), drop + the session.""" + epoch = session.epoch + if epoch is None or epoch == self.epoch: + return None + logger.info(f'[generate] drop stale session {session.session_id} ' + f'(session.epoch={epoch}, async_engine.epoch={self.epoch})') + return GenOut(response='', + history_token_len=session.step, + input_token_len=input_token_len, + generate_token_len=0, + finish_reason='abort', + token_ids=[]) + async def get_schedule_metrics(self): result = self.engine.get_schedule_metrics() if asyncio.iscoroutine(result): @@ -215,11 +232,16 @@ async def do_log_stats(self): async def stop_all_session(self): """Stop all running sessions.""" - logger.info('stop all sessions') + logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}') self.epoch += 1 await self.session_mgr.async_abort_all() - def sleep(self, level: int = 1): + def prepare_sleep(self): + """Reject new inference requests before backend sleep starts.""" + self.sleeping_tags = {'weights', 'kv_cache'} + self.is_sleeping = True + + async def sleep(self, level: int = 1): """Sleep the model. Args: @@ -227,7 +249,7 @@ def sleep(self, level: int = 1): weights and discard the kv cache. Level 2 sleep will discard both the model weights and the kv cache. """ - self.engine.sleep(level) + await self.engine.sleep(level) self.sleeping_tags = {'weights', 'kv_cache'} self.is_sleeping = True @@ -342,7 +364,8 @@ async def generate( do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. """ - epoch = self.epoch + metrics_processor.increase_total_requests() + if (messages is not None) ^ (input_ids is None): raise ValueError('You must specify exactly one of messages or input_ids') if isinstance(session_id, Session): @@ -389,6 +412,7 @@ async def generate( if gen_config.max_new_tokens == 0: logger.info(f'run out of tokens. session={session_id}.') + metrics_processor.increase_failed_requests('error') yield GenOut(response='', history_token_len=session.step, input_token_len=len(input_ids), @@ -403,6 +427,7 @@ async def generate( or gen_config.output_logits == 'all'): errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state ' 'when prefix caching is ON') + metrics_processor.increase_failed_requests('error') yield GenOut(response=errmsg, history_token_len=session.step, input_token_len=len(input_ids), @@ -424,10 +449,18 @@ def is_error(status): if not gen_config.ignore_eos: stop_ids = gen_config.stop_token_ids or [] - metrics_processor.increase_total_requests() + + stale = self._if_session_stale(session, len(prompt_input['input_ids'])) + if stale is not None: + metrics_processor.increase_failed_requests('abort') + yield stale + if sequence_end: + self.session_mgr.remove(session) + return async with session.request_handle() as handle: - if epoch != self.epoch: - logger.info(f'[generate] session {session_id} got aborted before starting inference') + if session.epoch is not None and session.epoch != self.epoch: + logger.info(f'[generate] session {session_id} got aborted before starting inference, ' + f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}') metrics_processor.increase_failed_requests('abort') yield GenOut(response='', history_token_len=0, diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 0ac7e1465f..685631091f 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -24,6 +24,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs): self.history: list[tuple[Any, str]] = [] self.gen_config: GenerationConfig | None = None self.step: int = 0 + # Set by api_server to AsyncEngine.epoch when a request binds a session; + # generate() drops work if stop_all_session() bumped epoch after bind. + self.epoch: int | None = None # event to wait for the session to be active self._active: asyncio.Event | None = None self._handle = None # inference instance @@ -64,6 +67,7 @@ def reset(self): self.history = [] self.gen_config = None self.step = 0 + self.epoch = None self._active = None self._handle = None self._session_mgr = None @@ -101,7 +105,7 @@ async def request_handle(self): async def async_abort(self): """Abort the session.""" - logger.info(f'[session] Aborting session {self.session_id}') + logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}') if self._handle is not None: await self._handle.async_cancel(self.session_id) @@ -205,13 +209,14 @@ def get(self, session_id: int | None = None, **kwargs) -> Session: session.update(**kwargs) return session else: - logger.info(f'[SessionManager] session {session_id} not found. Creating...') + logger.debug(f'[SessionManager] session {session_id} not found. Creating...') session = Session(session_id, self, **kwargs) self.sessions[session_id] = session return session async def async_abort_all(self): """Abort all sessions.""" + logger.info(f'[SessionManager] aborting all {len(self.sessions)} sessions') tasks = [] for session in list(self.sessions.values()): tasks.append(session.async_abort()) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 2c552febd0..4e93acc9f3 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + # yapf: disable import asyncio import copy @@ -10,7 +12,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Literal +from typing import TYPE_CHECKING, Literal import uvicorn from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status @@ -76,10 +78,13 @@ ) from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager -from lmdeploy.serve.utils.server_utils import validate_json_request +from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware, EngineSleepingMiddleware, validate_json_request from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import get_logger +if TYPE_CHECKING: + from lmdeploy.serve.managers import Session + # yapf: enable logger = get_logger('lmdeploy') @@ -100,12 +105,15 @@ class VariableInterface: enable_abort_handling: bool = False @staticmethod - def get_session(session_id: int) -> int: + def get_session(session_id: int) -> Session: session_mgr = VariableInterface.get_session_manager() if session_id == -1: - return session_mgr.get() + session = session_mgr.get() else: - return session_mgr.get(session_id) + session = session_mgr.get(session_id) + # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. + session.epoch = VariableInterface.async_engine.epoch + return session @staticmethod def get_session_manager(): @@ -769,7 +777,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) @@ -963,6 +970,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret + session = VariableInterface.get_session(request.session_id) prompt = request.prompt @@ -1175,7 +1183,16 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None): @router.post('/sleep', dependencies=[Depends(validate_json_request)]) async def sleep(raw_request: Request = None): level = raw_request.query_params.get('level', '1') - VariableInterface.async_engine.sleep(int(level)) + try: + level = int(level) + except (TypeError, ValueError): + return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be an integer.') + if level not in (1, 2): + return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.') + async_engine = VariableInterface.async_engine + async_engine.prepare_sleep() + await async_engine.stop_all_session() + await async_engine.sleep(level) return Response(status_code=200) @@ -1526,10 +1543,13 @@ def serve(model_path: str, ) if api_keys is not None and (tokens := [key for key in api_keys if key]): - from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware - app.add_middleware(AuthenticationMiddleware, tokens=tokens) + def is_engine_sleeping() -> bool: + eng = VariableInterface.async_engine + return eng is not None and eng.is_sleeping + app.add_middleware(EngineSleepingMiddleware, is_sleeping=is_engine_sleeping) + # set the maximum number of concurrent requests if max_concurrent_requests is not None: app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) diff --git a/lmdeploy/serve/utils/server_utils.py b/lmdeploy/serve/utils/server_utils.py index f7bcbdfa49..f032b2bc14 100644 --- a/lmdeploy/serve/utils/server_utils.py +++ b/lmdeploy/serve/utils/server_utils.py @@ -2,7 +2,8 @@ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/server_utils.py import hashlib import secrets -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable +from http import HTTPStatus from fastapi import Request from fastapi.exceptions import RequestValidationError @@ -18,6 +19,58 @@ def validate_json_request(raw_request: Request): raise RequestValidationError(errors=["Unsupported Media Type: Only 'application/json' is allowed"]) +class EngineSleepingMiddleware: + """Pure ASGI middleware that returns 503 for configured inference routes + when ``is_sleeping()`` is true (after ``POST /sleep``, until ``POST + /wakeup``). + + Notes + ----- + - Skips non-http scopes (except ``http``/``websocket`` are passed through + to the app; only ``http`` requests are gated). + - HTTP ``OPTIONS`` is passed through so CORS preflight is unaffected. + """ + + # POST routes rejected while sleeping (see POST /sleep, /wakeup). + DEFAULT_PROTECTED_INFERENCE_ROUTES = frozenset({ + ('POST', '/v1/chat/completions'), + ('POST', '/v1/completions'), + ('POST', '/generate'), + }) + + def __init__( + self, + app: ASGIApp, + is_sleeping: Callable[[], bool], + protected_routes: frozenset[tuple[str, str]] | None = None, + ) -> None: + self.app = app + self.is_sleeping = is_sleeping + self.protected_routes = protected_routes or type(self).DEFAULT_PROTECTED_INFERENCE_ROUTES + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope['type'] not in ('http', 'websocket'): + return self.app(scope, receive, send) + if scope['type'] == 'http' and scope['method'] == 'OPTIONS': + return self.app(scope, receive, send) + if scope['type'] == 'http': + root_path = scope.get('root_path', '') + url_path = URL(scope=scope).path.removeprefix(root_path) + key = (scope['method'], url_path) + if key in self.protected_routes and self.is_sleeping(): + response = JSONResponse( + content={ + 'error': ( + 'Engine is sleeping; call POST /wakeup before inference ' + '(e.g. tags=weights&tags=kv_cache).' + ), + }, + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + return response(scope, receive, send) + return self.app(scope, receive, send) + + class AuthenticationMiddleware: """Pure ASGI middleware that authenticates each request by checking if the Authorization Bearer token exists and equals anyof "{api_key}". diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f95b2b93ca..5271262f7b 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -285,7 +285,7 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig): self._tm_model = tm_model return model_comm - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep the model.""" with ThreadPoolExecutor(max_workers=self.gpu_count) as e: for _ in e.map(self.model_comm.sleep, range(self.gpu_count), [level] * self.gpu_count):