Question about metal gemm #3380
Unanswered
Lazarus-931
asked this question in
Q&A
Replies: 1 comment
-
|
For refrence, I was testing mlx's sdpa like this: import mlx.core as mx, time, sys
seq = int(sys.argv[1]) if len(sys.argv) > 1 else 2048
d = int(sys.argv[2]) if len(sys.argv) > 2 else 128
iters = int(sys.argv[3]) if len(sys.argv) > 3 else 20
Q = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
K = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
V = mx.random.normal((1, 1, seq, d)).astype(mx.float16)
mx.eval(Q, K, V)
for _ in range(5):
O = mx.fast.scaled_dot_product_attention(Q, K, V, scale=1.0/d**0.5)
mx.eval(O)
times = []
for _ in range(iters):
t0 = time.perf_counter()
O = mx.fast.scaled_dot_product_attention(Q, K, V, scale=1.0/d**0.5)
mx.eval(O)
times.append((time.perf_counter() - t0) * 1e6)
times.sort()
t = times[len(times) // 2]
flops = 4 * seq * seq * d + 5 * seq * seq
print(f"{t:.0f},{flops / (t * 1e3):.1f}") |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've written a fused attention kernel targeting M2 and I'm benchmarking against MLX's
scaled_dot_product_attention. After several rounds of optimization I'm still ~2x slower and I'm trying to understand what architectural choices explain the gap.The Kernel
Benchmarking setup
threadgroups = (1, seq/16, 1),threads_per_tg = (128, 1, 1)cmd->GPUEndTime() - cmd->GPUStartTime(), median of 10 runs after 3 warmupTimes
Key design choices:
BlockMMA<BN>structThis was specifically a standard attention: tile Q into 16-row blocks, stream K/V in 128-column tiles, online softmax with running max/sum, accumulate O in registers.
Note: my kernel is specialized for seq=2048, d=128 with compile-time constants (loop bounds, strides, tile counts are all constexpr). MLX's SDPA is fully general. Despite this advantage, I'm still 2.1x slower — which suggests the gap is
architectural, not from runtime overhead.
Any thoughts or advice would be much appreciated!
Beta Was this translation helpful? Give feedback.
All reactions