Skip to content

Commit 3863844

Browse files
authored
pd nixl upgrade write mode to transfer kv (#1324)
1 parent 105d57f commit 3863844

13 files changed

Lines changed: 729 additions & 360 deletions

File tree

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
44
ARG PYTHON_VERSION=3.10
55
ARG MAMBA_VERSION=24.7.1-0
66
ARG VLLM_VERSION=0.21.0
7-
ARG NIXL_REF=v1.1.0
7+
ARG NIXL_REF=v1.2.0
88
ARG FLASH_MLA_REF=47c35a7
99
ARG DEEPGEMM_REF=891d57b4db1071624b5c8fa0d1e51cb317fa709f
1010
ARG TARGETPLATFORM

lightllm/common/kv_trans_kernel/nixl_kv_trans.py

Lines changed: 142 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
@triton.jit
1111
def _page_io(
1212
mem_index_ptr,
13+
token_num,
14+
page_write_head_num,
1315
k_page_ptr,
1416
k_page_stride_size,
1517
k_page_stride_layer_num,
@@ -45,88 +47,91 @@ def _page_io(
4547
k_stride_size = tl.cast(k_stride_size, dtype=tl.int64)
4648
v_stride_size = tl.cast(v_stride_size, dtype=tl.int64)
4749

48-
tid = tl.program_id(0)
49-
kv_head_id = tl.program_id(1)
50-
page_head_id = page_head_start + kv_head_id
50+
start_index = tl.program_id(0)
51+
grid_num = tl.num_programs(0)
5152

52-
mem_index = tl.load(mem_index_ptr + tid)
53-
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
54-
if NEED_MASK:
55-
mask = off_dim < head_dim
56-
else:
57-
mask = None
53+
for tid in tl.range(start_index, token_num, step=grid_num):
54+
for kv_head_id in tl.range(page_write_head_num):
5855

59-
for layer_index in tl.range(layer_num, num_stages=3):
60-
if IS_WRITE:
61-
k_tensor = tl.load(
62-
k_ptr
63-
+ layer_index * k_stride_layer_num
64-
+ mem_index * k_stride_size
65-
+ kv_head_id * k_stride_head
66-
+ off_dim * k_stride_dim,
67-
mask=mask,
68-
)
69-
v_tensor = tl.load(
70-
v_ptr
71-
+ layer_index * v_stride_layer_num
72-
+ mem_index * v_stride_size
73-
+ kv_head_id * v_stride_head
74-
+ off_dim * v_stride_dim,
75-
mask=mask,
76-
)
77-
tl.store(
78-
k_page_ptr
79-
+ tid * k_page_stride_size
80-
+ layer_index * k_page_stride_layer_num
81-
+ page_head_id * k_page_stride_head
82-
+ off_dim * k_page_stride_dim,
83-
k_tensor,
84-
mask=mask,
85-
)
86-
tl.store(
87-
v_page_ptr
88-
+ tid * v_page_stride_size
89-
+ layer_index * v_page_stride_layer_num
90-
+ page_head_id * v_page_stride_head
91-
+ off_dim * v_page_stride_dim,
92-
v_tensor,
93-
mask=mask,
94-
)
95-
else:
96-
k_page_tensor = tl.load(
97-
k_page_ptr
98-
+ tid * k_page_stride_size
99-
+ layer_index * k_page_stride_layer_num
100-
+ page_head_id * k_page_stride_head
101-
+ off_dim * k_page_stride_dim,
102-
mask=mask,
103-
)
104-
v_page_tensor = tl.load(
105-
v_page_ptr
106-
+ tid * v_page_stride_size
107-
+ layer_index * v_page_stride_layer_num
108-
+ page_head_id * v_page_stride_head
109-
+ off_dim * v_page_stride_dim,
110-
mask=mask,
111-
)
112-
tl.store(
113-
k_ptr
114-
+ layer_index * k_stride_layer_num
115-
+ mem_index * k_stride_size
116-
+ kv_head_id * k_stride_head
117-
+ off_dim * k_stride_dim,
118-
k_page_tensor,
119-
mask=mask,
120-
)
121-
tl.store(
122-
v_ptr
123-
+ layer_index * v_stride_layer_num
124-
+ mem_index * v_stride_size
125-
+ kv_head_id * v_stride_head
126-
+ off_dim * v_stride_dim,
127-
v_page_tensor,
128-
mask=mask,
129-
)
56+
page_head_id = page_head_start + kv_head_id
57+
mem_index = tl.load(mem_index_ptr + tid)
58+
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
59+
if NEED_MASK:
60+
mask = off_dim < head_dim
61+
else:
62+
mask = None
63+
64+
for layer_index in tl.range(layer_num, num_stages=3):
65+
if IS_WRITE:
66+
k_tensor = tl.load(
67+
k_ptr
68+
+ layer_index * k_stride_layer_num
69+
+ mem_index * k_stride_size
70+
+ kv_head_id * k_stride_head
71+
+ off_dim,
72+
mask=mask,
73+
)
74+
v_tensor = tl.load(
75+
v_ptr
76+
+ layer_index * v_stride_layer_num
77+
+ mem_index * v_stride_size
78+
+ kv_head_id * v_stride_head
79+
+ off_dim,
80+
mask=mask,
81+
)
82+
tl.store(
83+
k_page_ptr
84+
+ tid * k_page_stride_size
85+
+ layer_index * k_page_stride_layer_num
86+
+ page_head_id * k_page_stride_head
87+
+ off_dim,
88+
k_tensor,
89+
mask=mask,
90+
)
91+
tl.store(
92+
v_page_ptr
93+
+ tid * v_page_stride_size
94+
+ layer_index * v_page_stride_layer_num
95+
+ page_head_id * v_page_stride_head
96+
+ off_dim,
97+
v_tensor,
98+
mask=mask,
99+
)
100+
else:
101+
k_page_tensor = tl.load(
102+
k_page_ptr
103+
+ tid * k_page_stride_size
104+
+ layer_index * k_page_stride_layer_num
105+
+ page_head_id * k_page_stride_head
106+
+ off_dim,
107+
mask=mask,
108+
)
109+
v_page_tensor = tl.load(
110+
v_page_ptr
111+
+ tid * v_page_stride_size
112+
+ layer_index * v_page_stride_layer_num
113+
+ page_head_id * v_page_stride_head
114+
+ off_dim,
115+
mask=mask,
116+
)
117+
tl.store(
118+
k_ptr
119+
+ layer_index * k_stride_layer_num
120+
+ mem_index * k_stride_size
121+
+ kv_head_id * k_stride_head
122+
+ off_dim,
123+
k_page_tensor,
124+
mask=mask,
125+
)
126+
tl.store(
127+
v_ptr
128+
+ layer_index * v_stride_layer_num
129+
+ mem_index * v_stride_size
130+
+ kv_head_id * v_stride_head
131+
+ off_dim,
132+
v_page_tensor,
133+
mask=mask,
134+
)
130135
return
131136

132137

@@ -169,10 +174,17 @@ def page_io(
169174
page_head_start = tp_index * (page_write_head_num)
170175

171176
token_num = len(mem_indexes)
172-
grid = (token_num, page_write_head_num)
177+
grid = (128,)
178+
179+
assert k_page_tensor.stride(3) == 1
180+
assert v_page_tensor.stride(3) == 1
181+
assert k_buffer.stride(3) == 1
182+
assert v_buffer.stride(3) == 1
173183

174184
_page_io[grid](
175185
mem_index_ptr=mem_indexes,
186+
token_num=token_num,
187+
page_write_head_num=page_write_head_num,
176188
k_page_ptr=k_page_tensor,
177189
k_page_stride_size=k_page_tensor.stride(0),
178190
k_page_stride_layer_num=k_page_tensor.stride(1),
@@ -207,6 +219,7 @@ def page_io(
207219
@triton.jit
208220
def _mla_page_io(
209221
mem_index_ptr,
222+
token_num,
210223
page_ptr,
211224
page_stride_size,
212225
page_stride_layer_num,
@@ -227,52 +240,54 @@ def _mla_page_io(
227240
kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64)
228241
kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64)
229242

230-
tid = tl.program_id(0)
243+
start_index = tl.program_id(0)
244+
grid_num = tl.num_programs(0)
231245

232-
mem_index = tl.load(mem_index_ptr + tid)
233-
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
234-
if NEED_MASK:
235-
mask = off_dim < head_dim
236-
else:
237-
mask = None
238-
239-
for layer_index in tl.range(layer_num, num_stages=3):
240-
if IS_WRITE:
241-
kv_tensor = tl.load(
242-
kv_ptr
243-
+ layer_index * kv_stride_layer_num
244-
+ mem_index * kv_stride_size
245-
+ 0 * kv_stride_head
246-
+ off_dim * kv_stride_dim,
247-
mask=mask,
248-
)
249-
tl.store(
250-
page_ptr
251-
+ tid * page_stride_size
252-
+ layer_index * page_stride_layer_num
253-
+ 0 * page_stride_head
254-
+ off_dim * page_stride_dim,
255-
kv_tensor,
256-
mask=mask,
257-
)
246+
for tid in tl.range(start_index, token_num, step=grid_num):
247+
mem_index = tl.load(mem_index_ptr + tid)
248+
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
249+
if NEED_MASK:
250+
mask = off_dim < head_dim
258251
else:
259-
page_tensor = tl.load(
260-
page_ptr
261-
+ tid * page_stride_size
262-
+ layer_index * page_stride_layer_num
263-
+ 0 * page_stride_head
264-
+ off_dim * page_stride_dim,
265-
mask=mask,
266-
)
267-
tl.store(
268-
kv_ptr
269-
+ layer_index * kv_stride_layer_num
270-
+ mem_index * kv_stride_size
271-
+ 0 * kv_stride_head
272-
+ off_dim * kv_stride_dim,
273-
page_tensor,
274-
mask=mask,
275-
)
252+
mask = None
253+
254+
for layer_index in tl.range(layer_num, num_stages=3):
255+
if IS_WRITE:
256+
kv_tensor = tl.load(
257+
kv_ptr
258+
+ layer_index * kv_stride_layer_num
259+
+ mem_index * kv_stride_size
260+
+ 0 * kv_stride_head
261+
+ off_dim * kv_stride_dim,
262+
mask=mask,
263+
)
264+
tl.store(
265+
page_ptr
266+
+ tid * page_stride_size
267+
+ layer_index * page_stride_layer_num
268+
+ 0 * page_stride_head
269+
+ off_dim * page_stride_dim,
270+
kv_tensor,
271+
mask=mask,
272+
)
273+
else:
274+
page_tensor = tl.load(
275+
page_ptr
276+
+ tid * page_stride_size
277+
+ layer_index * page_stride_layer_num
278+
+ 0 * page_stride_head
279+
+ off_dim * page_stride_dim,
280+
mask=mask,
281+
)
282+
tl.store(
283+
kv_ptr
284+
+ layer_index * kv_stride_layer_num
285+
+ mem_index * kv_stride_size
286+
+ 0 * kv_stride_head
287+
+ off_dim * kv_stride_dim,
288+
page_tensor,
289+
mask=mask,
290+
)
276291
return
277292

278293

@@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
290305
assert page_head_num == kv_head_num == 1
291306

292307
token_num = len(mem_indexes)
293-
grid = (token_num,)
308+
grid = (64,)
294309

295310
_mla_page_io[grid](
296311
mem_index_ptr=mem_indexes,
312+
token_num=token_num,
297313
page_ptr=page_tensor,
298314
page_stride_size=page_tensor.stride(0),
299315
page_stride_layer_num=page_tensor.stride(1),

lightllm/server/httpserver/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ async def generate(
380380
pickle.dumps((ObjType.NIXL_UPLOAD_NP_PROMPT_IDS, group_request_id, prompt_ids))
381381
)
382382
try:
383-
await asyncio.wait_for(nixl_pd_event.wait(), timeout=80)
383+
await asyncio.wait_for(nixl_pd_event.wait(), timeout=180)
384384
except asyncio.TimeoutError:
385-
logger.error(f"nixl np node wait nixl_pd_event 36s time out, group_req_id {group_request_id}")
385+
logger.error(f"nixl np node wait nixl_pd_event 180s time out, group_req_id {group_request_id}")
386386
raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out")
387387

388388
decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ async def fetch_nixl_stream(
331331
)
332332

333333
try:
334-
await asyncio.wait_for(up_status_event.wait(), timeout=60)
334+
await asyncio.wait_for(up_status_event.wait(), timeout=180)
335335
except asyncio.TimeoutError:
336336
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
337337
raise ServerBusyError()

lightllm/server/pd_io_struct.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ class NIXLChunckedTransTask:
273273
first_gen_token_id: Optional[int]
274274
first_gen_token_logprob: Optional[float]
275275

276+
nixl_write_stage: Optional[str] = None
277+
276278
# transfer params
277279
nixl_src_page_index: Optional[int] = None
278280
nixl_dst_page_index: Optional[int] = None
@@ -284,6 +286,7 @@ class NIXLChunckedTransTask:
284286
start_trans_time: float = None # 用于标记传输开始的时间。同时标记是否正在传输中
285287

286288
error_info: Optional[str] = None
289+
transfer_time_out_secs: int = 66
287290

288291
def __post_init__(self):
289292
if self.start_kv_index < 0 or self.end_kv_index < self.start_kv_index:
@@ -300,7 +303,7 @@ def time_out(self) -> bool:
300303
return True
301304
return False
302305
else:
303-
if time.time() - self.start_trans_time > self.time_out_secs + 88:
306+
if time.time() - self.start_trans_time > self.transfer_time_out_secs:
304307
return True
305308
else:
306309
return False

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,17 @@ def _create_nixl_trans_task(
178178
):
179179
# 确定传输设备
180180
if req_obj.nixl_trans_device_id == -1:
181+
if not hasattr(self, "nixl_iter_device_id"):
182+
self.nixl_iter_device_id = 0
183+
req_obj.nixl_trans_device_id = self.nixl_iter_device_id
181184
# only self.is_master_in_dp will be used.
182-
req_obj.nixl_trans_device_id = random.randint(0, self.node_world_size - 1)
185+
self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size
183186

184187
trans_task = NIXLChunckedTransTask(
185188
request_id=req_obj.req_id,
186189
start_kv_index=kv_start_index,
187190
end_kv_index=kv_end_index,
188-
time_out_secs=80,
191+
time_out_secs=180,
189192
pd_master_node_id=req_obj.sampling_param.pd_master_node_id,
190193
prefill_dp_index=None,
191194
decode_dp_index=self.dp_rank_in_node,

0 commit comments

Comments
 (0)