77
88
99def arrangement (
10- q , k , v , scale , o , BLOCK_SIZE_M = BLOCK_SIZE_M , BLOCK_SIZE_N = BLOCK_SIZE_N
10+ q , k , v , scale , q_start , o , BLOCK_SIZE_M = BLOCK_SIZE_M , BLOCK_SIZE_N = BLOCK_SIZE_N
1111):
1212 def arrange_q_or_o (input ):
1313 arranged = input .tile ((1 , 1 , BLOCK_SIZE_M , - 1 ))
@@ -26,10 +26,17 @@ def arrange_k_or_v(input):
2626
2727 q_arranged = arrange_q_or_o (q )
2828
29- return q_arranged , arrange_k_or_v (k ), arrange_k_or_v (v ), scale , arrange_q_or_o (o )
29+ return (
30+ q_arranged ,
31+ arrange_k_or_v (k ),
32+ arrange_k_or_v (v ),
33+ scale ,
34+ q_start ,
35+ arrange_q_or_o (o ),
36+ )
3037
3138
32- def application (q , k , v , scale , o ):
39+ def application (q , k , v , scale , q_start , o ):
3340 q_loaded = (q * scale * 1.44269504089 ).to (q .dtype )
3441
3542 acc = ntl .zeros ((q .shape [- 2 ], q .shape [- 1 ]), dtype = ntl .float32 )
@@ -38,7 +45,11 @@ def application(q, k, v, scale, o):
3845
3946 for i in range (k .shape [0 ]):
4047 qk = ntl .dot (q_loaded , ntl .trans (k [i ]))
41- qk = ntl .where (k [i ].offsets (- 2 ) < k .source .shape [- 2 ], qk , float ("-inf" ))
48+ qk = ntl .where (
49+ (q .offsets (- 2 ) + q_start )[:, None ] >= k [i ].offsets (- 2 ),
50+ qk ,
51+ float ("-inf" ),
52+ )
4253
4354 m_ij = ntl .maximum (m_i , ntl .max (qk , 1 ))
4455 p = ntl .exp2 (qk - m_ij [:, None ])
@@ -53,8 +64,8 @@ def application(q, k, v, scale, o):
5364 o = acc .to (o .dtype ) # noqa: F841
5465
5566
56- shape_options = (None , None , None , {"constexpr" : True , "upper_bound" : 128 })
57- q , k , v , o = (Tensor (4 , shape_options = shape_options ) for _ in range (4 ))
58- tensors = (q , k , v , Tensor (0 ), o )
67+ _shape_options = (None , None , None , {"constexpr" : True , "upper_bound" : 128 })
68+ _q , _k , _v , _o = (Tensor (4 , shape_options = _shape_options ) for _ in range (4 ))
69+ tensors = (_q , _k , _v , Tensor (0 ), Tensor ( 0 ), _o )
5970
6071kernel = ninetoothed .make (arrangement , application , tensors )
0 commit comments