Skip to content

Commit 6fbbec5

Browse files
committed
Fix SDPA vmap with GQA/MQA shapes (n_heads != n_kv_heads)
The ScaledDotProductAttention primitive relied on Custom::vmap which re-vmapped the fallback lambda. That lambda captured n_q_heads and n_kv_heads at creation time, causing shape mismatches (SIGSEGV/hang) when vmap changed the array dimensions. Add a dedicated vmap override that merges the vmap axis into the batch dimension and re-invokes scaled_dot_product_attention, which recomputes head counts from actual shapes. Falls back to Custom::vmap for sinks.
1 parent b98831a commit 6fbbec5

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

mlx/fast.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,83 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
5050
return {outputs, out_axes};
5151
}
5252

53+
std::pair<std::vector<array>, std::vector<int>> ScaledDotProductAttention::vmap(
54+
const std::vector<array>& inputs,
55+
const std::vector<int>& axes) {
56+
auto s = stream();
57+
58+
// Sinks require 1-D input; fall back to generic vmap for that case.
59+
if (has_sinks_) {
60+
return Custom::vmap(inputs, axes);
61+
}
62+
63+
// Determine vmap size from the first mapped input.
64+
int vmap_size = -1;
65+
for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
66+
if (axes[i] != -1) {
67+
vmap_size = inputs[i].shape(axes[i]);
68+
break;
69+
}
70+
}
71+
72+
auto prepare = [&s, vmap_size](const array& x, int ax) -> array {
73+
if (ax == -1) {
74+
return repeat(expand_dims(x, 0, s), vmap_size, 0, s);
75+
}
76+
if (ax != 0) {
77+
return moveaxis(x, ax, 0, s);
78+
}
79+
return x;
80+
};
81+
82+
auto q = prepare(inputs[0], axes[0]);
83+
auto k = prepare(inputs[1], axes[1]);
84+
auto v = prepare(inputs[2], axes[2]);
85+
86+
// [V, B, H, L, D] -> [V*B, H, L, D]
87+
auto merge_batch = [&s, vmap_size](const array& x) -> array {
88+
auto shape = x.shape();
89+
Shape new_shape = {vmap_size * shape[1]};
90+
new_shape.insert(new_shape.end(), shape.begin() + 2, shape.end());
91+
return reshape(x, std::move(new_shape), s);
92+
};
93+
94+
q = merge_batch(q);
95+
k = merge_batch(k);
96+
v = merge_batch(v);
97+
98+
std::optional<array> mask_arr;
99+
bool has_arr_mask = !do_causal_ && inputs.size() > 3;
100+
if (has_arr_mask) {
101+
mask_arr = merge_batch(prepare(inputs[3], axes[3]));
102+
}
103+
std::string mask_mode = do_causal_ ? "causal" : has_arr_mask ? "array" : "";
104+
105+
auto out = scaled_dot_product_attention(
106+
q, k, v, scale_, mask_mode, mask_arr, std::nullopt, s);
107+
108+
// [V*B, H, L, D] -> [V, B, H, L, D]
109+
auto split_batch = [&s, vmap_size](const array& x) -> array {
110+
auto shape = x.shape();
111+
Shape new_shape = {vmap_size, shape[0] / vmap_size};
112+
new_shape.insert(new_shape.end(), shape.begin() + 1, shape.end());
113+
return reshape(x, std::move(new_shape), s);
114+
};
115+
116+
out = split_batch(out);
117+
118+
// The re-invoked SDPA may produce a logsumexp sibling when training.
119+
if (output_logsumexp_) {
120+
assert(
121+
!out.siblings().empty() &&
122+
"vmap'd SDPA expected logsumexp sibling output");
123+
auto lse = split_batch(out.siblings()[0]);
124+
return {{out, lse}, {0, 0}};
125+
}
126+
127+
return {{out}, {0}};
128+
}
129+
53130
array rms_norm(
54131
const array& x,
55132
const std::optional<array>& weight,

mlx/fast_primitives.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ class ScaledDotProductAttention : public Custom {
244244
const std::vector<int>& argnums,
245245
const std::vector<array>& outputs) override;
246246

247+
std::pair<std::vector<array>, std::vector<int>> vmap(
248+
const std::vector<array>& inputs,
249+
const std::vector<int>& axes) override;
250+
247251
bool is_equivalent(const Primitive& other) const override;
248252

249253
DEFINE_NAME(ScaledDotProductAttention);

python/tests/test_fast_sdpa.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,70 @@ def test_sdpa_sliced(self):
642642
tolerance = {"rtol": 1e-2, "atol": 1e-2}
643643
self.assertTrue(mx.allclose(ref, out, **tolerance))
644644

645+
def test_sdpa_vmap_gqa(self):
646+
"""Regression test for https://github.com/ml-explore/mlx/issues/3383"""
647+
D = 64
648+
scale = 1.0 / math.sqrt(D)
649+
B = 2
650+
atol = 1e-5
651+
652+
# (n_q_heads, n_kv_heads, seq_len, mask)
653+
configs = [
654+
(4, 4, 4, None),
655+
(4, 2, 4, None),
656+
(4, 1, 4, None),
657+
(8, 2, 4, None),
658+
(4, 2, 8, "causal"),
659+
(4, 1, 8, "causal"),
660+
]
661+
662+
for n_q, n_kv, L, mask in configs:
663+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv, L=L, mask=mask):
664+
q = mx.random.normal((B, n_q, L, D))
665+
k = mx.random.normal((B, n_kv, L, D))
666+
v = mx.random.normal((B, n_kv, L, D))
667+
mx.eval(q, k, v)
668+
669+
def f(qi, ki, vi):
670+
return mx.fast.scaled_dot_product_attention(
671+
qi[None], ki[None], vi[None], scale=scale, mask=mask
672+
)[0]
673+
674+
ref = mx.stack([f(q[i], k[i], v[i]) for i in range(B)])
675+
out = mx.vmap(f)(q, k, v)
676+
mx.eval(out)
677+
678+
self.assertListEqual(list(ref.shape), list(out.shape))
679+
self.assertTrue(mx.allclose(ref, out, atol=atol, rtol=1e-3))
680+
681+
def test_sdpa_vmap_gqa_grad(self):
682+
"""Regression test for https://github.com/ml-explore/mlx/issues/3383"""
683+
D = 64
684+
scale = 1.0 / math.sqrt(D)
685+
B = 2
686+
tolerance = {"rtol": 1e-2, "atol": 1e-2}
687+
688+
for n_q, n_kv in [(4, 4), (4, 2), (4, 1)]:
689+
with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv):
690+
q = mx.random.normal((B, n_q, 4, D))
691+
k = mx.random.normal((B, n_kv, 4, D))
692+
v = mx.random.normal((B, n_kv, 4, D))
693+
mx.eval(q, k, v)
694+
695+
def loss(qi, ki, vi):
696+
return mx.mean(
697+
mx.fast.scaled_dot_product_attention(
698+
qi[None], ki[None], vi[None], scale=scale
699+
)
700+
)
701+
702+
ref = mx.stack([mx.grad(loss)(q[i], k[i], v[i]) for i in range(B)])
703+
out = mx.vmap(mx.grad(loss))(q, k, v)
704+
mx.eval(out)
705+
706+
self.assertListEqual(list(ref.shape), list(out.shape))
707+
self.assertTrue(mx.allclose(ref, out, **tolerance))
708+
645709

646710
if __name__ == "__main__":
647711
mlx_tests.MLXTestRunner(failfast=True)

0 commit comments

Comments
 (0)