@@ -34,8 +34,10 @@ def _fwd_kernel(
3434 stride_req_to_tokens_s ,
3535 kv_group_num ,
3636 b_prompt_cache_len ,
37+ b_image_token_tag ,
3738 H : tl .constexpr ,
38- BLOCK_DMODEL : tl .constexpr ,
39+ QK_HEAD_DIM : tl .constexpr ,
40+ V_HEAD_DIM : tl .constexpr ,
3941 BLOCK_M : tl .constexpr ,
4042 BLOCK_N : tl .constexpr ,
4143):
@@ -53,16 +55,19 @@ def _fwd_kernel(
5355 cur_batch_req_idx = tl .load (B_req_idx + cur_batch )
5456
5557 block_start_loc = BLOCK_M * start_m
58+ if block_start_loc >= cur_batch_seq_len :
59+ return
5660
5761 offs_n = tl .arange (0 , BLOCK_N )
58- offs_d = tl .arange (0 , BLOCK_DMODEL )
62+ offs_d_qk = tl .arange (0 , QK_HEAD_DIM )
63+ offs_d_v = tl .arange (0 , V_HEAD_DIM )
5964 offs_m = block_start_loc + tl .arange (0 , BLOCK_M )
6065
6166 # Q pointers
6267 off_q = (
6368 (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_qbs
6469 + cur_head * stride_qh
65- + offs_d [None , :] * stride_qd
70+ + offs_d_qk [None , :] * stride_qd
6671 )
6772
6873 q_valid = offs_m < cur_batch_seq_len
@@ -71,24 +76,14 @@ def _fwd_kernel(
7176 # online softmax state
7277 m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
7378 l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
74- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL ], dtype = tl .float32 )
75-
76- block_mask = tl .where (block_start_loc < cur_batch_seq_len , 1 , 0 )
79+ acc = tl .zeros ([BLOCK_M , V_HEAD_DIM ], dtype = tl .float32 )
7780 block_end_loc = total_len
7881
7982 # absolute q positions in the request
8083 q_pos = prompt_cache_len + offs_m # [M]
84+ q_image_token_tag = tl .load (b_image_token_tag + cur_batch_in_all_start_index + offs_m , mask = q_valid , other = False )
8185
82- # q_gid from packed position_ids (aligned with Q rows)
83- q_gid = tl .load (
84- position_ids + cur_batch_in_all_start_index + offs_m ,
85- mask = q_valid ,
86- other = - 2147483648 ,
87- ).to (tl .int32 )
88-
89- BIG = tl .full ([BLOCK_N ], 1000000000 , tl .int32 ) # ensure != any normal gid
90-
91- for start_n in range (0 , block_mask * block_end_loc , BLOCK_N ):
86+ for start_n in range (0 , block_end_loc , BLOCK_N ):
9287 start_n = tl .multiple_of (start_n , BLOCK_N )
9388
9489 k_pos = start_n + offs_n # [N]
@@ -102,32 +97,13 @@ def _fwd_kernel(
10297 ).to (tl .int64 )
10398
10499 # load K
105- off_k = kv_loc [None , :] * stride_kbs + cur_kv_head * stride_kh + offs_d [:, None ] * stride_kd
100+ off_k = kv_loc [None , :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk [:, None ] * stride_kd
106101 k = tl .load (K + off_k , mask = k_valid [None , :], other = 0.0 )
107-
108- qk = tl .dot (q , k )
109-
110- # k_gid:
111- # - for cached keys (k_pos < prompt_cache_len): set to BIG + k_pos so equality is always false
112- # - for new keys (k_pos >= prompt_cache_len): read from packed position_ids by (k_pos - prompt_cache_len)
113- k_in_new = k_pos >= prompt_cache_len
114- k_new_idx = (k_pos - prompt_cache_len ).to (tl .int32 ) # [N] valid only when k_in_new
115- k_gid_new = tl .load (
116- position_ids + cur_batch_in_all_start_index + k_new_idx ,
117- mask = k_valid & k_in_new ,
118- other = - 2147483647 ,
119- ).to (tl .int32 )
120-
121- k_gid = tl .where (
122- k_in_new ,
123- k_gid_new ,
124- (k_pos .to (tl .int32 ) + BIG ),
125- )
102+ qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
103+ qk += tl .dot (q , k )
126104
127105 # mask: causal OR same gid (only possible inside NEW part)
128- mask = (q_pos [:, None ] >= k_pos [None , :]) | (q_gid [:, None ] == k_gid [None , :])
129- mask = mask & q_valid [:, None ] & k_valid [None , :]
130-
106+ mask = (q_pos [:, None ] >= k_pos [None , :]) | q_image_token_tag [:, None ]
131107 qk = tl .where (mask , qk * sm_scale , - 1.0e8 )
132108
133109 # online softmax
@@ -141,7 +117,7 @@ def _fwd_kernel(
141117 acc = acc * alpha [:, None ]
142118
143119 # load V
144- off_v = kv_loc [:, None ] * stride_vbs + cur_kv_head * stride_vh + offs_d [None , :] * stride_vd
120+ off_v = kv_loc [:, None ] * stride_vbs + cur_kv_head * stride_vh + offs_d_v [None , :] * stride_vd
145121 v = tl .load (V + off_v , mask = k_valid [:, None ], other = 0.0 )
146122
147123 p = p .to (v .dtype )
@@ -154,7 +130,7 @@ def _fwd_kernel(
154130 off_o = (
155131 (cur_batch_in_all_start_index + offs_m [:, None ]) * stride_obs
156132 + cur_head * stride_oh
157- + offs_d [None , :] * stride_od
133+ + offs_d_v [None , :] * stride_od
158134 )
159135 tl .store (Out + off_o , acc , mask = q_valid [:, None ])
160136
@@ -172,6 +148,7 @@ def context_attention_fwd_neo(
172148 b_prompt_cache_len ,
173149 max_input_len ,
174150 req_to_token_indexs ,
151+ b_image_token_tag ,
175152):
176153 # minimal safety: position_ids must cover packed q rows
177154 assert position_ids .numel () >= q .shape [0 ], (position_ids .numel (), q .shape [0 ])
@@ -220,8 +197,10 @@ def context_attention_fwd_neo(
220197 req_to_token_indexs .stride (1 ),
221198 kv_group_num = kv_group_num ,
222199 b_prompt_cache_len = b_prompt_cache_len ,
200+ b_image_token_tag = b_image_token_tag ,
223201 H = head ,
224- BLOCK_DMODEL = Lk ,
202+ QK_HEAD_DIM = Lk ,
203+ V_HEAD_DIM = Lk // 2 ,
225204 BLOCK_M = BLOCK_M ,
226205 BLOCK_N = BLOCK_N ,
227206 num_warps = num_warps ,
0 commit comments