Skip to content

Commit f7f7546

Browse files
authored
Reject requests on stale session or sleeping engine (#4496)
* bind epoch to session * reject requests when engine sleeps * implement EngineSleepingMiddleware * validate sleep request * fix race window * async sleep and sync wakeup * change to async sleep in turbomind
1 parent 9f33332 commit f7f7546

11 files changed

Lines changed: 168 additions & 29 deletions

File tree

lmdeploy/pytorch/engine/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,9 @@ def update_params(self, request: Any):
443443
"""Update params."""
444444
self.executor.update_params(request)
445445

446-
def sleep(self, level: int = 1):
446+
async def sleep(self, level: int = 1):
447447
"""Sleep."""
448-
self.executor.sleep(level)
448+
await self.executor.sleep(level)
449449

450450
def wakeup(self, tags: list[str] | None = None):
451451
"""Wakeup."""

lmdeploy/pytorch/engine/executor/mp_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ def warmup(self):
373373
"""Build cache engine."""
374374
self.collective_rpc('warmup')
375375

376+
async def sleep(self, level: int = 1):
377+
"""Sleep."""
378+
await self.collective_rpc_async('sleep', args=(level, ), return_mask=0)
379+
380+
def wakeup(self, tags: list[str] | None = None):
381+
"""Wakeup."""
382+
self.collective_rpc('wakeup', args=(tags, ), return_mask=0)
383+
376384
async def _prefetch_outputs(self):
377385
while True:
378386
out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,18 @@ def collective_rpc(self,
321321
kwargs = dict()
322322
return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)
323323

324+
async def collective_rpc_async(self,
325+
method: str,
326+
args: tuple[Any] = None,
327+
kwargs: dict[str, Any] = None):
328+
"""Collective async rpc."""
329+
if args is None:
330+
args = list()
331+
if kwargs is None:
332+
kwargs = dict()
333+
tasks = [getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]
334+
return await asyncio.gather(*tasks)
335+
324336
def build_model(self):
325337
"""Build model."""
326338
self.collective_rpc('build_model')
@@ -353,9 +365,9 @@ def warmup(self):
353365
"""Build cache engine."""
354366
self.collective_rpc('warmup')
355367

356-
def sleep(self, level: int = 1):
368+
async def sleep(self, level: int = 1):
357369
"""Sleep."""
358-
self.collective_rpc('sleep', (level, ))
370+
await self.collective_rpc_async('sleep', (level, ))
359371

360372
def wakeup(self, tags: list[str] | None = None):
361373
"""Wakeup."""

lmdeploy/pytorch/engine/executor/uni_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0):
108108
assert dp_rank == 0
109109
return await self.model_agent.get_output_async()
110110

111+
async def sleep(self, level: int = 1):
112+
"""Sleep."""
113+
await self.model_agent.sleep(level)
114+
115+
def wakeup(self, tags: list[str] | None = None):
116+
"""Wakeup."""
117+
self.model_agent.wakeup(tags)
118+
111119
def get_input_processor(self):
112120
"""Get input processor."""
113121
return self.model_agent.get_input_processor()

lmdeploy/pytorch/engine/mp_engine/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def end_session(self, session_id: int):
5353
"""End session."""
5454
return self._collective_rpc('end_session', session_id)
5555

56-
def sleep(self, level: int):
56+
async def sleep(self, level: int):
5757
"""sleep."""
58-
return self._collective_rpc('sleep', level)
58+
return await self._collective_rpc_async('sleep', level)
5959

6060
def wakeup(self, tags: list[str] | None = None):
6161
"""Wakeup."""

lmdeploy/pytorch/engine/mp_engine/base_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
100100
"""
101101
return self.engine.p2p_drop_connect(drop_conn_request)
102102

103-
def sleep(self, level: int = 1):
103+
async def sleep(self, level: int = 1):
104104
"""sleep."""
105-
return self.engine.sleep(level)
105+
return await self.engine.sleep(level)
106106

107107
def wakeup(self, tags: list[str] | None = None):
108108
"""Wakeup."""

lmdeploy/serve/core/async_engine.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class GenOut:
4646
history_token_len: int
4747
input_token_len: int
4848
generate_token_len: int
49-
finish_reason: Literal['stop', 'length', 'error'] | None = None
49+
finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None
5050
token_ids: list[int] | None = None
5151
logprobs: list[dict[int, float]] | None = None
5252
logits: Any = None
@@ -201,6 +201,23 @@ def _build_stat_loggers(self):
201201
# set stats loggers of metrics processor
202202
metrics_processor.stat_loggers = self.stat_loggers
203203

204+
def _if_session_stale(self, session: Session,
205+
input_token_len: int) -> GenOut | None:
206+
"""If ``session.epoch`` was stamped by api_server and
207+
``stop_all_session`` ran since then (the engine epoch changed), drop
208+
the session."""
209+
epoch = session.epoch
210+
if epoch is None or epoch == self.epoch:
211+
return None
212+
logger.info(f'[generate] drop stale session {session.session_id} '
213+
f'(session.epoch={epoch}, async_engine.epoch={self.epoch})')
214+
return GenOut(response='',
215+
history_token_len=session.step,
216+
input_token_len=input_token_len,
217+
generate_token_len=0,
218+
finish_reason='abort',
219+
token_ids=[])
220+
204221
async def get_schedule_metrics(self):
205222
result = self.engine.get_schedule_metrics()
206223
if asyncio.iscoroutine(result):
@@ -215,19 +232,24 @@ async def do_log_stats(self):
215232

216233
async def stop_all_session(self):
217234
"""Stop all running sessions."""
218-
logger.info('stop all sessions')
235+
logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}')
219236
self.epoch += 1
220237
await self.session_mgr.async_abort_all()
221238

222-
def sleep(self, level: int = 1):
239+
def prepare_sleep(self):
240+
"""Reject new inference requests before backend sleep starts."""
241+
self.sleeping_tags = {'weights', 'kv_cache'}
242+
self.is_sleeping = True
243+
244+
async def sleep(self, level: int = 1):
223245
"""Sleep the model.
224246
225247
Args:
226248
level (int): The sleep level. Level 1 sleep will offload the model
227249
weights and discard the kv cache. Level 2 sleep will
228250
discard both the model weights and the kv cache.
229251
"""
230-
self.engine.sleep(level)
252+
await self.engine.sleep(level)
231253
self.sleeping_tags = {'weights', 'kv_cache'}
232254
self.is_sleeping = True
233255

@@ -342,7 +364,8 @@ async def generate(
342364
do_preprocess (bool): whether pre-process the messages. Default to
343365
True, which means chat_template will be applied.
344366
"""
345-
epoch = self.epoch
367+
metrics_processor.increase_total_requests()
368+
346369
if (messages is not None) ^ (input_ids is None):
347370
raise ValueError('You must specify exactly one of messages or input_ids')
348371
if isinstance(session_id, Session):
@@ -389,6 +412,7 @@ async def generate(
389412

390413
if gen_config.max_new_tokens == 0:
391414
logger.info(f'run out of tokens. session={session_id}.')
415+
metrics_processor.increase_failed_requests('error')
392416
yield GenOut(response='',
393417
history_token_len=session.step,
394418
input_token_len=len(input_ids),
@@ -403,6 +427,7 @@ async def generate(
403427
or gen_config.output_logits == 'all'):
404428
errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
405429
'when prefix caching is ON')
430+
metrics_processor.increase_failed_requests('error')
406431
yield GenOut(response=errmsg,
407432
history_token_len=session.step,
408433
input_token_len=len(input_ids),
@@ -424,10 +449,18 @@ def is_error(status):
424449
if not gen_config.ignore_eos:
425450
stop_ids = gen_config.stop_token_ids or []
426451

427-
metrics_processor.increase_total_requests()
452+
453+
stale = self._if_session_stale(session, len(prompt_input['input_ids']))
454+
if stale is not None:
455+
metrics_processor.increase_failed_requests('abort')
456+
yield stale
457+
if sequence_end:
458+
self.session_mgr.remove(session)
459+
return
428460
async with session.request_handle() as handle:
429-
if epoch != self.epoch:
430-
logger.info(f'[generate] session {session_id} got aborted before starting inference')
461+
if session.epoch is not None and session.epoch != self.epoch:
462+
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
463+
f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}')
431464
metrics_processor.increase_failed_requests('abort')
432465
yield GenOut(response='',
433466
history_token_len=0,

lmdeploy/serve/managers/session_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs):
2424
self.history: list[tuple[Any, str]] = []
2525
self.gen_config: GenerationConfig | None = None
2626
self.step: int = 0
27+
# Set by api_server to AsyncEngine.epoch when a request binds a session;
28+
# generate() drops work if stop_all_session() bumped epoch after bind.
29+
self.epoch: int | None = None
2730
# event to wait for the session to be active
2831
self._active: asyncio.Event | None = None
2932
self._handle = None # inference instance
@@ -64,6 +67,7 @@ def reset(self):
6467
self.history = []
6568
self.gen_config = None
6669
self.step = 0
70+
self.epoch = None
6771
self._active = None
6872
self._handle = None
6973
self._session_mgr = None
@@ -101,7 +105,7 @@ async def request_handle(self):
101105

102106
async def async_abort(self):
103107
"""Abort the session."""
104-
logger.info(f'[session] Aborting session {self.session_id}')
108+
logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
105109
if self._handle is not None:
106110
await self._handle.async_cancel(self.session_id)
107111

@@ -205,13 +209,14 @@ def get(self, session_id: int | None = None, **kwargs) -> Session:
205209
session.update(**kwargs)
206210
return session
207211
else:
208-
logger.info(f'[SessionManager] session {session_id} not found. Creating...')
212+
logger.debug(f'[SessionManager] session {session_id} not found. Creating...')
209213
session = Session(session_id, self, **kwargs)
210214
self.sessions[session_id] = session
211215
return session
212216

213217
async def async_abort_all(self):
214218
"""Abort all sessions."""
219+
logger.info(f'[SessionManager] aborting all {len(self.sessions)} sessions')
215220
tasks = []
216221
for session in list(self.sessions.values()):
217222
tasks.append(session.async_abort())

lmdeploy/serve/openai/api_server.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from __future__ import annotations
3+
24
# yapf: disable
35
import asyncio
46
import copy
@@ -10,7 +12,7 @@
1012
from contextlib import asynccontextmanager
1113
from functools import partial
1214
from http import HTTPStatus
13-
from typing import Literal
15+
from typing import TYPE_CHECKING, Literal
1416

1517
import uvicorn
1618
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
@@ -76,10 +78,13 @@
7678
)
7779
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
7880
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
79-
from lmdeploy.serve.utils.server_utils import validate_json_request
81+
from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware, EngineSleepingMiddleware, validate_json_request
8082
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
8183
from lmdeploy.utils import get_logger
8284

85+
if TYPE_CHECKING:
86+
from lmdeploy.serve.managers import Session
87+
8388
# yapf: enable
8489

8590
logger = get_logger('lmdeploy')
@@ -100,12 +105,15 @@ class VariableInterface:
100105
enable_abort_handling: bool = False
101106

102107
@staticmethod
103-
def get_session(session_id: int) -> int:
108+
def get_session(session_id: int) -> Session:
104109
session_mgr = VariableInterface.get_session_manager()
105110
if session_id == -1:
106-
return session_mgr.get()
111+
session = session_mgr.get()
107112
else:
108-
return session_mgr.get(session_id)
113+
session = session_mgr.get(session_id)
114+
# Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``.
115+
session.epoch = VariableInterface.async_engine.epoch
116+
return session
109117

110118
@staticmethod
111119
def get_session_manager():
@@ -769,7 +777,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
769777
error_check_ret = check_request(request)
770778
if error_check_ret is not None:
771779
return error_check_ret
772-
773780
json_request = await raw_request.json()
774781
migration_request = json_request.pop('migration_request', None)
775782
with_cache = json_request.pop('with_cache', False)
@@ -963,6 +970,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
963970
error_check_ret = check_request(request)
964971
if error_check_ret is not None:
965972
return error_check_ret
973+
966974
session = VariableInterface.get_session(request.session_id)
967975

968976
prompt = request.prompt
@@ -1175,7 +1183,16 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None):
11751183
@router.post('/sleep', dependencies=[Depends(validate_json_request)])
11761184
async def sleep(raw_request: Request = None):
11771185
level = raw_request.query_params.get('level', '1')
1178-
VariableInterface.async_engine.sleep(int(level))
1186+
try:
1187+
level = int(level)
1188+
except (TypeError, ValueError):
1189+
return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be an integer.')
1190+
if level not in (1, 2):
1191+
return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.')
1192+
async_engine = VariableInterface.async_engine
1193+
async_engine.prepare_sleep()
1194+
await async_engine.stop_all_session()
1195+
await async_engine.sleep(level)
11791196
return Response(status_code=200)
11801197

11811198

@@ -1526,10 +1543,13 @@ def serve(model_path: str,
15261543
)
15271544

15281545
if api_keys is not None and (tokens := [key for key in api_keys if key]):
1529-
from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware
1530-
15311546
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
15321547

1548+
def is_engine_sleeping() -> bool:
1549+
eng = VariableInterface.async_engine
1550+
return eng is not None and eng.is_sleeping
1551+
app.add_middleware(EngineSleepingMiddleware, is_sleeping=is_engine_sleeping)
1552+
15331553
# set the maximum number of concurrent requests
15341554
if max_concurrent_requests is not None:
15351555
app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests)

0 commit comments

Comments
 (0)