Skip to content

Commit 5514e24

Browse files
authored
fix prefill_params when prefill num_reqs > 1024 (#1336)
1 parent 3863844 commit 5514e24

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

lightllm/common/basemodel/triton_kernel/gen_prefill_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _gen_cumsum_pad0_kernel(
1919

2020
for start_index in range(0, size, BLOCK):
2121
current_offs = start_index + offs
22-
in_data = tl.load(b_q_seq_len + offs, mask=current_offs < size, other=0)
22+
in_data = tl.load(b_q_seq_len + current_offs, mask=current_offs < size, other=0)
2323
in_data = tl.cumsum(in_data) + start_value
2424
start_value = tl.max(in_data, 0)
2525
tl.store(b1_cu_q_seq_len + current_offs + 1, in_data, mask=current_offs < size)
@@ -30,7 +30,7 @@ def _gen_cumsum_pad0_kernel(
3030
start_value = tl.cast(0, tl.int64)
3131
for start_index in range(0, size, BLOCK):
3232
current_offs = start_index + offs
33-
in_data = tl.load(b_kv_seq_len + offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0)
33+
in_data = tl.load(b_kv_seq_len + current_offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0)
3434
in_data = tl.cumsum(in_data) + start_value
3535
start_value = tl.max(in_data, 0)
3636
tl.store(b1_cu_kv_seq_len + current_offs + 1, in_data, mask=current_offs < size)

0 commit comments

Comments
 (0)