Skip to content

Commit 70d2682

Browse files
authored
fix cancel stopped seq (#4654)
* fix * fix race after wakeup before warmup * Revert "fix race after wakeup before warmup" This reverts commit 94adf31.
1 parent 657d7bc commit 70d2682

3 files changed

Lines changed: 154 additions & 4 deletions

File tree

lmdeploy/pytorch/engine/engine.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ..adapter.adapter import AdapterManager
2424
from ..config import CacheConfig, ModelConfig
25-
from ..messages import SchedulerSequence, UpdateTokenMode
25+
from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode
2626
from ..paging import Scheduler
2727
from ..strategies import build_strategy_factory
2828
from .base import EngineBase
@@ -317,11 +317,18 @@ def _on_stop_session(self, reqs: list[Request], **kwargs):
317317
resp = req.data.get('response', True)
318318
resp_type = ResponseType.SESSION_NOT_EXIST
319319
if session_id in self.scheduler.sessions:
320-
self.scheduler.stop_session(session_id)
321320
session = self.scheduler.sessions[session_id]
321+
stopped_resp_ids = set()
322322
for seq in session.sequences.values():
323+
if seq.status not in (MessageStatus.STOPPED, MessageStatus.TO_BE_MIGRATED):
324+
continue
323325
_resp: Response = getattr(seq, 'resp', None)
324326
if _resp is not None:
327+
stopped_resp_ids.add(id(_resp))
328+
self.scheduler.stop_session(session_id)
329+
for seq in session.sequences.values():
330+
_resp: Response = getattr(seq, 'resp', None)
331+
if _resp is not None and id(_resp) not in stopped_resp_ids:
325332
self.req_manager.reject_request(_resp)
326333
resp_type = ResponseType.SUCCESS
327334
if resp:

lmdeploy/pytorch/engine/engine_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,12 @@ def _log_resps(outputs: list[InferOutput]):
198198
def _send_resp(self, out: InferOutput):
199199
"""Send response."""
200200
logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None)
201-
if out.resp.is_done:
201+
if out.finish:
202+
resp_type = ResponseType.FINISH
203+
elif out.resp.is_done:
202204
resp_type = out.resp.type
203205
else:
204-
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
206+
resp_type = ResponseType.SUCCESS
205207
response_reqs(self.req_manager,
206208
out.resp,
207209
resp_type,
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import asyncio
3+
from types import SimpleNamespace
4+
5+
import numpy as np
6+
7+
from lmdeploy.messages import ResponseType
8+
from lmdeploy.pytorch.engine.engine import Engine, InferOutput
9+
from lmdeploy.pytorch.engine.engine_loop import EngineLoop
10+
from lmdeploy.pytorch.engine.request import Request, RequestType, Response
11+
from lmdeploy.pytorch.messages import MessageStatus
12+
13+
14+
class FakeReqManager:
15+
16+
def __init__(self):
17+
self.rejected = []
18+
self.responses = []
19+
20+
def reject_request(self, resp, req_type=None, reason=''):
21+
self.rejected.append((resp, req_type, reason))
22+
resp.type = ResponseType.CANCEL
23+
resp.is_done = True
24+
resp.event.set()
25+
26+
def response(self, resp):
27+
self.responses.append(resp)
28+
resp.event.set()
29+
30+
31+
class FakeState:
32+
33+
def __init__(self, seq):
34+
self.seq = seq
35+
36+
def stop(self):
37+
self.seq.status = MessageStatus.STOPPED
38+
39+
40+
class FakeSeq:
41+
42+
def __init__(self, status, resp):
43+
self.status = status
44+
self.resp = resp
45+
self.state = FakeState(self)
46+
47+
48+
class FakeScheduler:
49+
50+
def __init__(self, seq):
51+
self.sessions = {1: SimpleNamespace(sequences={0: seq})}
52+
53+
def stop_session(self, session_id):
54+
for seq in self.sessions[session_id].sequences.values():
55+
seq.state.stop()
56+
57+
58+
def make_response(status=ResponseType.INTERNAL_ENGINE_ERROR):
59+
return Response(type=status, sender_id=0, event=asyncio.Event())
60+
61+
62+
def make_stop_request():
63+
return Request(
64+
type=RequestType.STOP_SESSION,
65+
sender_id=0,
66+
data=dict(session_id=1),
67+
resp=make_response(),
68+
)
69+
70+
71+
def run_stop_session(seq_status):
72+
stream_resp = make_response()
73+
seq = FakeSeq(seq_status, stream_resp)
74+
req_manager = FakeReqManager()
75+
engine = SimpleNamespace(
76+
scheduler=FakeScheduler(seq),
77+
req_manager=req_manager,
78+
_response=lambda resp, resp_type: Engine._response(engine, resp, resp_type),
79+
)
80+
req = make_stop_request()
81+
82+
Engine._on_stop_session(engine, [req])
83+
return stream_resp, req.resp, req_manager
84+
85+
86+
def test_stopped_sequence_is_not_cancelled_by_abort():
87+
stream_resp, stop_resp, req_manager = run_stop_session(MessageStatus.STOPPED)
88+
89+
assert stop_resp.type == ResponseType.SUCCESS
90+
assert stream_resp.type == ResponseType.INTERNAL_ENGINE_ERROR
91+
assert stream_resp.is_done is False
92+
assert stream_resp.event.is_set() is False
93+
assert req_manager.rejected == []
94+
95+
96+
def test_to_be_migrated_sequence_is_not_cancelled_by_abort():
97+
stream_resp, stop_resp, req_manager = run_stop_session(MessageStatus.TO_BE_MIGRATED)
98+
99+
assert stop_resp.type == ResponseType.SUCCESS
100+
assert stream_resp.type == ResponseType.INTERNAL_ENGINE_ERROR
101+
assert stream_resp.is_done is False
102+
assert stream_resp.event.is_set() is False
103+
assert req_manager.rejected == []
104+
105+
106+
def test_running_sequence_is_still_cancelled_by_abort():
107+
stream_resp, stop_resp, req_manager = run_stop_session(MessageStatus.RUNNING)
108+
109+
assert stop_resp.type == ResponseType.SUCCESS
110+
assert stream_resp.type == ResponseType.CANCEL
111+
assert stream_resp.is_done is True
112+
assert stream_resp.event.is_set() is True
113+
assert len(req_manager.rejected) == 1
114+
115+
116+
def test_finish_output_wins_over_stale_cancel_response():
117+
req_manager = FakeReqManager()
118+
loop = SimpleNamespace(req_manager=req_manager)
119+
resp = make_response(ResponseType.CANCEL)
120+
resp.is_done = True
121+
out = InferOutput(session_id=1, resp=resp, token_ids=np.array([7]), finish=True)
122+
123+
EngineLoop._send_resp(loop, out)
124+
125+
assert resp.type == ResponseType.FINISH
126+
assert resp.data['token_ids'].tolist() == [7]
127+
assert req_manager.responses == [resp]
128+
129+
130+
def test_cancel_output_without_finish_stays_cancelled():
131+
req_manager = FakeReqManager()
132+
loop = SimpleNamespace(req_manager=req_manager)
133+
resp = make_response(ResponseType.CANCEL)
134+
resp.is_done = True
135+
out = InferOutput(session_id=1, resp=resp, token_ids=np.array([]), finish=False)
136+
137+
EngineLoop._send_resp(loop, out)
138+
139+
assert resp.type == ResponseType.CANCEL
140+
assert resp.data['token_ids'].tolist() == []
141+
assert req_manager.responses == [resp]

0 commit comments

Comments
 (0)