@@ -19,7 +19,7 @@ def _gen_cumsum_pad0_kernel(
1919
2020 for start_index in range (0 , size , BLOCK ):
2121 current_offs = start_index + offs
22- in_data = tl .load (b_q_seq_len + offs , mask = current_offs < size , other = 0 )
22+ in_data = tl .load (b_q_seq_len + current_offs , mask = current_offs < size , other = 0 )
2323 in_data = tl .cumsum (in_data ) + start_value
2424 start_value = tl .max (in_data , 0 )
2525 tl .store (b1_cu_q_seq_len + current_offs + 1 , in_data , mask = current_offs < size )
@@ -30,7 +30,7 @@ def _gen_cumsum_pad0_kernel(
3030 start_value = tl .cast (0 , tl .int64 )
3131 for start_index in range (0 , size , BLOCK ):
3232 current_offs = start_index + offs
33- in_data = tl .load (b_kv_seq_len + offs * b_kv_seq_len_stride_0 , mask = current_offs < size , other = 0 )
33+ in_data = tl .load (b_kv_seq_len + current_offs * b_kv_seq_len_stride_0 , mask = current_offs < size , other = 0 )
3434 in_data = tl .cumsum (in_data ) + start_value
3535 start_value = tl .max (in_data , 0 )
3636 tl .store (b1_cu_kv_seq_len + current_offs + 1 , in_data , mask = current_offs < size )
0 commit comments