@@ -96,23 +96,24 @@ def kernel(
9696 order = (1 , 0 ),
9797 )
9898
99- q = (tl .load (q_block_ptr ) * scale * 1.44269504089 ).to (q_block_ptr .type .element_ty )
99+ q = tl .load (q_block_ptr , boundary_check = (0 , 1 ))
100+ q = (q * scale * 1.44269504089 ).to (q_block_ptr .type .element_ty )
100101
101102 acc = tl .zeros ((BLOCK_SIZE_M , EMB_DIM ), dtype = tl .float32 )
102103 l_i = tl .full ((BLOCK_SIZE_M ,), 1 , dtype = tl .float32 )
103104 m_i = tl .full ((BLOCK_SIZE_M ,), float ("-inf" ), dtype = tl .float32 )
104105
105- for _ in range (0 , tl .cdiv (seq_len , BLOCK_SIZE_N )):
106- k = tl .load (k_block_ptr )
106+ for i in range (0 , tl .cdiv (seq_len , BLOCK_SIZE_N )):
107+ k = tl .load (k_block_ptr , boundary_check = ( 0 , 1 ) )
107108
108- qk = tl .dot (q , k )
109+ mask = i * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N ) < seq_len
110+ qk = tl .where (mask , tl .dot (q , k ), float ("-inf" ))
109111
110112 m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
111- qk -= m_ij [:, None ]
112- p = tl .exp2 (qk )
113+ p = tl .exp2 (qk - m_ij [:, None ])
113114 l_ij = tl .sum (p , 1 )
114115
115- v = tl .load (v_block_ptr )
116+ v = tl .load (v_block_ptr , boundary_check = ( 0 , 1 ) )
116117 alpha = tl .exp2 (m_i - m_ij )
117118 acc = acc * alpha [:, None ] + tl .dot (p .to (v_block_ptr .type .element_ty ), v )
118119 m_i = m_ij
@@ -123,4 +124,4 @@ def kernel(
123124
124125 acc /= l_i [:, None ]
125126
126- tl .store (o_block_ptr , acc .to (o_ptr .type .element_ty ))
127+ tl .store (o_block_ptr , acc .to (o_ptr .type .element_ty ), boundary_check = ( 0 , 1 ) )
0 commit comments