@@ -39,6 +39,7 @@ struct Params {
3939
4040 neq1 : u32 ,
4141 rq3 : u32 ,
42+ K : u32 ,
4243 scale : f32 ,
4344};
4445
@@ -62,11 +63,14 @@ fn main(
6263 let iq3 = seq_id / params . rq3 ;
6364
6465 let state_size = S_V * S_V ;
65- let state_base = (seq_id * params . h + head_id ) * state_size ;
66+ let state_in_base = (seq_id * params . K * params . h + head_id ) * state_size ;
67+ let state_out_base = (seq_id * params . h + head_id ) * state_size ;
68+ let state_size_per_snap = state_size * params . h * params . n_seqs ;
69+ let shift = i32 (params . n_tokens ) - i32 (params . K );
6670
6771 var state : array <f32 , S_V >;
6872 for (var i = 0u ; i < S_V ; i ++ ) {
69- state [i ] = src_state [state_base + col * S_V + i ];
73+ state [i ] = src_state [state_in_base + col * S_V + i ];
7074 }
7175
7276 var attn_off = (seq_id * params . n_tokens * params . h + head_id ) * S_V ;
@@ -123,10 +127,22 @@ fn main(
123127 dst [attn_off + col ] = attn_col * params . scale ;
124128 attn_off += S_V * params . h ;
125129
130+ if (params . K > 1u ) {
131+ let target_slot = i32 (t ) - shift ;
132+ if (target_slot >= 0 && target_slot < i32 (params . K )) {
133+ let slot_base = params . s_off + u32 (target_slot ) * state_size_per_snap + state_out_base ;
134+ for (var i = 0u ; i < S_V ; i ++ ) {
135+ dst [slot_base + col * S_V + i ] = state [i ];
136+ }
137+ }
138+ }
139+
126140 workgroupBarrier ();
127141 }
128142
129- for (var i = 0u ; i < S_V ; i ++ ) {
130- dst [params . s_off + state_base + col * S_V + i ] = state [i ];
143+ if (params . K == 1u ) {
144+ for (var i = 0u ; i < S_V ; i ++ ) {
145+ dst [params . s_off + state_out_base + col * S_V + i ] = state [i ];
146+ }
131147 }
132148}
0 commit comments