diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..bdb875691d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -50,6 +50,83 @@ std::pair, std::vector> Custom::vmap( return {outputs, out_axes}; } +std::pair, std::vector> ScaledDotProductAttention::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto s = stream(); + + // Sinks require 1-D input; fall back to generic vmap for that case. + if (has_sinks_) { + return Custom::vmap(inputs, axes); + } + + // Determine vmap size from the first mapped input. + int vmap_size = -1; + for (int i = 0; i < static_cast(axes.size()); ++i) { + if (axes[i] != -1) { + vmap_size = inputs[i].shape(axes[i]); + break; + } + } + + auto prepare = [&s, vmap_size](const array& x, int ax) -> array { + if (ax == -1) { + return repeat(expand_dims(x, 0, s), vmap_size, 0, s); + } + if (ax != 0) { + return moveaxis(x, ax, 0, s); + } + return x; + }; + + auto q = prepare(inputs[0], axes[0]); + auto k = prepare(inputs[1], axes[1]); + auto v = prepare(inputs[2], axes[2]); + + // [V, B, H, L, D] -> [V*B, H, L, D] + auto merge_batch = [&s, vmap_size](const array& x) -> array { + auto shape = x.shape(); + Shape new_shape = {vmap_size * shape[1]}; + new_shape.insert(new_shape.end(), shape.begin() + 2, shape.end()); + return reshape(x, std::move(new_shape), s); + }; + + q = merge_batch(q); + k = merge_batch(k); + v = merge_batch(v); + + std::optional mask_arr; + bool has_arr_mask = !do_causal_ && inputs.size() > 3; + if (has_arr_mask) { + mask_arr = merge_batch(prepare(inputs[3], axes[3])); + } + std::string mask_mode = do_causal_ ? "causal" : has_arr_mask ? "array" : ""; + + auto out = scaled_dot_product_attention( + q, k, v, scale_, mask_mode, mask_arr, std::nullopt, s); + + // [V*B, H, L, D] -> [V, B, H, L, D] + auto split_batch = [&s, vmap_size](const array& x) -> array { + auto shape = x.shape(); + Shape new_shape = {vmap_size, shape[0] / vmap_size}; + new_shape.insert(new_shape.end(), shape.begin() + 1, shape.end()); + return reshape(x, std::move(new_shape), s); + }; + + out = split_batch(out); + + // The re-invoked SDPA may produce a logsumexp sibling when training. + if (output_logsumexp_) { + assert( + !out.siblings().empty() && + "vmap'd SDPA expected logsumexp sibling output"); + auto lse = split_batch(out.siblings()[0]); + return {{out, lse}, {0, 0}}; + } + + return {{out}, {0}}; +} + array rms_norm( const array& x, const std::optional& weight, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..0489e85d83 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -244,6 +244,10 @@ class ScaledDotProductAttention : public Custom { const std::vector& argnums, const std::vector& outputs) override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + bool is_equivalent(const Primitive& other) const override; DEFINE_NAME(ScaledDotProductAttention); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7606373ce4..f9ce077ac3 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -642,6 +642,46 @@ def test_sdpa_sliced(self): tolerance = {"rtol": 1e-2, "atol": 1e-2} self.assertTrue(mx.allclose(ref, out, **tolerance)) + @unittest.skipIf(not mx.metal.is_available(), "Metal kernel required") + def test_sdpa_vmap_uses_fused_kernel(self): + """Verify vmap'd SDPA dispatches the fused Metal kernel, not the + decomposed fallback. The fused kernel and the fallback use different + accumulation orders, producing distinguishable float16 results for + large enough shapes. We check that vmap output matches the + non-vmapped (fused) output exactly.""" + D = 64 + L = 128 # L > 8 → sdpa_full kernel path + scale = 1.0 / math.sqrt(D) + B = 4 + + for n_q, n_kv in [(32, 32), (32, 8), (16, 4)]: + with self.subTest(n_q_heads=n_q, n_kv_heads=n_kv): + mx.random.seed(42) + q = mx.random.normal((B, n_q, L, D), dtype=mx.float16) + k = mx.random.normal((B, n_kv, L, D), dtype=mx.float16) + v = mx.random.normal((B, n_kv, L, D), dtype=mx.float16) + mx.eval(q, k, v) + + # Non-vmapped SDPA — uses fused Metal kernel + kernel_out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + + # Vmapped SDPA — should also use fused kernel + def f(qi, ki, vi): + return mx.fast.scaled_dot_product_attention( + qi[None], ki[None], vi[None], scale=scale + )[0] + + vmap_out = mx.vmap(f)(q, k, v) + mx.eval(kernel_out, vmap_out) + + # Fused kernel output must match exactly (same kernel, same + # accumulation). If vmap fell back to the decomposed path, + # float16 rounding differences would cause a mismatch. + self.assertTrue( + mx.array_equal(kernel_out, vmap_out), + f"vmap path did not use fused kernel for ({n_q},{n_kv})", + ) + if __name__ == "__main__": mlx_tests.MLXTestRunner(failfast=True)