@@ -46,7 +46,7 @@ class GenOut:
4646 history_token_len : int
4747 input_token_len : int
4848 generate_token_len : int
49- finish_reason : Literal ['stop' , 'length' , 'error' ] | None = None
49+ finish_reason : Literal ['stop' , 'length' , 'error' , 'abort' ] | None = None
5050 token_ids : list [int ] | None = None
5151 logprobs : list [dict [int , float ]] | None = None
5252 logits : Any = None
@@ -201,6 +201,23 @@ def _build_stat_loggers(self):
201201 # set stats loggers of metrics processor
202202 metrics_processor .stat_loggers = self .stat_loggers
203203
204+ def _if_session_stale (self , session : Session ,
205+ input_token_len : int ) -> GenOut | None :
206+ """If ``session.epoch`` was stamped by api_server and
207+ ``stop_all_session`` ran since then (the engine epoch changed), drop
208+ the session."""
209+ epoch = session .epoch
210+ if epoch is None or epoch == self .epoch :
211+ return None
212+ logger .info (f'[generate] drop stale session { session .session_id } '
213+ f'(session.epoch={ epoch } , async_engine.epoch={ self .epoch } )' )
214+ return GenOut (response = '' ,
215+ history_token_len = session .step ,
216+ input_token_len = input_token_len ,
217+ generate_token_len = 0 ,
218+ finish_reason = 'abort' ,
219+ token_ids = [])
220+
204221 async def get_schedule_metrics (self ):
205222 result = self .engine .get_schedule_metrics ()
206223 if asyncio .iscoroutine (result ):
@@ -215,19 +232,24 @@ async def do_log_stats(self):
215232
216233 async def stop_all_session (self ):
217234 """Stop all running sessions."""
218- logger .info ('stop all sessions' )
235+ logger .info (f 'stop all sessions, epoch { self . epoch } -> { self . epoch + 1 } ' )
219236 self .epoch += 1
220237 await self .session_mgr .async_abort_all ()
221238
222- def sleep (self , level : int = 1 ):
239+ def prepare_sleep (self ):
240+ """Reject new inference requests before backend sleep starts."""
241+ self .sleeping_tags = {'weights' , 'kv_cache' }
242+ self .is_sleeping = True
243+
244+ async def sleep (self , level : int = 1 ):
223245 """Sleep the model.
224246
225247 Args:
226248 level (int): The sleep level. Level 1 sleep will offload the model
227249 weights and discard the kv cache. Level 2 sleep will
228250 discard both the model weights and the kv cache.
229251 """
230- self .engine .sleep (level )
252+ await self .engine .sleep (level )
231253 self .sleeping_tags = {'weights' , 'kv_cache' }
232254 self .is_sleeping = True
233255
@@ -342,7 +364,8 @@ async def generate(
342364 do_preprocess (bool): whether pre-process the messages. Default to
343365 True, which means chat_template will be applied.
344366 """
345- epoch = self .epoch
367+ metrics_processor .increase_total_requests ()
368+
346369 if (messages is not None ) ^ (input_ids is None ):
347370 raise ValueError ('You must specify exactly one of messages or input_ids' )
348371 if isinstance (session_id , Session ):
@@ -389,6 +412,7 @@ async def generate(
389412
390413 if gen_config .max_new_tokens == 0 :
391414 logger .info (f'run out of tokens. session={ session_id } .' )
415+ metrics_processor .increase_failed_requests ('error' )
392416 yield GenOut (response = '' ,
393417 history_token_len = session .step ,
394418 input_token_len = len (input_ids ),
@@ -403,6 +427,7 @@ async def generate(
403427 or gen_config .output_logits == 'all' ):
404428 errmsg = ('lmdeploy does not support outputting all token\' s logits or last_hidden_state '
405429 'when prefix caching is ON' )
430+ metrics_processor .increase_failed_requests ('error' )
406431 yield GenOut (response = errmsg ,
407432 history_token_len = session .step ,
408433 input_token_len = len (input_ids ),
@@ -424,10 +449,18 @@ def is_error(status):
424449 if not gen_config .ignore_eos :
425450 stop_ids = gen_config .stop_token_ids or []
426451
427- metrics_processor .increase_total_requests ()
452+
453+ stale = self ._if_session_stale (session , len (prompt_input ['input_ids' ]))
454+ if stale is not None :
455+ metrics_processor .increase_failed_requests ('abort' )
456+ yield stale
457+ if sequence_end :
458+ self .session_mgr .remove (session )
459+ return
428460 async with session .request_handle () as handle :
429- if epoch != self .epoch :
430- logger .info (f'[generate] session { session_id } got aborted before starting inference' )
461+ if session .epoch is not None and session .epoch != self .epoch :
462+ logger .info (f'[generate] session { session_id } got aborted before starting inference, '
463+ f'session.epoch={ session .epoch } , async_engine.epoch={ self .epoch } ' )
431464 metrics_processor .increase_failed_requests ('abort' )
432465 yield GenOut (response = '' ,
433466 history_token_len = 0 ,
0 commit comments