Skip to content

Commit cac2edf

Browse files
committed
neo moe inferece speedup
1 parent 8f8ed44 commit cac2edf

5 files changed

Lines changed: 44 additions & 43 deletions

File tree

lightllm/models/neo_chat_moe/infer_struct.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def __init__(self):
2020
def init_some_extra_state(self, model: LlamaTpPartModel):
2121
LlamaInferStateInfo.init_some_extra_state(self, model)
2222
if self.is_prefill:
23+
self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
24+
non_blocking=True
25+
)
2326
self.position_ids = self.get_neo_position(self.multimodal_params)
2427
else:
2528
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
@@ -95,5 +98,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
9598
b_ready_cache_len=self.b_ready_cache_len,
9699
b_q_seq_len=self.b_q_seq_len,
97100
b_start_loc=self.b_q_start_loc,
101+
b_image_token_tag=self.b_image_token_tag,
98102
)
99103
return position_ids

lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def _context_attention_kernel(
182182
infer_state.b_ready_cache_len,
183183
infer_state.max_q_seq_len,
184184
infer_state.req_manager.req_to_token_indexs,
185+
infer_state.b_image_token_tag,
185186
)
186187
o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
187188
o3 = o3[:, :, : self.head_dim_].contiguous()

lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def _fwd_kernel(
3434
stride_req_to_tokens_s,
3535
kv_group_num,
3636
b_prompt_cache_len,
37+
b_image_token_tag,
3738
H: tl.constexpr,
38-
BLOCK_DMODEL: tl.constexpr,
39+
QK_HEAD_DIM: tl.constexpr,
40+
V_HEAD_DIM: tl.constexpr,
3941
BLOCK_M: tl.constexpr,
4042
BLOCK_N: tl.constexpr,
4143
):
@@ -53,16 +55,19 @@ def _fwd_kernel(
5355
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
5456

5557
block_start_loc = BLOCK_M * start_m
58+
if block_start_loc >= cur_batch_seq_len:
59+
return
5660

5761
offs_n = tl.arange(0, BLOCK_N)
58-
offs_d = tl.arange(0, BLOCK_DMODEL)
62+
offs_d_qk = tl.arange(0, QK_HEAD_DIM)
63+
offs_d_v = tl.arange(0, V_HEAD_DIM)
5964
offs_m = block_start_loc + tl.arange(0, BLOCK_M)
6065

6166
# Q pointers
6267
off_q = (
6368
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
6469
+ cur_head * stride_qh
65-
+ offs_d[None, :] * stride_qd
70+
+ offs_d_qk[None, :] * stride_qd
6671
)
6772

6873
q_valid = offs_m < cur_batch_seq_len
@@ -71,24 +76,14 @@ def _fwd_kernel(
7176
# online softmax state
7277
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
7378
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
74-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
75-
76-
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
79+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
7780
block_end_loc = total_len
7881

7982
# absolute q positions in the request
8083
q_pos = prompt_cache_len + offs_m # [M]
84+
q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False)
8185

82-
# q_gid from packed position_ids (aligned with Q rows)
83-
q_gid = tl.load(
84-
position_ids + cur_batch_in_all_start_index + offs_m,
85-
mask=q_valid,
86-
other=-2147483648,
87-
).to(tl.int32)
88-
89-
BIG = tl.full([BLOCK_N], 1000000000, tl.int32) # ensure != any normal gid
90-
91-
for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
86+
for start_n in range(0, block_end_loc, BLOCK_N):
9287
start_n = tl.multiple_of(start_n, BLOCK_N)
9388

9489
k_pos = start_n + offs_n # [N]
@@ -102,32 +97,13 @@ def _fwd_kernel(
10297
).to(tl.int64)
10398

10499
# load K
105-
off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
100+
off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd
106101
k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0)
107-
108-
qk = tl.dot(q, k)
109-
110-
# k_gid:
111-
# - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false
112-
# - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len)
113-
k_in_new = k_pos >= prompt_cache_len
114-
k_new_idx = (k_pos - prompt_cache_len).to(tl.int32) # [N] valid only when k_in_new
115-
k_gid_new = tl.load(
116-
position_ids + cur_batch_in_all_start_index + k_new_idx,
117-
mask=k_valid & k_in_new,
118-
other=-2147483647,
119-
).to(tl.int32)
120-
121-
k_gid = tl.where(
122-
k_in_new,
123-
k_gid_new,
124-
(k_pos.to(tl.int32) + BIG),
125-
)
102+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
103+
qk += tl.dot(q, k)
126104

127105
# mask: causal OR same gid (only possible inside NEW part)
128-
mask = (q_pos[:, None] >= k_pos[None, :]) | (q_gid[:, None] == k_gid[None, :])
129-
mask = mask & q_valid[:, None] & k_valid[None, :]
130-
106+
mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]
131107
qk = tl.where(mask, qk * sm_scale, -1.0e8)
132108

133109
# online softmax
@@ -141,7 +117,7 @@ def _fwd_kernel(
141117
acc = acc * alpha[:, None]
142118

143119
# load V
144-
off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
120+
off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd
145121
v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0)
146122

147123
p = p.to(v.dtype)
@@ -154,7 +130,7 @@ def _fwd_kernel(
154130
off_o = (
155131
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
156132
+ cur_head * stride_oh
157-
+ offs_d[None, :] * stride_od
133+
+ offs_d_v[None, :] * stride_od
158134
)
159135
tl.store(Out + off_o, acc, mask=q_valid[:, None])
160136

@@ -172,6 +148,7 @@ def context_attention_fwd_neo(
172148
b_prompt_cache_len,
173149
max_input_len,
174150
req_to_token_indexs,
151+
b_image_token_tag,
175152
):
176153
# minimal safety: position_ids must cover packed q rows
177154
assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
@@ -220,8 +197,10 @@ def context_attention_fwd_neo(
220197
req_to_token_indexs.stride(1),
221198
kv_group_num=kv_group_num,
222199
b_prompt_cache_len=b_prompt_cache_len,
200+
b_image_token_tag=b_image_token_tag,
223201
H=head,
224-
BLOCK_DMODEL=Lk,
202+
QK_HEAD_DIM=Lk,
203+
V_HEAD_DIM=Lk // 2,
225204
BLOCK_M=BLOCK_M,
226205
BLOCK_N=BLOCK_N,
227206
num_warps=num_warps,

lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def _get_neo_position_triton(
1616
b_ready_cache_len: torch.Tensor,
1717
b_q_seq_len: torch.Tensor,
1818
b_start_loc: torch.Tensor,
19+
b_image_token_tag: torch.Tensor,
1920
BLOCK_SIZE: tl.constexpr,
2021
) -> torch.Tensor:
2122
cur_batch = tl.program_id(0)
@@ -36,6 +37,13 @@ def _get_neo_position_triton(
3637
t_pos = local_image_start_idx + off * 0
3738
h_pos = off // image_w
3839
w_pos = off % image_w
40+
tl.store(
41+
b_image_token_tag + off + image_start_idx,
42+
True,
43+
mask=(off < image_len)
44+
& (off + local_image_start_idx - cache_len < q_seq_len)
45+
& (local_image_start_idx - cache_len + off >= 0),
46+
)
3947
tl.store(
4048
position_ids + off + image_start_idx,
4149
t_pos,
@@ -87,6 +95,7 @@ def get_neo_position_triton(
8795
b_ready_cache_len: torch.Tensor,
8896
b_q_seq_len: torch.Tensor,
8997
b_start_loc: torch.Tensor,
98+
b_image_token_tag: torch.Tensor,
9099
) -> torch.Tensor:
91100

92101
batch_size = b_q_seq_len.shape[0]
@@ -105,6 +114,7 @@ def get_neo_position_triton(
105114
b_ready_cache_len=b_ready_cache_len,
106115
b_q_seq_len=b_q_seq_len,
107116
b_start_loc=b_start_loc,
117+
b_image_token_tag=b_image_token_tag,
108118
BLOCK_SIZE=BLOCK_SIZE,
109119
)
110120

@@ -121,6 +131,7 @@ def test():
121131
.expand(3, -1)
122132
.contiguous()
123133
)
134+
b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda")
124135
position_ids[1:].zero_()
125136
b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
126137
b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
@@ -135,8 +146,10 @@ def test():
135146
b_ready_cache_len,
136147
b_q_seq_len,
137148
b_start_loc,
149+
b_image_token_tag,
138150
)
139151

152+
print(b_image_token_tag)
140153
print(position_ids)
141154
# old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
142155

@@ -172,3 +185,7 @@ def test():
172185
[0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]],
173186
device='cuda:0', dtype=torch.int32)
174187
"""
188+
189+
190+
if __name__ == "__main__":
191+
test()

lightllm/models/neo_chat_moe/vision_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,6 @@ def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=655
136136
)
137137
pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size)
138138

139-
print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})")
139+
# print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})")
140140

141141
return pixel_values, grid_hw

0 commit comments

Comments
 (0)