@@ -2756,7 +2756,9 @@ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta
27562756// state_in: [S*S*H, n_seqs]
27572757// output (no intermediates): [S*H, n_tokens*n_seqs + S*n_seqs] (activations + final state)
27582758// output (keep_intermediates): [S*H, n_tokens*n_seqs + S*S*H*n_seqs + (n_seq_tokens-1)*S*S*H*n_seqs]
2759- // layout: [activations] [final_state] [snapshots: t=1 state, t=2 state, ...] (reverse order, skip t=0)
2759+ // layout: [activations] [final_state] [snapshots: state after confirmed token, state after draft skip, ...]
2760+ // saves state after each token EXCEPT the last (which is stored as final state)
2761+ // snap_idx = t_in_seq (0-indexed per sequence), skip t_end-1
27602762// grid=(S/NSG, S/4, H*n_seqs), threadgroup=(NW, 4, 1)
27612763template <short NSG, bool KEEP_INTERMEDIATES>
27622764kernel void kernel_gated_linear_attn_impl (
@@ -2832,20 +2834,18 @@ kernel void kernel_gated_linear_attn_impl(
28322834 dst_out[(t - t_start) * S * H + h_idx * S + j] = out_j;
28332835 }
28342836
2835- // Store intermediate state for rollback (skip t=0, reverse-index per CPU semantics)
2836- if (KEEP_INTERMEDIATES) {
2837+ // Store intermediate state for rollback (save state after each token except the last)
2838+ // The snapshot at index (t - t_start) is the state AFTER confirmed tokens,
2839+ // BEFORE the draft token. On draft rejection, this is what we roll back to.
2840+ if (KEEP_INTERMEDIATES && t < t_end - 1 ) {
28372841 const short t_in_seq = t - t_start;
2838- if (t_in_seq > 0 ) {
2839- const short snap_idx = T_per_seq - 1 - t_in_seq;
2840- // Snapshot layout: after final state, indexed by (snap_idx * n_seqs + seq_idx) * H * S * S + h_idx * S * S
2841- const int64_t state_elems_all = (int64_t )S * S * H * n_seqs;
2842- device float * snap = (device float *) dst + state_base + state_elems_all
2843- + ((int64_t )snap_idx * n_seqs + seq_idx) * H * S * S + h_idx * S * S;
2844- FOR_UNROLL (short i = 0 ; i < NSG; i++) {
2845- const short ii = i_start + tx*NSG + i;
2846- if (ii < S && j < S) {
2847- snap[ii * S + j] = state_vals[i];
2848- }
2842+ const int64_t state_elems_all = (int64_t )S * S * H * n_seqs;
2843+ device float * snap = (device float *) dst + state_base + state_elems_all
2844+ + ((int64_t )t_in_seq * n_seqs + seq_idx) * H * S * S + h_idx * S * S;
2845+ FOR_UNROLL (short i = 0 ; i < NSG; i++) {
2846+ const short ii = i_start + tx*NSG + i;
2847+ if (ii < S && j < S) {
2848+ snap[ii * S + j] = state_vals[i];
28492849 }
28502850 }
28512851 }
0 commit comments