@@ -48,7 +48,8 @@ def kernel(
4848 o_stride_m ,
4949 o_stride_n ,
5050 scale ,
51- seq_len ,
51+ seq_len_q ,
52+ seq_len_k_v ,
5253 EMB_DIM : tl .constexpr ,
5354 BLOCK_SIZE_M : tl .constexpr ,
5455 BLOCK_SIZE_N : tl .constexpr ,
@@ -62,7 +63,7 @@ def kernel(
6263 q_off = off_z * q_stride_z + off_h * q_stride_h
6364 q_block_ptr = tl .make_block_ptr (
6465 base = q_ptr + q_off ,
65- shape = (seq_len , EMB_DIM ),
66+ shape = (seq_len_q , EMB_DIM ),
6667 strides = (q_stride_m , q_stride_k ),
6768 offsets = (offs_m_start , 0 ),
6869 block_shape = (BLOCK_SIZE_M , EMB_DIM ),
@@ -71,7 +72,7 @@ def kernel(
7172 k_off = off_z * k_stride_z + off_h * k_stride_h
7273 k_block_ptr = tl .make_block_ptr (
7374 base = k_ptr + k_off ,
74- shape = (EMB_DIM , seq_len ),
75+ shape = (EMB_DIM , seq_len_k_v ),
7576 strides = (k_stride_k , k_stride_n ),
7677 offsets = (0 , 0 ),
7778 block_shape = (EMB_DIM , BLOCK_SIZE_N ),
@@ -80,7 +81,7 @@ def kernel(
8081 v_off = off_z * v_stride_z + off_h * v_stride_h
8182 v_block_ptr = tl .make_block_ptr (
8283 base = v_ptr + v_off ,
83- shape = (seq_len , EMB_DIM ),
84+ shape = (seq_len_k_v , EMB_DIM ),
8485 strides = (v_stride_k , v_stride_n ),
8586 offsets = (0 , 0 ),
8687 block_shape = (BLOCK_SIZE_N , EMB_DIM ),
@@ -89,7 +90,7 @@ def kernel(
8990 o_off = off_z * o_stride_z + off_h * o_stride_h
9091 o_block_ptr = tl .make_block_ptr (
9192 base = o_ptr + o_off ,
92- shape = (seq_len , EMB_DIM ),
93+ shape = (seq_len_q , EMB_DIM ),
9394 strides = (o_stride_m , o_stride_n ),
9495 offsets = (offs_m_start , 0 ),
9596 block_shape = (BLOCK_SIZE_M , EMB_DIM ),
@@ -103,10 +104,10 @@ def kernel(
103104 l_i = tl .full ((BLOCK_SIZE_M ,), 1 , dtype = tl .float32 )
104105 m_i = tl .full ((BLOCK_SIZE_M ,), float ("-inf" ), dtype = tl .float32 )
105106
106- for i in range (0 , tl .cdiv (seq_len , BLOCK_SIZE_N )):
107+ for i in range (0 , tl .cdiv (seq_len_k_v , BLOCK_SIZE_N )):
107108 k = tl .load (k_block_ptr , boundary_check = (0 , 1 ))
108109
109- mask = i * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N ) < seq_len
110+ mask = i * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N ) < seq_len_k_v
110111 qk = tl .where (mask , tl .dot (q , k ), float ("-inf" ))
111112
112113 m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
0 commit comments