Skip to content

Commit f2f1380

Browse files
Brooooooklynclaude
andcommitted
[Metal] Fused Flash Attention backward (VJP) kernels
Add fused Flash Attention backward pass (VJP) kernels for Apple Silicon GPUs, implementing the two-kernel architecture from Flash Attention 2 (Dao, 2023). The fused backward eliminates O(L^2) attention matrix materialization, reducing peak memory by 70-95% with auto-dispatch routing between fused and unfused paths. Key additions: - Two Metal kernels: steel_attention_vjp_dq and steel_attention_vjp_dkv - JIT compilation with baked constants for dead-code elimination - Delta precomputation as lazy MLX graph ops - Threadgroup memory aliasing (23KB -> 14.8KB for D=128 dKV) - Sparse block mask support via function constant gating - Auto-dispatch: causal L thresholds + 1GB memory ceiling - Support for D={64,96,128}, float16/bfloat16, causal, GQA - 2-pass vector kernel LSE output for VJP logsumexp Performance (M3 Max, B=1 H=32 causal float16, fused vs unfused): D=64: 1.22-1.40x faster D=96: 1.29-1.35x faster D=128: 0.77-0.81x (memory trade-off, 70-82% savings) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fd6d304 commit f2f1380

22 files changed

+3326
-125
lines changed

benchmarks/python/sdpa_vector_vjp_bench.py

Lines changed: 430 additions & 0 deletions
Large diffs are not rendered by default.

mlx/backend/cuda/scaled_dot_product_attention.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,14 @@ void ScaledDotProductAttention::eval_gpu(
626626
}
627627
}
628628

629-
bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) {
629+
bool ScaledDotProductAttentionVJP::use_fallback(
630+
const array& q,
631+
const array& /* k */,
632+
Stream s,
633+
bool,
634+
bool,
635+
bool,
636+
int) {
630637
// The frontend adds a padding mask when sequence length is not a multiple of
631638
// tile size.
632639
if (q.shape(2) % 128 != 0) {
@@ -642,8 +649,8 @@ void ScaledDotProductAttentionVJP::eval_gpu(
642649

643650
auto& s = stream();
644651

645-
assert(inputs.size() >= 6);
646-
int primals_size = inputs.size() - 3;
652+
assert(inputs.size() >= 7); // primals(>=3) + O + LSE + dO + delta
653+
int primals_size = inputs.size() - 4;
647654
bool has_arr_mask = primals_size > 3 + has_sinks_;
648655

649656
array q = prepare_sdpa_input(inputs[0], s);
@@ -659,7 +666,7 @@ void ScaledDotProductAttentionVJP::eval_gpu(
659666
}
660667
std::optional<array> sinks;
661668
if (has_sinks_) {
662-
sinks = prepare_sdpa_sinks(inputs.back(), s);
669+
sinks = prepare_sdpa_sinks(inputs[primals_size - 1], s);
663670
}
664671

665672
assert(outputs.size() == 3);

mlx/backend/metal/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ if(MLX_METAL_JIT)
8282
make_jit_source(gemv_masked)
8383

8484
make_jit_source(steel/attn/kernels/steel_attention)
85+
make_jit_source(steel/attn/kernels/steel_attention_vjp_dq)
86+
make_jit_source(steel/attn/kernels/steel_attention_vjp_dkv)
8587

8688
make_jit_source(
8789
steel/gemm/gemm_nax kernels/steel/utils.h kernels/steel/gemm/nax.h

mlx/backend/metal/jit/includes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ const char* steel_conv_3d();
4545
const char* steel_conv_general();
4646
const char* gemv_masked();
4747
const char* steel_attention();
48+
const char* steel_attention_vjp_dq();
49+
const char* steel_attention_vjp_dkv();
4850

4951
const char* gemm_nax();
5052
const char* steel_gemm_fused_nax();

mlx/backend/metal/jit_kernels.cpp

Lines changed: 174 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
// Copyright © 2024 Apple Inc.
2+
#include <cmath>
3+
#include <cstdio>
4+
#include <cstring>
5+
26
#include "mlx/backend/common/compiled.h"
37
#include "mlx/backend/metal/jit/includes.h"
48
#include "mlx/backend/metal/kernels.h"
@@ -1076,6 +1080,85 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel(
10761080
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
10771081
}
10781082

1083+
namespace {
1084+
1085+
// Produce a valid MSL float literal with enough precision for exact round-trip.
1086+
std::string float_to_msl(float v) {
1087+
char buf[32];
1088+
std::snprintf(buf, sizeof(buf), "%.9gf", v);
1089+
return std::string(buf);
1090+
}
1091+
1092+
// Encode a float's bits as hex for use in cache keys (exact matching).
1093+
std::string scale_to_hex(float v) {
1094+
uint32_t bits;
1095+
std::memcpy(&bits, &v, sizeof(bits));
1096+
char buf[16];
1097+
std::snprintf(buf, sizeof(buf), "%08x", bits);
1098+
return std::string(buf);
1099+
}
1100+
1101+
// Shared implementation for VJP dQ/dKV JIT kernel dispatch.
1102+
// Both kernels use identical lib_name construction, #define baking, and caching.
1103+
MTL::ComputePipelineState* get_steel_attention_vjp_kernel_impl(
1104+
metal::Device& d,
1105+
const std::string& kernel_name,
1106+
const array& q,
1107+
int bq,
1108+
int bk,
1109+
int bd,
1110+
int wm,
1111+
int wn,
1112+
int gqa_factor,
1113+
float scale,
1114+
float scale_log2,
1115+
bool align_Q,
1116+
bool align_K,
1117+
bool do_causal,
1118+
bool has_block_mask,
1119+
const char* shader_source(),
1120+
const char* template_name) {
1121+
std::string lib_name = kernel_name;
1122+
lib_name += "_gqa" + std::to_string(gqa_factor);
1123+
lib_name += "_s" + scale_to_hex(scale);
1124+
lib_name += align_Q ? "_aQ" : "_nQ";
1125+
lib_name += align_K ? "_aK" : "_nK";
1126+
lib_name += do_causal ? "_c" : "_nc";
1127+
lib_name += has_block_mask ? "_bm" : "_nbm";
1128+
1129+
auto lib = d.get_library(lib_name, [&]() {
1130+
std::string defines;
1131+
defines += "#define VJP_GQA_FACTOR " + std::to_string(gqa_factor) + "\n";
1132+
defines += "#define VJP_SCALE " + float_to_msl(scale) + "\n";
1133+
defines += "#define VJP_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n";
1134+
defines += "#define VJP_BAKED_FC 1\n";
1135+
defines +=
1136+
"#define VJP_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n";
1137+
defines +=
1138+
"#define VJP_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n";
1139+
defines += "#define VJP_DO_CAUSAL " +
1140+
std::string(do_causal ? "true" : "false") + "\n";
1141+
defines += "#define VJP_HAS_BLOCK_MASK " +
1142+
std::string(has_block_mask ? "true" : "false") + "\n";
1143+
1144+
std::string kernel_source;
1145+
concatenate(
1146+
kernel_source,
1147+
metal::utils(),
1148+
defines,
1149+
shader_source(),
1150+
get_template_definition(
1151+
kernel_name,
1152+
template_name,
1153+
get_type_string(q.dtype()),
1154+
bq, bk, bd, wm, wn));
1155+
return kernel_source;
1156+
});
1157+
return d.get_kernel(kernel_name, lib);
1158+
}
1159+
1160+
} // namespace
1161+
10791162
MTL::ComputePipelineState* get_steel_attention_kernel(
10801163
metal::Device& d,
10811164
const std::string& kernel_name,
@@ -1087,16 +1170,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
10871170
int bd,
10881171
int wm,
10891172
int wn,
1090-
const array& m) {
1091-
const auto& lib_name = kernel_name;
1173+
const array& m,
1174+
int gqa_factor,
1175+
float scale,
1176+
bool align_Q,
1177+
bool align_K,
1178+
bool has_mask,
1179+
bool do_causal,
1180+
bool has_sinks,
1181+
bool output_logsumexp) {
1182+
std::string lib_name = kernel_name;
1183+
lib_name += "_gqa" + std::to_string(gqa_factor);
1184+
lib_name += "_s" + scale_to_hex(scale);
1185+
lib_name += align_Q ? "_aQ" : "_nQ";
1186+
lib_name += align_K ? "_aK" : "_nK";
1187+
lib_name += has_mask ? "_m" : "_nm";
1188+
lib_name += do_causal ? "_c" : "_nc";
1189+
lib_name += has_sinks ? "_sk" : "_nsk";
1190+
lib_name += output_logsumexp ? "_lse" : "_nlse";
1191+
1192+
float scale_log2 = static_cast<float>(scale * M_LOG2E);
1193+
10921194
auto lib = d.get_library(lib_name, [&]() {
1195+
std::string defines;
1196+
defines += "#define FWD_GQA_FACTOR " + std::to_string(gqa_factor) + "\n";
1197+
defines += "#define FWD_SCALE_LOG2 " + float_to_msl(scale_log2) + "\n";
1198+
defines += "#define FWD_BAKED_FC 1\n";
1199+
defines +=
1200+
"#define FWD_ALIGN_Q " + std::string(align_Q ? "true" : "false") + "\n";
1201+
defines +=
1202+
"#define FWD_ALIGN_K " + std::string(align_K ? "true" : "false") + "\n";
1203+
defines +=
1204+
"#define FWD_HAS_MASK " + std::string(has_mask ? "true" : "false") +
1205+
"\n";
1206+
defines += "#define FWD_DO_CAUSAL " +
1207+
std::string(do_causal ? "true" : "false") + "\n";
1208+
defines += "#define FWD_HAS_SINKS " +
1209+
std::string(has_sinks ? "true" : "false") + "\n";
1210+
defines += "#define FWD_OUTPUT_LOGSUMEXP " +
1211+
std::string(output_logsumexp ? "true" : "false") + "\n";
10931212
std::string kernel_source;
10941213
concatenate(
10951214
kernel_source,
10961215
metal::utils(),
1216+
defines,
10971217
metal::steel_attention(),
10981218
get_template_definition(
1099-
lib_name,
1219+
kernel_name,
11001220
"attention",
11011221
get_type_string(q.dtype()),
11021222
bq,
@@ -1107,7 +1227,57 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
11071227
get_type_string(m.dtype())));
11081228
return kernel_source;
11091229
});
1110-
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
1230+
return d.get_kernel(kernel_name, lib);
1231+
}
1232+
1233+
MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel(
1234+
metal::Device& d,
1235+
const std::string& kernel_name,
1236+
const std::string& /*hash_name*/,
1237+
const metal::MTLFCList& /*func_consts*/,
1238+
const array& q,
1239+
int bq,
1240+
int bk,
1241+
int bd,
1242+
int wm,
1243+
int wn,
1244+
int gqa_factor,
1245+
float scale,
1246+
float scale_log2,
1247+
bool align_Q,
1248+
bool align_K,
1249+
bool do_causal,
1250+
bool has_block_mask) {
1251+
return get_steel_attention_vjp_kernel_impl(
1252+
d, kernel_name, q, bq, bk, bd, wm, wn,
1253+
gqa_factor, scale, scale_log2, align_Q, align_K, do_causal,
1254+
has_block_mask,
1255+
metal::steel_attention_vjp_dq, "attention_vjp_dq");
1256+
}
1257+
1258+
MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel(
1259+
metal::Device& d,
1260+
const std::string& kernel_name,
1261+
const std::string& /*hash_name*/,
1262+
const metal::MTLFCList& /*func_consts*/,
1263+
const array& q,
1264+
int bq,
1265+
int bk,
1266+
int bd,
1267+
int wm,
1268+
int wn,
1269+
int gqa_factor,
1270+
float scale,
1271+
float scale_log2,
1272+
bool align_Q,
1273+
bool align_K,
1274+
bool do_causal,
1275+
bool has_block_mask) {
1276+
return get_steel_attention_vjp_kernel_impl(
1277+
d, kernel_name, q, bq, bk, bd, wm, wn,
1278+
gqa_factor, scale, scale_log2, align_Q, align_K, do_causal,
1279+
has_block_mask,
1280+
metal::steel_attention_vjp_dkv, "attention_vjp_dkv");
11111281
}
11121282

11131283
MTL::ComputePipelineState* get_steel_attention_nax_kernel(

mlx/backend/metal/kernels.h

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,53 @@ MTL::ComputePipelineState* get_steel_attention_kernel(
344344
int bd,
345345
int wm,
346346
int wn,
347-
const array& m);
347+
const array& m,
348+
int gqa_factor,
349+
float scale,
350+
bool align_Q,
351+
bool align_K,
352+
bool has_mask,
353+
bool do_causal,
354+
bool has_sinks,
355+
bool output_logsumexp);
356+
357+
MTL::ComputePipelineState* get_steel_attention_vjp_dq_kernel(
358+
metal::Device& d,
359+
const std::string& kernel_name,
360+
const std::string& hash_name,
361+
const metal::MTLFCList& func_consts,
362+
const array& q,
363+
int bq,
364+
int bk,
365+
int bd,
366+
int wm,
367+
int wn,
368+
int gqa_factor,
369+
float scale,
370+
float scale_log2,
371+
bool align_Q,
372+
bool align_K,
373+
bool do_causal,
374+
bool has_block_mask = false);
375+
376+
MTL::ComputePipelineState* get_steel_attention_vjp_dkv_kernel(
377+
metal::Device& d,
378+
const std::string& kernel_name,
379+
const std::string& hash_name,
380+
const metal::MTLFCList& func_consts,
381+
const array& q,
382+
int bq,
383+
int bk,
384+
int bd,
385+
int wm,
386+
int wn,
387+
int gqa_factor,
388+
float scale,
389+
float scale_log2,
390+
bool align_Q,
391+
bool align_K,
392+
bool do_causal,
393+
bool has_block_mask = false);
348394

349395
MTL::ComputePipelineState* get_steel_attention_nax_kernel(
350396
metal::Device& d,

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ set(STEEL_ATTN_HEADERS
9898
steel/attn/transforms.h
9999
steel/attn/kernels/steel_attention.h)
100100

101+
set(STEEL_ATTN_VJP_HEADERS
102+
steel/defines.h
103+
steel/utils.h
104+
steel/attn/attn.h
105+
steel/attn/loader.h
106+
steel/attn/mma.h
107+
steel/attn/transforms.h
108+
steel/attn/params.h
109+
steel/attn/kernels/steel_attention_vjp_dq.h
110+
steel/attn/kernels/steel_attention_vjp_dkv.h)
111+
101112
set(STEEL_NAX_HEADERS
102113
steel/defines.h
103114
steel/utils.h
@@ -153,6 +164,8 @@ if(NOT MLX_METAL_JIT)
153164
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
154165
build_kernel(gemv_masked steel/utils.h)
155166
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
167+
build_kernel(steel/attn/kernels/steel_attention_vjp_dq ${STEEL_ATTN_VJP_HEADERS})
168+
build_kernel(steel/attn/kernels/steel_attention_vjp_dkv ${STEEL_ATTN_VJP_HEADERS})
156169

157170
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
158171
26.2))

0 commit comments

Comments
 (0)