Skip to content

Commit 14700f3

Browse files
committed
Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)
The ScaledDotProductAttention primitive relied on Custom::vmap which re-vmapped the fallback lambda. That lambda captured n_q_heads and n_kv_heads at creation time, causing shape mismatches (SIGSEGV/hang) when vmap changed the array dimensions. Add a dedicated vmap override that merges the vmap axis into the batch dimension and re-invokes scaled_dot_product_attention, which recomputes head counts from actual shapes. Falls back to Custom::vmap for sinks.
1 parent b98831a commit 14700f3

3 files changed

Lines changed: 121 additions & 0 deletions

File tree

mlx/fast.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,83 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
5050
return {outputs, out_axes};
5151
}
5252

53+
std::pair<std::vector<array>, std::vector<int>> ScaledDotProductAttention::vmap(
54+
const std::vector<array>& inputs,
55+
const std::vector<int>& axes) {
56+
auto s = stream();
57+
58+
// Sinks require 1-D input; fall back to generic vmap for that case.
59+
if (has_sinks_) {
60+
return Custom::vmap(inputs, axes);
61+
}
62+
63+
// Determine vmap size from the first mapped input.
64+
int vmap_size = -1;
65+
for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
66+
if (axes[i] != -1) {
67+
vmap_size = inputs[i].shape(axes[i]);
68+
break;
69+
}
70+
}
71+
72+
auto prepare = [&s, vmap_size](const array& x, int ax) -> array {
73+
if (ax == -1) {
74+
return repeat(expand_dims(x, 0, s), vmap_size, 0, s);
75+
}
76+
if (ax != 0) {
77+
return moveaxis(x, ax, 0, s);
78+
}
79+
return x;
80+
};
81+
82+
auto q = prepare(inputs[0], axes[0]);
83+
auto k = prepare(inputs[1], axes[1]);
84+
auto v = prepare(inputs[2], axes[2]);
85+
86+
// [V, B, H, L, D] -> [V*B, H, L, D]
87+
auto merge_batch = [&s, vmap_size](const array& x) -> array {
88+
auto shape = x.shape();
89+
Shape new_shape = {vmap_size * shape[1]};
90+
new_shape.insert(new_shape.end(), shape.begin() + 2, shape.end());
91+
return reshape(x, std::move(new_shape), s);
92+
};
93+
94+
q = merge_batch(q);
95+
k = merge_batch(k);
96+
v = merge_batch(v);
97+
98+
std::optional<array> mask_arr;
99+
bool has_arr_mask = !do_causal_ && inputs.size() > 3;
100+
if (has_arr_mask) {
101+
mask_arr = merge_batch(prepare(inputs[3], axes[3]));
102+
}
103+
std::string mask_mode = do_causal_ ? "causal" : has_arr_mask ? "array" : "";
104+
105+
auto out = scaled_dot_product_attention(
106+
q, k, v, scale_, mask_mode, mask_arr, std::nullopt, s);
107+
108+
// [V*B, H, L, D] -> [V, B, H, L, D]
109+
auto split_batch = [&s, vmap_size](const array& x) -> array {
110+
auto shape = x.shape();
111+
Shape new_shape = {vmap_size, shape[0] / vmap_size};
112+
new_shape.insert(new_shape.end(), shape.begin() + 1, shape.end());
113+
return reshape(x, std::move(new_shape), s);
114+
};
115+
116+
out = split_batch(out);
117+
118+
// The re-invoked SDPA may produce a logsumexp sibling when training.
119+
if (output_logsumexp_) {
120+
assert(
121+
!out.siblings().empty() &&
122+
"vmap'd SDPA expected logsumexp sibling output");
123+
auto lse = split_batch(out.siblings()[0]);
124+
return {{out, lse}, {0, 0}};
125+
}
126+
127+
return {{out}, {0}};
128+
}
129+
53130
array rms_norm(
54131
const array& x,
55132
const std::optional<array>& weight,

mlx/fast_primitives.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ class ScaledDotProductAttention : public Custom {
244244
const std::vector<int>& argnums,
245245
const std::vector<array>& outputs) override;
246246

247+
std::pair<std::vector<array>, std::vector<int>> vmap(
248+
const std::vector<array>& inputs,
249+
const std::vector<int>& axes) override;
250+
247251
bool is_equivalent(const Primitive& other) const override;
248252

249253
DEFINE_NAME(ScaledDotProductAttention);

python/tests/test_fast_sdpa.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,46 @@ def test_sdpa_sliced(self):
642642
tolerance = {"rtol": 1e-2, "atol": 1e-2}
643643
self.assertTrue(mx.allclose(ref, out, **tolerance))
644644

645+
@unittest.skipIf(not mx.metal.is_available(), "Metal kernel required")
646+
def test_sdpa_vmap_uses_fused_kernel(self):
647+
"""Verify vmap'd SDPA dispatches the fused Metal kernel, not the
648+
decomposed fallback. The fused kernel and the fallback use different
649+
accumulation orders, producing distinguishable float16 results for
650+
large enough shapes. We check that vmap output matches the
651+
non-vmapped (fused) output exactly."""
652+
D = 64
653+
L = 128 # L > 8 → sdpa_full kernel path
654+
scale = 1.0 / math.sqrt(D)
655+
B = 4
656+
657+
for n_q, n_kv in [(32, 32), (32, 8), (16, 4)]:
658+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv):
659+
mx.random.seed(42)
660+
q = mx.random.normal((B, n_q, L, D), dtype=mx.float16)
661+
k = mx.random.normal((B, n_kv, L, D), dtype=mx.float16)
662+
v = mx.random.normal((B, n_kv, L, D), dtype=mx.float16)
663+
mx.eval(q, k, v)
664+
665+
# Non-vmapped SDPA — uses fused Metal kernel
666+
kernel_out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
667+
668+
# Vmapped SDPA — should also use fused kernel
669+
def f(qi, ki, vi):
670+
return mx.fast.scaled_dot_product_attention(
671+
qi[None], ki[None], vi[None], scale=scale
672+
)[0]
673+
674+
vmap_out = mx.vmap(f)(q, k, v)
675+
mx.eval(kernel_out, vmap_out)
676+
677+
# Fused kernel output must match exactly (same kernel, same
678+
# accumulation). If vmap fell back to the decomposed path,
679+
# float16 rounding differences would cause a mismatch.
680+
self.assertTrue(
681+
mx.array_equal(kernel_out, vmap_out),
682+
f"vmap path did not use fused kernel for ({n_q},{n_kv})",
683+
)
684+
645685

646686
if __name__ == "__main__":
647687
mlx_tests.MLXTestRunner(failfast=True)

0 commit comments

Comments
 (0)