Skip to content

Commit 4e42445

Browse files
committed
Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)
The ScaledDotProductAttention primitive relied on Custom::vmap which re-vmapped the fallback lambda. That lambda captured n_q_heads and n_kv_heads at creation time, causing shape mismatches (SIGSEGV/hang) when vmap changed the array dimensions. Add a dedicated vmap override that merges the vmap axis into the batch dimension and re-invokes scaled_dot_product_attention, which recomputes head counts from actual shapes. Falls back to Custom::vmap for sinks.
1 parent b98831a commit 4e42445

3 files changed

Lines changed: 191 additions & 0 deletions

File tree

mlx/fast.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,83 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
5050
return {outputs, out_axes};
5151
}
5252

53+
std::pair<std::vector<array>, std::vector<int>> ScaledDotProductAttention::vmap(
54+
const std::vector<array>& inputs,
55+
const std::vector<int>& axes) {
56+
auto s = stream();
57+
58+
// Sinks require 1-D input; fall back to generic vmap for that case.
59+
if (has_sinks_) {
60+
return Custom::vmap(inputs, axes);
61+
}
62+
63+
// Determine vmap size from the first mapped input.
64+
int vmap_size = -1;
65+
for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
66+
if (axes[i] != -1) {
67+
vmap_size = inputs[i].shape(axes[i]);
68+
break;
69+
}
70+
}
71+
72+
auto prepare = [&s, vmap_size](const array& x, int ax) -> array {
73+
if (ax == -1) {
74+
return repeat(expand_dims(x, 0, s), vmap_size, 0, s);
75+
}
76+
if (ax != 0) {
77+
return moveaxis(x, ax, 0, s);
78+
}
79+
return x;
80+
};
81+
82+
auto q = prepare(inputs[0], axes[0]);
83+
auto k = prepare(inputs[1], axes[1]);
84+
auto v = prepare(inputs[2], axes[2]);
85+
86+
// [V, B, H, L, D] -> [V*B, H, L, D]
87+
auto merge_batch = [&s, vmap_size](const array& x) -> array {
88+
auto shape = x.shape();
89+
Shape new_shape = {vmap_size * shape[1]};
90+
new_shape.insert(new_shape.end(), shape.begin() + 2, shape.end());
91+
return reshape(x, std::move(new_shape), s);
92+
};
93+
94+
q = merge_batch(q);
95+
k = merge_batch(k);
96+
v = merge_batch(v);
97+
98+
std::optional<array> mask_arr;
99+
bool has_arr_mask = !do_causal_ && inputs.size() > 3;
100+
if (has_arr_mask) {
101+
mask_arr = merge_batch(prepare(inputs[3], axes[3]));
102+
}
103+
std::string mask_mode = do_causal_ ? "causal" : has_arr_mask ? "array" : "";
104+
105+
auto out = scaled_dot_product_attention(
106+
q, k, v, scale_, mask_mode, mask_arr, std::nullopt, s);
107+
108+
// [V*B, H, L, D] -> [V, B, H, L, D]
109+
auto split_batch = [&s, vmap_size](const array& x) -> array {
110+
auto shape = x.shape();
111+
Shape new_shape = {vmap_size, shape[0] / vmap_size};
112+
new_shape.insert(new_shape.end(), shape.begin() + 1, shape.end());
113+
return reshape(x, std::move(new_shape), s);
114+
};
115+
116+
out = split_batch(out);
117+
118+
// The re-invoked SDPA may produce a logsumexp sibling when training.
119+
if (output_logsumexp_) {
120+
assert(
121+
!out.siblings().empty() &&
122+
"vmap'd SDPA expected logsumexp sibling output");
123+
auto lse = split_batch(out.siblings()[0]);
124+
return {{out, lse}, {0, 0}};
125+
}
126+
127+
return {{out}, {0}};
128+
}
129+
53130
array rms_norm(
54131
const array& x,
55132
const std::optional<array>& weight,

mlx/fast_primitives.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ class ScaledDotProductAttention : public Custom {
244244
const std::vector<int>& argnums,
245245
const std::vector<array>& outputs) override;
246246

247+
std::pair<std::vector<array>, std::vector<int>> vmap(
248+
const std::vector<array>& inputs,
249+
const std::vector<int>& axes) override;
250+
247251
bool is_equivalent(const Primitive& other) const override;
248252

249253
DEFINE_NAME(ScaledDotProductAttention);

python/tests/test_fast_sdpa.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,116 @@ def test_sdpa_sliced(self):
642642
tolerance = {"rtol": 1e-2, "atol": 1e-2}
643643
self.assertTrue(mx.allclose(ref, out, **tolerance))
644644

645+
def test_sdpa_vmap_gqa(self):
646+
"""Test vmap over SDPA with GQA shapes (n_heads != n_kv_heads).
647+
648+
Reproduces https://github.com/ml-explore/mlx/issues/3383
649+
"""
650+
D = 64
651+
L = 4
652+
scale = 1.0 / math.sqrt(D)
653+
654+
# (n_q_heads, n_kv_heads) pairs: MHA, GQA, MQA
655+
head_configs = [
656+
(4, 4), # MHA — works (baseline)
657+
(4, 2), # GQA — hangs/crashes
658+
(4, 1), # MQA — hangs/crashes
659+
(8, 2), # GQA with larger ratio
660+
]
661+
662+
for n_q, n_kv in head_configs:
663+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv):
664+
B = 2
665+
q = mx.random.normal((B, n_q, L, D))
666+
k = mx.random.normal((B, n_kv, L, D))
667+
v = mx.random.normal((B, n_kv, L, D))
668+
mx.eval(q, k, v)
669+
670+
def f(qi, ki, vi):
671+
return mx.fast.scaled_dot_product_attention(
672+
qi[None], ki[None], vi[None], scale=scale
673+
)[0]
674+
675+
# Reference: manual loop over batch
676+
refs = []
677+
for i in range(B):
678+
refs.append(f(q[i], k[i], v[i]))
679+
ref = mx.stack(refs)
680+
681+
# vmap version
682+
out = mx.vmap(f)(q, k, v)
683+
mx.eval(out)
684+
685+
self.assertListEqual(list(ref.shape), list(out.shape))
686+
self.assertTrue(
687+
mx.allclose(ref, out, atol=1e-5, rtol=1e-3),
688+
f"vmap output mismatch for n_q={n_q}, n_kv={n_kv}",
689+
)
690+
691+
def test_sdpa_vmap_gqa_grad(self):
692+
"""Test vmap(grad) over SDPA with GQA shapes.
693+
694+
Reproduces https://github.com/ml-explore/mlx/issues/3383
695+
"""
696+
D = 64
697+
L = 4
698+
scale = 1.0 / math.sqrt(D)
699+
700+
for n_q, n_kv in [(4, 4), (4, 2), (4, 1)]:
701+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv):
702+
B = 2
703+
q = mx.random.normal((B, n_q, L, D))
704+
k = mx.random.normal((B, n_kv, L, D))
705+
v = mx.random.normal((B, n_kv, L, D))
706+
mx.eval(q, k, v)
707+
708+
def loss(qi, ki, vi):
709+
return mx.mean(
710+
mx.fast.scaled_dot_product_attention(
711+
qi[None], ki[None], vi[None], scale=scale
712+
)
713+
)
714+
715+
grad_fn = mx.grad(loss)
716+
vmap_grad = mx.vmap(grad_fn)
717+
out = vmap_grad(q, k, v)
718+
mx.eval(out)
719+
720+
self.assertEqual(out.shape, q.shape)
721+
722+
def test_sdpa_vmap_gqa_with_mask(self):
723+
"""Test vmap over SDPA with GQA shapes and causal masking.
724+
725+
Reproduces https://github.com/ml-explore/mlx/issues/3383
726+
"""
727+
D = 64
728+
L = 8
729+
scale = 1.0 / math.sqrt(D)
730+
731+
for n_q, n_kv in [(4, 2), (4, 1)]:
732+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv):
733+
B = 2
734+
q = mx.random.normal((B, n_q, L, D))
735+
k = mx.random.normal((B, n_kv, L, D))
736+
v = mx.random.normal((B, n_kv, L, D))
737+
mx.eval(q, k, v)
738+
739+
def f(qi, ki, vi):
740+
return mx.fast.scaled_dot_product_attention(
741+
qi[None], ki[None], vi[None], scale=scale, mask="causal"
742+
)[0]
743+
744+
ref = mx.stack([f(q[i], k[i], v[i]) for i in range(B)])
745+
746+
out = mx.vmap(f)(q, k, v)
747+
mx.eval(out)
748+
749+
self.assertListEqual(list(ref.shape), list(out.shape))
750+
self.assertTrue(
751+
mx.allclose(ref, out, atol=1e-5, rtol=1e-3),
752+
f"vmap+causal mismatch for n_q={n_q}, n_kv={n_kv}",
753+
)
754+
645755

646756
if __name__ == "__main__":
647757
mlx_tests.MLXTestRunner(failfast=True)

0 commit comments

Comments
 (0)