Commit 985fdd0
[Metal] Fused Flash Attention backward (VJP) kernels
Add fused Flash Attention backward pass (VJP) kernels for Apple Silicon
GPUs, implementing the two-kernel architecture from Flash Attention 2
(Dao, 2023). The fused backward eliminates O(L^2) attention matrix
materialization, reducing peak memory by 70-95% with auto-dispatch
routing between fused and unfused paths.
Key additions:
- Two Metal kernels: steel_attention_vjp_dq and steel_attention_vjp_dkv
- JIT compilation with baked constants for dead-code elimination
- Delta precomputation as lazy MLX graph ops
- Threadgroup memory aliasing (23KB -> 14.8KB for D=128 dKV)
- Sparse block mask support via function constant gating
- Auto-dispatch: causal L thresholds + 1GB memory ceiling
- Support for D={64,96,128}, float16/bfloat16, causal, GQA
- 2-pass vector kernel LSE output for VJP logsumexp
Performance (M3 Max, B=1 H=32 causal float16, fused vs unfused):
D=64: 1.22-1.40x faster
D=96: 1.29-1.35x faster
D=128: 0.77-0.81x (memory trade-off, 70-82% savings)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent 38ad257 commit 985fdd0
File tree
22 files changed
+3326
-126
lines changed- benchmarks/python
- mlx
- backend
- cuda
- metal
- jit
- kernels
- steel/attn
- kernels
- no_gpu
- python/tests
22 files changed
+3326
-126
lines changedLarge diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
551 | 551 | | |
552 | 552 | | |
553 | 553 | | |
554 | | - | |
555 | 554 | | |
556 | 555 | | |
557 | 556 | | |
| |||
614 | 613 | | |
615 | 614 | | |
616 | 615 | | |
617 | | - | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
618 | 624 | | |
619 | 625 | | |
620 | 626 | | |
| |||
630 | 636 | | |
631 | 637 | | |
632 | 638 | | |
633 | | - | |
634 | | - | |
| 639 | + | |
| 640 | + | |
635 | 641 | | |
636 | 642 | | |
637 | 643 | | |
| |||
647 | 653 | | |
648 | 654 | | |
649 | 655 | | |
650 | | - | |
| 656 | + | |
651 | 657 | | |
652 | 658 | | |
653 | 659 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
82 | 82 | | |
83 | 83 | | |
84 | 84 | | |
| 85 | + | |
| 86 | + | |
85 | 87 | | |
86 | 88 | | |
87 | 89 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
45 | 45 | | |
46 | 46 | | |
47 | 47 | | |
| 48 | + | |
| 49 | + | |
48 | 50 | | |
49 | 51 | | |
50 | 52 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
2 | 6 | | |
3 | 7 | | |
4 | 8 | | |
| |||
1076 | 1080 | | |
1077 | 1081 | | |
1078 | 1082 | | |
| 1083 | + | |
| 1084 | + | |
| 1085 | + | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
| 1089 | + | |
| 1090 | + | |
| 1091 | + | |
| 1092 | + | |
| 1093 | + | |
| 1094 | + | |
| 1095 | + | |
| 1096 | + | |
| 1097 | + | |
| 1098 | + | |
| 1099 | + | |
| 1100 | + | |
| 1101 | + | |
| 1102 | + | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
| 1109 | + | |
| 1110 | + | |
| 1111 | + | |
| 1112 | + | |
| 1113 | + | |
| 1114 | + | |
| 1115 | + | |
| 1116 | + | |
| 1117 | + | |
| 1118 | + | |
| 1119 | + | |
| 1120 | + | |
| 1121 | + | |
| 1122 | + | |
| 1123 | + | |
| 1124 | + | |
| 1125 | + | |
| 1126 | + | |
| 1127 | + | |
| 1128 | + | |
| 1129 | + | |
| 1130 | + | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
| 1136 | + | |
| 1137 | + | |
| 1138 | + | |
| 1139 | + | |
| 1140 | + | |
| 1141 | + | |
| 1142 | + | |
| 1143 | + | |
| 1144 | + | |
| 1145 | + | |
| 1146 | + | |
| 1147 | + | |
| 1148 | + | |
| 1149 | + | |
| 1150 | + | |
| 1151 | + | |
| 1152 | + | |
| 1153 | + | |
| 1154 | + | |
| 1155 | + | |
| 1156 | + | |
| 1157 | + | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
1079 | 1162 | | |
1080 | 1163 | | |
1081 | 1164 | | |
| |||
1087 | 1170 | | |
1088 | 1171 | | |
1089 | 1172 | | |
1090 | | - | |
1091 | | - | |
| 1173 | + | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
| 1177 | + | |
| 1178 | + | |
| 1179 | + | |
| 1180 | + | |
| 1181 | + | |
| 1182 | + | |
| 1183 | + | |
| 1184 | + | |
| 1185 | + | |
| 1186 | + | |
| 1187 | + | |
| 1188 | + | |
| 1189 | + | |
| 1190 | + | |
| 1191 | + | |
| 1192 | + | |
| 1193 | + | |
1092 | 1194 | | |
| 1195 | + | |
| 1196 | + | |
| 1197 | + | |
| 1198 | + | |
| 1199 | + | |
| 1200 | + | |
| 1201 | + | |
| 1202 | + | |
| 1203 | + | |
| 1204 | + | |
| 1205 | + | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
1093 | 1212 | | |
1094 | 1213 | | |
1095 | 1214 | | |
1096 | 1215 | | |
| 1216 | + | |
1097 | 1217 | | |
1098 | 1218 | | |
1099 | | - | |
| 1219 | + | |
1100 | 1220 | | |
1101 | 1221 | | |
1102 | 1222 | | |
| |||
1107 | 1227 | | |
1108 | 1228 | | |
1109 | 1229 | | |
1110 | | - | |
| 1230 | + | |
| 1231 | + | |
| 1232 | + | |
| 1233 | + | |
| 1234 | + | |
| 1235 | + | |
| 1236 | + | |
| 1237 | + | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
| 1246 | + | |
| 1247 | + | |
| 1248 | + | |
| 1249 | + | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
| 1253 | + | |
| 1254 | + | |
| 1255 | + | |
| 1256 | + | |
| 1257 | + | |
| 1258 | + | |
| 1259 | + | |
| 1260 | + | |
| 1261 | + | |
| 1262 | + | |
| 1263 | + | |
| 1264 | + | |
| 1265 | + | |
| 1266 | + | |
| 1267 | + | |
| 1268 | + | |
| 1269 | + | |
| 1270 | + | |
| 1271 | + | |
| 1272 | + | |
| 1273 | + | |
| 1274 | + | |
| 1275 | + | |
| 1276 | + | |
| 1277 | + | |
| 1278 | + | |
| 1279 | + | |
| 1280 | + | |
1111 | 1281 | | |
1112 | 1282 | | |
1113 | 1283 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
344 | 344 | | |
345 | 345 | | |
346 | 346 | | |
347 | | - | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
348 | 394 | | |
349 | 395 | | |
350 | 396 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
98 | 98 | | |
99 | 99 | | |
100 | 100 | | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
101 | 112 | | |
102 | 113 | | |
103 | 114 | | |
| |||
153 | 164 | | |
154 | 165 | | |
155 | 166 | | |
| 167 | + | |
| 168 | + | |
156 | 169 | | |
157 | 170 | | |
158 | 171 | | |
| |||
0 commit comments