Skip to content

Commit cbcdfa8

Browse files
committed
async sleep and sync wakeup
1 parent ea9aa7a commit cbcdfa8

File tree

9 files changed

+44
-15
lines changed

9 files changed

+44
-15
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,9 @@ def update_params(self, request: Any):
443443
"""Update params."""
444444
self.executor.update_params(request)
445445

446-
def sleep(self, level: int = 1):
446+
async def sleep(self, level: int = 1):
447447
"""Sleep."""
448-
self.executor.sleep(level)
448+
await self.executor.sleep(level)
449449

450450
def wakeup(self, tags: list[str] | None = None):
451451
"""Wakeup."""

lmdeploy/pytorch/engine/executor/mp_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ def warmup(self):
373373
"""Build cache engine."""
374374
self.collective_rpc('warmup')
375375

376+
async def sleep(self, level: int = 1):
377+
"""Sleep."""
378+
await self.collective_rpc_async('sleep', args=(level, ), return_mask=0)
379+
380+
def wakeup(self, tags: list[str] | None = None):
381+
"""Wakeup."""
382+
self.collective_rpc('wakeup', args=(tags, ), return_mask=0)
383+
376384
async def _prefetch_outputs(self):
377385
while True:
378386
out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,18 @@ def collective_rpc(self,
321321
kwargs = dict()
322322
return ray.get([getattr(worker, method).remote(*args, **kwargs) for worker in self.workers], timeout=timeout)
323323

324+
async def collective_rpc_async(self,
325+
method: str,
326+
args: tuple[Any] = None,
327+
kwargs: dict[str, Any] = None):
328+
"""Collective async rpc."""
329+
if args is None:
330+
args = list()
331+
if kwargs is None:
332+
kwargs = dict()
333+
tasks = [getattr(worker, method).remote(*args, **kwargs) for worker in self.workers]
334+
return await asyncio.gather(*tasks)
335+
324336
def build_model(self):
325337
"""Build model."""
326338
self.collective_rpc('build_model')
@@ -353,9 +365,9 @@ def warmup(self):
353365
"""Build cache engine."""
354366
self.collective_rpc('warmup')
355367

356-
def sleep(self, level: int = 1):
368+
async def sleep(self, level: int = 1):
357369
"""Sleep."""
358-
self.collective_rpc('sleep', (level, ))
370+
await self.collective_rpc_async('sleep', (level, ))
359371

360372
def wakeup(self, tags: list[str] | None = None):
361373
"""Wakeup."""

lmdeploy/pytorch/engine/executor/uni_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0):
108108
assert dp_rank == 0
109109
return await self.model_agent.get_output_async()
110110

111+
async def sleep(self, level: int = 1):
112+
"""Sleep."""
113+
await self.model_agent.sleep(level)
114+
115+
def wakeup(self, tags: list[str] | None = None):
116+
"""Wakeup."""
117+
self.model_agent.wakeup(tags)
118+
111119
def get_input_processor(self):
112120
"""Get input processor."""
113121
return self.model_agent.get_input_processor()

lmdeploy/pytorch/engine/mp_engine/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def end_session(self, session_id: int):
5353
"""End session."""
5454
return self._collective_rpc('end_session', session_id)
5555

56-
def sleep(self, level: int):
56+
async def sleep(self, level: int):
5757
"""sleep."""
58-
return self._collective_rpc('sleep', level)
58+
return await self._collective_rpc_async('sleep', level)
5959

6060
def wakeup(self, tags: list[str] | None = None):
6161
"""Wakeup."""

lmdeploy/pytorch/engine/mp_engine/base_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
100100
"""
101101
return self.engine.p2p_drop_connect(drop_conn_request)
102102

103-
def sleep(self, level: int = 1):
103+
async def sleep(self, level: int = 1):
104104
"""sleep."""
105-
return self.engine.sleep(level)
105+
return await self.engine.sleep(level)
106106

107107
def wakeup(self, tags: list[str] | None = None):
108108
"""Wakeup."""

lmdeploy/serve/core/async_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def _if_session_stale(self, session: Session,
209209
epoch = session.epoch
210210
if epoch is None or epoch == self.epoch:
211211
return None
212-
logger.info(
213-
f'[generate] session {session.session_id} dropped (session.epoch={epoch}, epoch={self.epoch})')
212+
logger.info(f'[generate] drop stale session {session.session_id} '
213+
f'(session.epoch={epoch}, async_engine.epoch={self.epoch})')
214214
return GenOut(response='',
215215
history_token_len=session.step,
216216
input_token_len=input_token_len,
@@ -241,15 +241,15 @@ def prepare_sleep(self):
241241
self.sleeping_tags = {'weights', 'kv_cache'}
242242
self.is_sleeping = True
243243

244-
def sleep(self, level: int = 1):
244+
async def sleep(self, level: int = 1):
245245
"""Sleep the model.
246246
247247
Args:
248248
level (int): The sleep level. Level 1 sleep will offload the model
249249
weights and discard the kv cache. Level 2 sleep will
250250
discard both the model weights and the kv cache.
251251
"""
252-
self.engine.sleep(level)
252+
await self.engine.sleep(level)
253253
self.sleeping_tags = {'weights', 'kv_cache'}
254254
self.is_sleeping = True
255255

@@ -460,7 +460,7 @@ def is_error(status):
460460
async with session.request_handle() as handle:
461461
if session.epoch is not None and session.epoch != self.epoch:
462462
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
463-
f'session.epoch={session.epoch}, epoch={self.epoch}')
463+
f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}')
464464
metrics_processor.increase_failed_requests('abort')
465465
yield GenOut(response='',
466466
history_token_len=0,

lmdeploy/serve/managers/session_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def request_handle(self):
105105

106106
async def async_abort(self):
107107
"""Abort the session."""
108-
logger.info(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
108+
logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
109109
if self._handle is not None:
110110
await self._handle.async_cancel(self.session_id)
111111

@@ -216,6 +216,7 @@ def get(self, session_id: int | None = None, **kwargs) -> Session:
216216

217217
async def async_abort_all(self):
218218
"""Abort all sessions."""
219+
logger.info(f'[SessionManager] aborting all {len(self.sessions)} sessions')
219220
tasks = []
220221
for session in list(self.sessions.values()):
221222
tasks.append(session.async_abort())

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ async def sleep(raw_request: Request = None):
11921192
async_engine = VariableInterface.async_engine
11931193
async_engine.prepare_sleep()
11941194
await async_engine.stop_all_session()
1195-
async_engine.sleep(level)
1195+
await async_engine.sleep(level)
11961196
return Response(status_code=200)
11971197

11981198

0 commit comments

Comments
 (0)