File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -2838,10 +2838,12 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
28382838 res += ggml_type_size (GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2 ));
28392839 }
28402840
2841- // TurboFlash two-pass: always reserve partial result buffer to avoid graph reallocations
2842- // partial_out: float[n_bh * n_blocks * dv]
2843- // partial_ms: float[n_bh * n_blocks * 2] (max + sum per block)
2844- {
2841+ // TurboFlash two-pass temp is only needed when the TurboFlash path is eligible.
2842+ // Reserving it unconditionally can massively inflate graph scratch usage for
2843+ // large-context models even when the normal FA path is selected.
2844+ if (ggml_metal_op_flash_attn_ext_use_turbo_flash (op)) {
2845+ // partial_out: float[n_bh * n_blocks * dv]
2846+ // partial_ms: float[n_bh * n_blocks * 2] (max + sum per block)
28452847 const int64_t n_bh = ne01 * ne02 * ne03;
28462848 const int64_t ne11 = op->src [1 ]->ne [1 ]; // T_kv
28472849 const int64_t n_blocks = (ne11 + 63 ) / 64 ; // ceil(T_kv / 64)
You can’t perform that action at this time.
0 commit comments