@@ -155,13 +155,11 @@ def update_step_context(cls, step_context):
155155 """Update step context."""
156156
157157 block_num , block_size , * _ = step_context .kv_caches [0 ][0 ].shape
158- is_unpaged_prefill = False
158+ is_prefill_no_cache = False
159159 if not step_context .is_decoding :
160- is_unpaged_prefill = all ((step_context .q_seqlens == step_context .kv_seqlens ).tolist ())
160+ is_prefill_no_cache = all ((step_context .q_seqlens == step_context .kv_seqlens ).tolist ())
161161 if step_context .block_offsets .dtype != torch .int32 :
162162 step_context .block_offsets = step_context .block_offsets .to (torch .int32 )
163- if not (step_context .is_decoding or is_unpaged_prefill ):
164- step_context .block_offsets = step_context .block_offsets .repeat_interleave (step_context .q_seqlens , 0 )
165163 if step_context .kv_seqlens .dtype != torch .int32 :
166164 step_context .kv_seqlens = step_context .kv_seqlens .to (torch .int32 )
167165 if step_context .q_seqlens .dtype != torch .int32 :
@@ -175,7 +173,7 @@ def get_total_slots():
175173 cls .total_slots = cls .total_slots .view (block_num , block_size )
176174 return cls .total_slots
177175
178- def get_cpu_seqlens (is_decoding , is_unpaged_prefill ):
176+ def get_cpu_seqlens (is_decoding , is_prefill_no_cache ):
179177 """Get sequence lengths on CPU.
180178
181179 Returns:
@@ -187,37 +185,43 @@ def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
187185 """
188186 if is_decoding :
189187 q_seqlens_cpu = None
190- kv_seqlens_cpu = kv_seqlens_expanded = step_context .kv_seqlens .cpu ()
191- elif is_unpaged_prefill :
188+ kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
189+ elif is_prefill_no_cache :
192190 q_seqlens_cpu = step_context .q_seqlens .cpu ()
193- kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu
191+ kv_seqlens_cpu = q_seqlens_cpu
194192 else :
195193 q_seqlens_cpu = step_context .q_seqlens .cpu ()
196194 kv_seqlens_cpu = step_context .kv_seqlens .cpu ()
197- # Expand kv_seqlens to per-token for paged prefill attention
198- kv_seqlens_expanded = kv_seqlens_cpu .repeat_interleave (q_seqlens_cpu , 0 )
199- return q_seqlens_cpu , kv_seqlens_cpu , kv_seqlens_expanded
195+ return q_seqlens_cpu , kv_seqlens_cpu
200196
201- def get_list_seqlens (is_decoding , is_unpaged_prefill , q_seqlens_cpu = None , kv_seqlens_cpu = None ):
197+ def get_list_seqlens (is_decoding , is_prefill_no_cache , q_seqlens_cpu = None , kv_seqlens_cpu = None ):
202198 if is_decoding :
203199 q_seqlens_list , kv_seqlens_list = None , None
204- elif is_unpaged_prefill :
200+ elif is_prefill_no_cache :
205201 q_seqlens_list = kv_seqlens_list = q_seqlens_cpu .tolist ()
206202 else :
207203 q_seqlens_list , kv_seqlens_list = q_seqlens_cpu .tolist (), kv_seqlens_cpu .tolist ()
208204 return q_seqlens_list , kv_seqlens_list
209205
210- def get_max_seqlens (is_decoding , is_unpaged_prefill , q_seqlens_list = None , kv_seqlens_list = None ):
206+ def get_max_seqlens (is_decoding , is_prefill_no_cache , q_seqlens_list = None , kv_seqlens_list = None ):
211207 if is_decoding :
212208 max_q_seq_len , max_kv_seq_len = 1 , None
213- elif is_unpaged_prefill :
209+ elif is_prefill_no_cache :
214210 max_q_seq_len = max_kv_seq_len = max (q_seqlens_list )
215211 else :
216212 max_q_seq_len = max (q_seqlens_list )
217213 max_kv_seq_len = max (kv_seqlens_list )
218214 return max_q_seq_len , max_kv_seq_len
219215
220- def get_kv_start_indices_and_attention_mask (is_decoding , is_unpaged_prefill , q_seqlens_list , kv_seqlens_list ,
216+ def update_q_seqlens (is_decoding , is_prefill_no_cache , q_seqlens_cpu = None ):
217+ if is_decoding :
218+ batch_size = step_context .q_seqlens .size (0 )
219+ return torch .arange (1 , batch_size + 1 , dtype = torch .int32 )
220+ elif is_prefill_no_cache :
221+ return q_seqlens_cpu
222+ return q_seqlens_cpu .cumsum (dim = 0 )
223+
224+ def get_kv_start_indices_and_attention_mask (is_decoding , is_prefill_no_cache , q_seqlens_list , kv_seqlens_list ,
221225 max_q_seq_len , max_kv_seq_len ):
222226 kv_start_indices , attention_mask = [], []
223227 if is_decoding :
@@ -236,25 +240,17 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
236240 slots = slot_tables [history_length :kv_seq_len ]
237241 kv_start_indices .append (slots )
238242
239- if not is_unpaged_prefill :
240- single_attention_mask = torch .triu (
241- torch .ones (q_seq_len ,
242- step_context .block_offsets .shape [1 ] * block_size ,
243- dtype = torch .bool ,
244- device = step_context .block_offsets .device ),
245- diagonal = kv_seq_len - q_seq_len + 1 ,
246- )
247- attention_mask .append (single_attention_mask )
248-
249- if is_unpaged_prefill :
243+ if is_prefill_no_cache :
250244 attention_mask .append (
251245 torch .triu (torch .ones (max_q_seq_len ,
252246 max_kv_seq_len ,
253247 dtype = step_context .kv_caches [0 ][0 ].dtype ,
254248 device = step_context .block_offsets .device ),
255249 diagonal = max_kv_seq_len - max_q_seq_len + 1 ))
256250 else :
257- attention_mask = [torch .cat (attention_mask )]
251+ attention_mask .append (
252+ torch .triu (torch .ones (2048 , 2048 , dtype = torch .bool , device = step_context .block_offsets .device ),
253+ diagonal = 1 ))
258254
259255 kv_start_indices = torch .cat (kv_start_indices )
260256
@@ -357,16 +353,16 @@ def get_moe_group_name(group):
357353 group_name = backend .get_hccl_comm_name (local_rank )
358354 return group_name
359355
360- q_seqlens_cpu , kv_seqlens_cpu , kv_seqlens_expanded = get_cpu_seqlens (step_context .is_decoding ,
361- is_unpaged_prefill )
362- q_seqlens_list , kv_seqlens_list = get_list_seqlens (step_context .is_decoding , is_unpaged_prefill , q_seqlens_cpu ,
356+ q_seqlens_cpu , kv_seqlens_cpu = get_cpu_seqlens (step_context .is_decoding , is_prefill_no_cache )
357+ q_seqlens_list , kv_seqlens_list = get_list_seqlens (step_context .is_decoding , is_prefill_no_cache , q_seqlens_cpu ,
363358 kv_seqlens_cpu )
364- max_q_seq_len , max_kv_seq_len = get_max_seqlens (step_context .is_decoding , is_unpaged_prefill , q_seqlens_list ,
359+ max_q_seq_len , max_kv_seq_len = get_max_seqlens (step_context .is_decoding , is_prefill_no_cache , q_seqlens_list ,
365360 kv_seqlens_list )
366361 kv_start_indices , attention_mask = get_kv_start_indices_and_attention_mask (step_context .is_decoding ,
367- is_unpaged_prefill , q_seqlens_list ,
362+ is_prefill_no_cache , q_seqlens_list ,
368363 kv_seqlens_list , max_q_seq_len ,
369364 max_kv_seq_len )
365+ q_seqlens_cpu = update_q_seqlens (step_context .is_decoding , is_prefill_no_cache , q_seqlens_cpu )
370366
371367 if not cls .enable_graph and step_context .kv_quant_policy == 8 :
372368 record_file = os .getenv ('ASCEND_QUANT_RECORD_FILE' )
@@ -400,13 +396,11 @@ def get_moe_group_name(group):
400396 # Otherwise, q_start_loc is None.
401397 q_start_loc = cu_seqlens ,
402398 q_seqlens = q_seqlens_cpu ,
403- # kv_seqlens_expanded is only expanded in paged prefill,
404- # otherwise it equals kv_seqlens_cpu
405- kv_seqlens = kv_seqlens_expanded ,
399+ kv_seqlens = kv_seqlens_cpu ,
406400 kv_start_indices = kv_start_indices ,
407401 block_size = block_size ,
408402 attention_mask = attention_mask ,
409- is_unpaged_prefill = is_unpaged_prefill ,
403+ is_prefill_no_cache = is_prefill_no_cache ,
410404 max_q_seq_len = max_q_seq_len ,
411405 max_kv_seq_len = max_kv_seq_len ,
412406 quant_policy = step_context .kv_quant_policy ,
0 commit comments