Skip to content

Commit 5bb250c

Browse files
author
Ljubomir Josifovski
committed
Bugfixes to the GLA metal kernel 1) Grid dispatch was wrong (S/nsg, S/4, H*n_seqs) - correct to (1, S/4, H*n_seqs). The buggy version dispatched 32x too many threadgroups in x, all of them computing the same i-dimension?!? 2) The kernel was missing from the scheduler routing so was never routed to even when present. TBS looking like TG 54 tok/s (from 32 t/s), PP 115 tok/s (from 75 t/s). Major win
1 parent aa37e54 commit 5bb250c

3 files changed

Lines changed: 8 additions & 4 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
11911191
return true;
11921192
case GGML_OP_GATED_DELTA_NET:
11931193
return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;
1194+
case GGML_OP_GATED_LINEAR_ATTN:
1195+
return has_simdgroup_reduction && op->src[0]->ne[0] % 32 == 0;
11941196
case GGML_OP_SOLVE_TRI:
11951197
case GGML_OP_MUL_MAT:
11961198
case GGML_OP_MUL_MAT_ID:

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,7 +1717,7 @@ int ggml_metal_op_gated_linear_attn(ggml_metal_op_t ctx, int idx) {
17171717
const int H = ne01; // num_heads
17181718
const int n_seqs = ne41;
17191719

1720-
ggml_metal_encoder_dispatch_threadgroups(enc, S / nsg, S / 4, H * n_seqs, 32, nsg, 1);
1720+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, S / 4, H * n_seqs, 32, nsg, 1);
17211721

17221722
return 1;
17231723
}

tools/server/server-context.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3149,11 +3149,13 @@ struct server_context_impl {
31493149
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt.size(), n);
31503150
}
31513151

3152-
// ctx_mtp has no analogous checkpointing — auto-mirror
3153-
// wipes its tail; the next prefill ubatch repopulates
3154-
// it via the streaming hook.
31553152
llama_context_seq_rm(slot.ctx, slot.id, ckpt.pos_max + 1, -1);
31563153

3154+
// Roll back the MTP model's KV cache to match the trunk.
3155+
// Without this, ctx_mtp retains stale draft positions that
3156+
// corrupt subsequent draft generation.
3157+
common_speculative_accept(slot.spec.get(), (uint16_t)(accepted.size() - 1));
3158+
31573159
slot.prompt.tokens.keep_first(ckpt.n_tokens);
31583160
slot.smpl = std::move(smpl_save);
31593161

0 commit comments

Comments
 (0)