-
Notifications
You must be signed in to change notification settings - Fork 692
Expand file tree
/
Copy pathbase_worker.py
More file actions
157 lines (123 loc) · 5.09 KB
/
base_worker.py
File metadata and controls
157 lines (123 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
from lmdeploy.messages import EngineOutput
from lmdeploy.pytorch.disagg.conn.protocol import (
DistServeConnectionRequest,
DistServeDropConnectionRequest,
DistServeInitRequest,
)
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
if TYPE_CHECKING:
from lmdeploy.pytorch.engine.engine import Engine
class EngineInstancePool:
"""Engine Instance Pool."""
def __init__(self, engine):
from lmdeploy.pytorch.engine import Engine
self.engine: Engine = engine
# enlarge `num_instance`, otherwise an sequence cannot be stopped in time
self.num_instance = self.engine.engine_config.max_batch_size * 2
self.pool = None
def create_instance_pool(self, num_instance: int):
"""Create instance pool."""
pool = asyncio.Queue(maxsize=num_instance)
for _ in range(num_instance):
instance = self.engine.create_instance()
pool.put_nowait(instance)
return pool
@asynccontextmanager
async def instance(self):
"""Get an instance from the pool."""
# lazy create pool
if self.pool is None:
self.pool = self.create_instance_pool(self.num_instance)
instance = await self.pool.get()
try:
yield instance
finally:
self.pool.put_nowait(instance)
async def async_end(self, session_id: int):
"""End the given session."""
async with self.instance() as instance:
return await instance.async_end(session_id)
async def async_cancel(self, session_id: int):
"""Stop current streaming inference."""
async with self.instance() as instance:
return await instance.async_cancel(session_id)
async def async_stream_infer(self, *args, **kwargs):
"""Send stream inference request."""
async with self.instance() as instance:
async for result in instance.async_stream_infer(*args, **kwargs):
yield result
class EngineWorkerBase:
"""Base class for engine worker."""
def __init__(self, engine: 'Engine'):
engine.start_loop()
self.engine = engine
self.instance_pool = EngineInstancePool(engine)
def end_session(self, session_id: int):
"""End session."""
return self.engine.end_session(session_id)
def get_engine_config(self):
"""Get engine config."""
return self.engine.get_engine_config()
def get_schedule_metrics(self):
"""Get schedule metrics."""
return self.engine.get_schedule_metrics()
def p2p_initialize(self, conn_request: DistServeInitRequest):
"""Init rdma link."""
return self.engine.p2p_initialize(conn_request)
def p2p_connect(self, conn_request: DistServeConnectionRequest):
"""rdma_connect."""
return self.engine.p2p_connect(conn_request)
def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
"""Drop connection.
1. drop engine connection (zmq connection)
2. TODO(JimyMa) drop RDMA Connection.
"""
return self.engine.p2p_drop_connect(drop_conn_request)
async def sleep(self, level: int = 1):
"""sleep."""
return await self.engine.sleep(level)
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return await self.engine.wakeup(tags)
def update_params(self, request: Any):
"""Update params."""
return self.engine.update_params(request)
def close(self) -> None:
"""Close engine worker."""
self.engine.close()
async def instance_async_end(self, session_id: int):
"""End the given session."""
return await self.instance_pool.async_end(session_id)
async def instance_async_cancel(self, session_id: int):
"""Stop current streaming inference."""
return await self.instance_pool.async_cancel(session_id)
async def instance_async_stream_infer(self, *args, **kwargs):
"""Send stream inference request."""
async for result in self.instance_pool.async_stream_infer(*args, **kwargs):
yield result
class EngineOutputGather:
"""Helper class to gather incremental engine output."""
def __init__(self):
self._output = dict()
def get(self, stream_id):
if stream_id not in self._output:
self._output[stream_id] = EngineOutput(status=None, token_ids=[], logprobs=[])
return self._output[stream_id]
def add(self, stream_id, result):
if not isinstance(result, EngineOutput):
return
output = self.get(stream_id)
output.token_ids.extend(result.token_ids or [])
output.logprobs.extend(result.logprobs or [])
def pop(self, stream_id, result):
if not isinstance(result, EngineOutput):
return result
output = self._output.pop(stream_id)
result.token_ids = output.token_ids or []
result.logprobs = output.logprobs or None
return result