Skip to content

Commit aa37e54

Browse files
author
Ljubomir Josifovski
committed
more diags (by DS), but MTP output still garbage
1 parent 51ac2d5 commit aa37e54

1 file changed

Lines changed: 14 additions & 14 deletions

File tree

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2756,7 +2756,9 @@ template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta
27562756
// state_in: [S*S*H, n_seqs]
27572757
// output (no intermediates): [S*H, n_tokens*n_seqs + S*n_seqs] (activations + final state)
27582758
// output (keep_intermediates): [S*H, n_tokens*n_seqs + S*S*H*n_seqs + (n_seq_tokens-1)*S*S*H*n_seqs]
2759-
// layout: [activations] [final_state] [snapshots: t=1 state, t=2 state, ...] (reverse order, skip t=0)
2759+
// layout: [activations] [final_state] [snapshots: state after confirmed token, state after draft skip, ...]
2760+
// saves state after each token EXCEPT the last (which is stored as final state)
2761+
// snap_idx = t_in_seq (0-indexed per sequence), skip t_end-1
27602762
// grid=(S/NSG, S/4, H*n_seqs), threadgroup=(NW, 4, 1)
27612763
template<short NSG, bool KEEP_INTERMEDIATES>
27622764
kernel void kernel_gated_linear_attn_impl(
@@ -2832,20 +2834,18 @@ kernel void kernel_gated_linear_attn_impl(
28322834
dst_out[(t - t_start) * S * H + h_idx * S + j] = out_j;
28332835
}
28342836

2835-
// Store intermediate state for rollback (skip t=0, reverse-index per CPU semantics)
2836-
if (KEEP_INTERMEDIATES) {
2837+
// Store intermediate state for rollback (save state after each token except the last)
2838+
// The snapshot at index (t - t_start) is the state AFTER confirmed tokens,
2839+
// BEFORE the draft token. On draft rejection, this is what we roll back to.
2840+
if (KEEP_INTERMEDIATES && t < t_end - 1) {
28372841
const short t_in_seq = t - t_start;
2838-
if (t_in_seq > 0) {
2839-
const short snap_idx = T_per_seq - 1 - t_in_seq;
2840-
// Snapshot layout: after final state, indexed by (snap_idx * n_seqs + seq_idx) * H * S * S + h_idx * S * S
2841-
const int64_t state_elems_all = (int64_t)S * S * H * n_seqs;
2842-
device float * snap = (device float *) dst + state_base + state_elems_all
2843-
+ ((int64_t)snap_idx * n_seqs + seq_idx) * H * S * S + h_idx * S * S;
2844-
FOR_UNROLL (short i = 0; i < NSG; i++) {
2845-
const short ii = i_start + tx*NSG + i;
2846-
if (ii < S && j < S) {
2847-
snap[ii * S + j] = state_vals[i];
2848-
}
2842+
const int64_t state_elems_all = (int64_t)S * S * H * n_seqs;
2843+
device float * snap = (device float *) dst + state_base + state_elems_all
2844+
+ ((int64_t)t_in_seq * n_seqs + seq_idx) * H * S * S + h_idx * S * S;
2845+
FOR_UNROLL (short i = 0; i < NSG; i++) {
2846+
const short ii = i_start + tx*NSG + i;
2847+
if (ii < S && j < S) {
2848+
snap[ii * S + j] = state_vals[i];
28492849
}
28502850
}
28512851
}

0 commit comments

Comments
 (0)