Skip to content

Commit 01ec36b

Browse files
committed
fix
1 parent 43b0487 commit 01ec36b

5 files changed

Lines changed: 174 additions & 128 deletions

File tree

lightllm/common/kv_trans_kernel/nixl_kv_trans.py

Lines changed: 142 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
@triton.jit
1111
def _page_io(
1212
mem_index_ptr,
13+
token_num,
14+
page_write_head_num,
1315
k_page_ptr,
1416
k_page_stride_size,
1517
k_page_stride_layer_num,
@@ -45,88 +47,91 @@ def _page_io(
4547
k_stride_size = tl.cast(k_stride_size, dtype=tl.int64)
4648
v_stride_size = tl.cast(v_stride_size, dtype=tl.int64)
4749

48-
tid = tl.program_id(0)
49-
kv_head_id = tl.program_id(1)
50-
page_head_id = page_head_start + kv_head_id
50+
start_index = tl.program_id(0)
51+
grid_num = tl.num_programs(0)
5152

52-
mem_index = tl.load(mem_index_ptr + tid)
53-
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
54-
if NEED_MASK:
55-
mask = off_dim < head_dim
56-
else:
57-
mask = None
53+
for tid in tl.range(start_index, token_num, step=grid_num):
54+
for kv_head_id in tl.range(page_write_head_num):
5855

59-
for layer_index in tl.range(layer_num, num_stages=3):
60-
if IS_WRITE:
61-
k_tensor = tl.load(
62-
k_ptr
63-
+ layer_index * k_stride_layer_num
64-
+ mem_index * k_stride_size
65-
+ kv_head_id * k_stride_head
66-
+ off_dim * k_stride_dim,
67-
mask=mask,
68-
)
69-
v_tensor = tl.load(
70-
v_ptr
71-
+ layer_index * v_stride_layer_num
72-
+ mem_index * v_stride_size
73-
+ kv_head_id * v_stride_head
74-
+ off_dim * v_stride_dim,
75-
mask=mask,
76-
)
77-
tl.store(
78-
k_page_ptr
79-
+ tid * k_page_stride_size
80-
+ layer_index * k_page_stride_layer_num
81-
+ page_head_id * k_page_stride_head
82-
+ off_dim * k_page_stride_dim,
83-
k_tensor,
84-
mask=mask,
85-
)
86-
tl.store(
87-
v_page_ptr
88-
+ tid * v_page_stride_size
89-
+ layer_index * v_page_stride_layer_num
90-
+ page_head_id * v_page_stride_head
91-
+ off_dim * v_page_stride_dim,
92-
v_tensor,
93-
mask=mask,
94-
)
95-
else:
96-
k_page_tensor = tl.load(
97-
k_page_ptr
98-
+ tid * k_page_stride_size
99-
+ layer_index * k_page_stride_layer_num
100-
+ page_head_id * k_page_stride_head
101-
+ off_dim * k_page_stride_dim,
102-
mask=mask,
103-
)
104-
v_page_tensor = tl.load(
105-
v_page_ptr
106-
+ tid * v_page_stride_size
107-
+ layer_index * v_page_stride_layer_num
108-
+ page_head_id * v_page_stride_head
109-
+ off_dim * v_page_stride_dim,
110-
mask=mask,
111-
)
112-
tl.store(
113-
k_ptr
114-
+ layer_index * k_stride_layer_num
115-
+ mem_index * k_stride_size
116-
+ kv_head_id * k_stride_head
117-
+ off_dim * k_stride_dim,
118-
k_page_tensor,
119-
mask=mask,
120-
)
121-
tl.store(
122-
v_ptr
123-
+ layer_index * v_stride_layer_num
124-
+ mem_index * v_stride_size
125-
+ kv_head_id * v_stride_head
126-
+ off_dim * v_stride_dim,
127-
v_page_tensor,
128-
mask=mask,
129-
)
56+
page_head_id = page_head_start + kv_head_id
57+
mem_index = tl.load(mem_index_ptr + tid)
58+
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
59+
if NEED_MASK:
60+
mask = off_dim < head_dim
61+
else:
62+
mask = None
63+
64+
for layer_index in tl.range(layer_num, num_stages=3):
65+
if IS_WRITE:
66+
k_tensor = tl.load(
67+
k_ptr
68+
+ layer_index * k_stride_layer_num
69+
+ mem_index * k_stride_size
70+
+ kv_head_id * k_stride_head
71+
+ off_dim,
72+
mask=mask,
73+
)
74+
v_tensor = tl.load(
75+
v_ptr
76+
+ layer_index * v_stride_layer_num
77+
+ mem_index * v_stride_size
78+
+ kv_head_id * v_stride_head
79+
+ off_dim,
80+
mask=mask,
81+
)
82+
tl.store(
83+
k_page_ptr
84+
+ tid * k_page_stride_size
85+
+ layer_index * k_page_stride_layer_num
86+
+ page_head_id * k_page_stride_head
87+
+ off_dim,
88+
k_tensor,
89+
mask=mask,
90+
)
91+
tl.store(
92+
v_page_ptr
93+
+ tid * v_page_stride_size
94+
+ layer_index * v_page_stride_layer_num
95+
+ page_head_id * v_page_stride_head
96+
+ off_dim,
97+
v_tensor,
98+
mask=mask,
99+
)
100+
else:
101+
k_page_tensor = tl.load(
102+
k_page_ptr
103+
+ tid * k_page_stride_size
104+
+ layer_index * k_page_stride_layer_num
105+
+ page_head_id * k_page_stride_head
106+
+ off_dim,
107+
mask=mask,
108+
)
109+
v_page_tensor = tl.load(
110+
v_page_ptr
111+
+ tid * v_page_stride_size
112+
+ layer_index * v_page_stride_layer_num
113+
+ page_head_id * v_page_stride_head
114+
+ off_dim,
115+
mask=mask,
116+
)
117+
tl.store(
118+
k_ptr
119+
+ layer_index * k_stride_layer_num
120+
+ mem_index * k_stride_size
121+
+ kv_head_id * k_stride_head
122+
+ off_dim,
123+
k_page_tensor,
124+
mask=mask,
125+
)
126+
tl.store(
127+
v_ptr
128+
+ layer_index * v_stride_layer_num
129+
+ mem_index * v_stride_size
130+
+ kv_head_id * v_stride_head
131+
+ off_dim,
132+
v_page_tensor,
133+
mask=mask,
134+
)
130135
return
131136

132137

@@ -169,10 +174,17 @@ def page_io(
169174
page_head_start = tp_index * (page_write_head_num)
170175

171176
token_num = len(mem_indexes)
172-
grid = (token_num, page_write_head_num)
177+
grid = (128,)
178+
179+
assert k_page_tensor.stride(3) == 1
180+
assert v_page_tensor.stride(3) == 1
181+
assert k_buffer.stride(3) == 1
182+
assert v_buffer.stride(3) == 1
173183

174184
_page_io[grid](
175185
mem_index_ptr=mem_indexes,
186+
token_num=token_num,
187+
page_write_head_num=page_write_head_num,
176188
k_page_ptr=k_page_tensor,
177189
k_page_stride_size=k_page_tensor.stride(0),
178190
k_page_stride_layer_num=k_page_tensor.stride(1),
@@ -207,6 +219,7 @@ def page_io(
207219
@triton.jit
208220
def _mla_page_io(
209221
mem_index_ptr,
222+
token_num,
210223
page_ptr,
211224
page_stride_size,
212225
page_stride_layer_num,
@@ -227,52 +240,54 @@ def _mla_page_io(
227240
kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64)
228241
kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64)
229242

230-
tid = tl.program_id(0)
243+
start_index = tl.program_id(0)
244+
grid_num = tl.num_programs(0)
231245

232-
mem_index = tl.load(mem_index_ptr + tid)
233-
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
234-
if NEED_MASK:
235-
mask = off_dim < head_dim
236-
else:
237-
mask = None
238-
239-
for layer_index in tl.range(layer_num, num_stages=3):
240-
if IS_WRITE:
241-
kv_tensor = tl.load(
242-
kv_ptr
243-
+ layer_index * kv_stride_layer_num
244-
+ mem_index * kv_stride_size
245-
+ 0 * kv_stride_head
246-
+ off_dim * kv_stride_dim,
247-
mask=mask,
248-
)
249-
tl.store(
250-
page_ptr
251-
+ tid * page_stride_size
252-
+ layer_index * page_stride_layer_num
253-
+ 0 * page_stride_head
254-
+ off_dim * page_stride_dim,
255-
kv_tensor,
256-
mask=mask,
257-
)
246+
for tid in tl.range(start_index, token_num, step=grid_num):
247+
mem_index = tl.load(mem_index_ptr + tid)
248+
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
249+
if NEED_MASK:
250+
mask = off_dim < head_dim
258251
else:
259-
page_tensor = tl.load(
260-
page_ptr
261-
+ tid * page_stride_size
262-
+ layer_index * page_stride_layer_num
263-
+ 0 * page_stride_head
264-
+ off_dim * page_stride_dim,
265-
mask=mask,
266-
)
267-
tl.store(
268-
kv_ptr
269-
+ layer_index * kv_stride_layer_num
270-
+ mem_index * kv_stride_size
271-
+ 0 * kv_stride_head
272-
+ off_dim * kv_stride_dim,
273-
page_tensor,
274-
mask=mask,
275-
)
252+
mask = None
253+
254+
for layer_index in tl.range(layer_num, num_stages=3):
255+
if IS_WRITE:
256+
kv_tensor = tl.load(
257+
kv_ptr
258+
+ layer_index * kv_stride_layer_num
259+
+ mem_index * kv_stride_size
260+
+ 0 * kv_stride_head
261+
+ off_dim * kv_stride_dim,
262+
mask=mask,
263+
)
264+
tl.store(
265+
page_ptr
266+
+ tid * page_stride_size
267+
+ layer_index * page_stride_layer_num
268+
+ 0 * page_stride_head
269+
+ off_dim * page_stride_dim,
270+
kv_tensor,
271+
mask=mask,
272+
)
273+
else:
274+
page_tensor = tl.load(
275+
page_ptr
276+
+ tid * page_stride_size
277+
+ layer_index * page_stride_layer_num
278+
+ 0 * page_stride_head
279+
+ off_dim * page_stride_dim,
280+
mask=mask,
281+
)
282+
tl.store(
283+
kv_ptr
284+
+ layer_index * kv_stride_layer_num
285+
+ mem_index * kv_stride_size
286+
+ 0 * kv_stride_head
287+
+ off_dim * kv_stride_dim,
288+
page_tensor,
289+
mask=mask,
290+
)
276291
return
277292

278293

@@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
290305
assert page_head_num == kv_head_num == 1
291306

292307
token_num = len(mem_indexes)
293-
grid = (token_num,)
308+
grid = (64,)
294309

295310
_mla_page_io[grid](
296311
mem_index_ptr=mem_indexes,
312+
token_num=token_num,
297313
page_ptr=page_tensor,
298314
page_stride_size=page_tensor.stride(0),
299315
page_stride_layer_num=page_tensor.stride(1),

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _create_nixl_trans_task(
185185
request_id=req_obj.req_id,
186186
start_kv_index=kv_start_index,
187187
end_kv_index=kv_end_index,
188-
time_out_secs=80,
188+
time_out_secs=180,
189189
pd_master_node_id=req_obj.sampling_param.pd_master_node_id,
190190
prefill_dp_index=None,
191191
decode_dp_index=self.dp_rank_in_node,

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ def _init_env(
4949
task_out_queue: mp.Queue,
5050
up_status_in_queue: Optional[mp.SimpleQueue],
5151
):
52+
import os
53+
54+
# -------------------------------------------------------------------------
55+
# 问题背景(PD NIXL + 同卡多进程):
56+
# decode 物理 GPU 上至少有两个独立 CUDA 进程:model_infer(解码推理)与
57+
# nixl_decode_trans(把 prefill 侧 KV page 拷入 decode KV cache)。
58+
# lm_eval batch=64 时会在短时间内并发大量 read_page;拷贝在 copy_cuda_stream
59+
# 上排队,而推理在另一进程的 stream 上执行,彼此无法 cudaStreamWaitEvent
60+
# 协调。日志里的 read_page_gpu_time(event 差值)会把「等 GPU 时间片 /
61+
# 与推理争抢 SM」算进去,出现数十秒级毛刺,但并不代表单次 memcpy 真那么慢。
62+
#
63+
# 解决思路:依赖 NVIDIA MPS(Multi-Process Service)在同一 GPU 上多进程
64+
# 共享上下文并做客户端级调度;在子进程 import torch / 创建 CUDA 上下文
65+
# **之前**设置下列环境变量(故必须放在本函数最前)。
66+
#
67+
# CUDA_MPS_CLIENT_PRIORITY="0":
68+
# MPS 下数值越小优先级越高。decode 侧 KV 拷贝处于 decode 关键路径(须先
69+
# 落盘 KV 才能出首 token),故给 trans 进程最高优先级,减轻被同卡推理
70+
# 饿死导致的排队放大。须集群已启动 nvidia-cuda-mps-control / mps-server,
71+
# 否则该变量不生效。 启动 mps 的命令为 nvidia-cuda-mps-control -d
72+
# -------------------------------------------------------------------------
73+
os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "0"
74+
5275
torch.backends.cudnn.enabled = False
5376
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_decode_trans:Device{device_id}")
5477

lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _create_nixl_trans_task(
102102
request_id=req_obj.req_id,
103103
start_kv_index=kv_start_index,
104104
end_kv_index=kv_end_index,
105-
time_out_secs=82,
105+
time_out_secs=182,
106106
pd_master_node_id=req_obj.sampling_param.pd_master_node_id,
107107
prefill_dp_index=self.dp_rank_in_node,
108108
decode_dp_index=None,

lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def _init_env(
4242
task_in_queue: mp.Queue,
4343
task_out_queue: mp.Queue,
4444
):
45+
46+
import os
47+
48+
# prefill 节点不一定需要 mps 来协调,所以优先级设置为 1.
49+
# 本身并不产生严重的阻塞。
50+
os.environ["CUDA_MPS_CLIENT_PRIORITY"] = "1"
51+
4552
torch.backends.cudnn.enabled = False
4653
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::nixl_prefill_trans:Device{device_id}")
4754

0 commit comments

Comments
 (0)