Skip to content

Commit af26ee9

Browse files
committed
Re-enable warp-agnostic ROCm SDPA kernel
Re-enable the optimized SDPA kernel with the warp-size agnostic implementation. The kernel uses 32-thread tiles for consistent behavior across RDNA and CDNA architectures. The memory fault issue appears to be elsewhere in the inference pipeline, not in SDPA.
1 parent a6bf8cb commit af26ee9

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

mlx/backend/rocm/scaled_dot_product_attention.hip

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,26 @@ bool supports_sdpa_vector(
216216
bool has_arr_mask,
217217
bool do_causal,
218218
bool output_logsumexp) {
219-
// Temporarily disable optimized SDPA to debug memory fault
220-
// The memory fault occurs even with SDPA disabled, so the issue is elsewhere
221-
return false;
219+
if (output_logsumexp) {
220+
return false;
221+
}
222+
223+
// Check for supported dtypes
224+
if (q.dtype() != float32 && q.dtype() != float16 && q.dtype() != bfloat16) {
225+
return false;
226+
}
227+
228+
const int value_head_dim = v.shape(-1);
229+
const int query_head_dim = q.shape(-1);
230+
const int query_sequence_length = q.shape(2);
231+
232+
const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
233+
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
234+
235+
const bool supported_vector_config =
236+
sdpa_supported_head_dim && query_sequence_length < 4;
237+
238+
return supported_vector_config && !has_arr_mask;
222239
}
223240

224241
void sdpa_vector(

0 commit comments

Comments
 (0)