-
Notifications
You must be signed in to change notification settings - Fork 318
Expand file tree
/
Copy pathpd_loop.py
More file actions
256 lines (216 loc) · 10.7 KB
/
pd_loop.py
File metadata and controls
256 lines (216 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import asyncio
import pickle
import websockets
import ujson as json
import socket
import httpx
import base64
import weakref
from typing import Dict, Optional, Union, List
from websockets import ClientConnection
from lightllm.server.pd_io_struct import NodeRole, ObjType
from lightllm.server.httpserver.async_queue import AsyncQueue
from lightllm.utils.net_utils import get_hostname_ip
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_lightllm_websocket_max_message_size
from lightllm.server.httpserver.manager import HttpServerManager
from ..pd_io_struct import PD_Master_Obj
from lightllm.server.core.objs import StartArgs
from lightllm.server.core.objs import SamplingParams
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
logger = init_logger(__name__)
async def timer_log(manager: HttpServerManager):
while True:
await asyncio.sleep(30)
manager.first_time_costs.print_log("mean first cost")
manager.per_token_costs.print_log("mean per token cost")
return
async def pd_handle_loop(manager: HttpServerManager):
assert manager.args.host not in ["127.0.0.1", "localhost"], "pd mode must specify host ip"
if manager.args.host in ["0.0.0.0"]:
manager.host_ip = get_hostname_ip()
else:
manager.host_ip = manager.args.host
asyncio.create_task(timer_log(manager))
id_to_handle_task: Dict[int, asyncio.Task] = {}
while True:
try:
id_to_pd_master_obj = await _get_pd_master_objs(manager.args)
logger.info(f"get pd_master_objs {id_to_pd_master_obj}")
if id_to_pd_master_obj is not None:
for node_id, pd_master_obj in id_to_handle_task.items():
if node_id not in id_to_pd_master_obj:
id_to_handle_task[node_id].cancel()
id_to_handle_task.pop(node_id, None)
logger.info(f"pd_handle_task {pd_master_obj} cancelled")
for node_id, pd_master_obj in id_to_pd_master_obj.items():
if node_id not in id_to_handle_task:
id_to_handle_task[node_id] = asyncio.create_task(_pd_handle_task(manager, pd_master_obj))
await asyncio.sleep(30)
except Exception as e:
logger.exception(str(e))
await asyncio.sleep(10)
async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_Obj):
"""
pd_handle_loop 主要负责与 pd master 进行注册连接,然后接收pd master发来的请求,然后
将推理结果转发给 pd master进行处理。
"""
# 创建转发队列
forwarding_queue = AsyncQueue()
while True:
forwarding_tokens_task = None
try:
uri = f"ws://{pd_master_obj.host_ip_port}/pd_register"
async with websockets.connect(
uri, max_size=get_lightllm_websocket_max_message_size(), max_queue=(2048 * 1024, 2048 * 1023) # 关键修改
) as websocket:
sock = websocket.transport.get_extra_info("socket")
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
args_dict = vars(manager.args)
args_dict["host"] = manager.host_ip
# 发送注册信息
regist_json = {
"node_id": manager.args.pd_node_id,
"client_ip_port": f"{manager.host_ip}:{manager.args.port}",
"mode": manager.pd_mode.value,
"start_args": args_dict,
}
await websocket.send(json.dumps(regist_json))
logger.info(f"Sent registration JSON: {regist_json}")
# 转发任务
forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket))
group_req_id_to_event: Dict[int, asyncio.Event] = weakref.WeakValueDictionary()
# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
while True:
recv_bytes = await websocket.recv()
obj = json.loads(recv_bytes.decode())
if obj[0] == ObjType.REQ:
prompt, sampling_params, multimodal_params = obj[1]
group_req_id = sampling_params.group_request_id
nixl_pd_event = asyncio.Event()
group_req_id_to_event[group_req_id] = nixl_pd_event
asyncio.create_task(
_pd_process_generate(
manager=manager,
prompt=prompt,
sampling_params=sampling_params,
multimodal_params=multimodal_params,
forwarding_queue=forwarding_queue,
nixl_pd_upload_websocket=websocket,
nixl_pd_event=nixl_pd_event,
)
)
elif obj[0] == ObjType.ABORT:
group_req_id = obj[1]
logger.warning(f"recv cmd aborted req id {group_req_id}")
if not (await manager.abort(group_req_id)):
async def delayed_abort_task(group_req_id, retry_count):
for _ in range(retry_count):
await asyncio.sleep(5.0)
if await manager.abort(group_req_id):
break
asyncio.create_task(delayed_abort_task(group_req_id=group_req_id, retry_count=4))
elif obj[0] == ObjType.NIXL_REQ_DECODE_NODE_INFO:
_, group_req_id, decode_node_info = obj
nixl_pd_event = group_req_id_to_event.pop(group_req_id, None)
if nixl_pd_event is None:
logger.error(f"error in find nixl_pd_event, info: {obj}")
continue
nixl_pd_event.decode_node_info = decode_node_info
nixl_pd_event.set()
else:
logger.error(f"recevie error obj {str(obj)}")
except asyncio.CancelledError:
# 如果任务被取消,则退出循环
logger.warning(f"forwarding_tokens_task {pd_master_obj} cancelled")
if forwarding_tokens_task is not None:
forwarding_tokens_task.cancel()
return
except Exception as e:
logger.error("connetion to pd_master has error")
logger.exception(str(e))
if forwarding_tokens_task is not None:
forwarding_tokens_task.cancel()
await asyncio.sleep(10)
await forwarding_queue.get_all_data()
logger.info("reconnection to pd_master")
async def _get_pd_master_objs(args: StartArgs) -> Optional[Dict[int, PD_Master_Obj]]:
"""
get_pd_master_objs 主要负责从 pd master 获取所有的pd master对象。
"""
use_config_server = args.config_server_host and args.config_server_port
# 如果不使用config_server服务来发现所有的 pd_master, 则需要使用启动参数中的
# --pd_master_ip 和--pd_master_port 设置的唯一pd_master来进行连接, 其默认
# node_id 为 0
if not use_config_server:
ans = dict()
ans[0] = PD_Master_Obj(node_id=0, host_ip_port=f"{args.pd_master_ip}:{args.pd_master_port}")
return ans
# 使用 config_server 服务来发现所有的 pd_master 节点。
uri = f"ws://{args.config_server_host}:{args.config_server_port}/registered_objects"
try:
async with httpx.AsyncClient() as client:
response = await client.get(uri)
if response.status_code == 200:
base64data = response.json()["data"]
id_to_pd_master_obj = json.loads(base64.b64decode(base64data).decode())
return id_to_pd_master_obj
else:
logger.error(f"get pd_master_objs error {response.status_code}")
return None
except Exception as e:
logger.exception(str(e))
await asyncio.sleep(10)
return None
# 触发推理的task
async def _pd_process_generate(
manager: HttpServerManager,
prompt: Union[str, List[int]],
sampling_params: SamplingParams,
multimodal_params: Dict,
forwarding_queue: AsyncQueue,
nixl_pd_upload_websocket: ClientConnection,
nixl_pd_event: asyncio.Event,
):
try:
async for sub_req_id, request_output, metadata, finish_status in manager.generate(
prompt=prompt,
sampling_params=sampling_params,
multimodal_params=multimodal_params,
request=None,
nixl_pd_upload_websocket=nixl_pd_upload_websocket,
nixl_pd_event=nixl_pd_event,
):
# p d 模式下,将 token 数据放入到转发队列中, 请求id 小于0的请求是health探测请求,不用转发。
is_health_check_req = sub_req_id < 0
if not is_health_check_req:
metadata["node_mode"] = manager.args.run_mode
await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status))
except NixlPrefillNodeStopGenToken as e:
logger.info(f"nixl prefill node stop gen token for group_request_id {e.group_request_id}")
except BaseException as e:
logger.error(str(e))
# 转发token的task
async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket: ClientConnection):
while True:
handle_list = await forwarding_queue.wait_to_get_all_data()
if handle_list:
load_info: dict = _get_load_info()
await websocket.send(json.dumps((ObjType.TOKEN_PACKS, handle_list, load_info)).encode())
# 获取节点负载信息
def _get_load_info() -> dict:
from lightllm.server.api_http import g_objs
assert g_objs.shared_token_load is not None, "shared_token_load is not initialized"
args = g_objs.args
dp_size_in_node = max(1, args.dp // args.nnodes)
# 获取当前每个 dp 的负载,数值含义为当前的 token 总容量使用率, 上报给 PD_Master 用于做
# 调度决策。
current_load = [
float(g_objs.shared_token_load.get_dynamic_max_load(dp_index)) for dp_index in range(dp_size_in_node)
]
mean_node_load = sum(current_load) / len(current_load)
load_info = {
"total_token_usage_rate": mean_node_load,
"client_ip_port": f"{g_objs.httpserver_manager.host_ip}:{g_objs.args.port}",
}
return load_info