-
Notifications
You must be signed in to change notification settings - Fork 935
Expand file tree
/
Copy pathbest_program.py
More file actions
508 lines (408 loc) · 18.9 KB
/
best_program.py
File metadata and controls
508 lines (408 loc) · 18.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
"""
Qwen3 Custom Metal Kernel for Grouped Query Attention (GQA) Optimization
This module implements a custom Metal kernel for Qwen3's 16:8 GQA pattern using
MLX's metal_kernel API. The kernel is designed to outperform mx.fast.scaled_dot_product_attention
by leveraging Apple Silicon specific optimizations and the 2:1 query-to-KV head ratio.
Target: Qwen3-0.6B with 16 query heads : 8 KV heads
Hardware: Apple M-series GPUs with unified memory
Baseline: Standard MLX-LM using mx.fast.scaled_dot_product_attention
Goal: 5-15% performance improvement through custom Metal kernel optimization
Evolution Target: The Metal kernel source code that computes GQA attention
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import math
from typing import Optional, Tuple, Any
import time
def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
"""
Custom Metal kernel implementation for Qwen3 GQA attention.
Args:
queries: [B, num_heads=16, L, head_dim=128]
keys: [B, num_kv_heads=8, L, head_dim=128]
values: [B, num_kv_heads=8, L, head_dim=128]
scale: Attention scaling factor (1/sqrt(head_dim))
mask: Attention mask (None, "causal", or boolean tensor)
Returns:
Attention output [B, num_heads=16, L, head_dim=128]
"""
B, num_heads, L, head_dim = queries.shape
_, num_kv_heads, _, _ = keys.shape
heads_per_kv = num_heads // num_kv_heads # 2 for Qwen3-0.6B (16:8)
# Handle mask conversion
if mask == "causal" or mask is None:
# Create causal mask for autoregressive attention
causal_mask = mx.triu(mx.ones((L, L), dtype=mx.bool_), k=1)
mask_tensor = mx.logical_not(causal_mask) # True where attention is allowed
use_mask = True
elif isinstance(mask, (mx.array, type(None))):
if mask is None:
mask_tensor = mx.ones((L, L), dtype=mx.bool_)
use_mask = False
else:
mask_tensor = mask.astype(mx.bool_)
use_mask = True
else:
# Raise error for unsupported mask types - no fallback
raise ValueError(
f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask."
)
# Expand mask to match batch and head dimensions if needed
if mask_tensor.ndim == 2:
mask_tensor = mx.broadcast_to(mask_tensor[None, None, :, :], (B, num_heads, L, L))
elif mask_tensor.ndim == 3:
mask_tensor = mx.broadcast_to(mask_tensor[:, None, :, :], (B, num_heads, L, L))
# EVOLVE-BLOCK-START
# Custom Metal kernel source for Qwen3 GQA optimization
# This kernel leverages the 16:8 head ratio and Apple Silicon architecture
kernel_source = """
// Qwen3 GQA Metal Kernel - Optimized for 16:8 head pattern
// Thread mapping: each thread processes one query position
uint thread_id = thread_position_in_grid.x;
uint head_idx = thread_position_in_grid.y;
uint batch_idx = thread_position_in_grid.z;
uint query_pos = thread_id;
// Bounds checking
if (batch_idx >= BATCH_SIZE || head_idx >= NUM_HEADS || query_pos >= SEQ_LEN) {
return;
}
// Extract scalar values from input arrays
T scale_val = scale[0];
bool use_mask_val = use_mask[0] > 0;
// GQA mapping: determine which KV head corresponds to this query head
uint kv_head_idx = head_idx / HEADS_PER_KV; // 2 query heads per KV head
// Pre-calculate base indices for memory access optimization
const uint q_base = batch_idx * (NUM_HEADS * SEQ_LEN * HEAD_DIM) +
head_idx * (SEQ_LEN * HEAD_DIM) +
query_pos * HEAD_DIM;
const uint k_base_start = batch_idx * (NUM_KV_HEADS * SEQ_LEN * HEAD_DIM) +
kv_head_idx * (SEQ_LEN * HEAD_DIM);
const uint v_base_start = k_base_start; // Values have same layout as keys
const uint mask_base = batch_idx * (NUM_HEADS * SEQ_LEN * SEQ_LEN) +
head_idx * (SEQ_LEN * SEQ_LEN) +
query_pos * SEQ_LEN;
const uint out_base = q_base;
// Load query vector for this position (8-wide unrolled for better instruction-level parallelism)
T query_vec[HEAD_DIM];
for (uint d = 0; d < HEAD_DIM; d += 8) {
query_vec[d] = queries[q_base + d];
query_vec[d+1] = queries[q_base + d+1];
query_vec[d+2] = queries[q_base + d+2];
query_vec[d+3] = queries[q_base + d+3];
query_vec[d+4] = queries[q_base + d+4];
query_vec[d+5] = queries[q_base + d+5];
query_vec[d+6] = queries[q_base + d+6];
query_vec[d+7] = queries[q_base + d+7];
}
// First pass: compute attention scores and find maximum for numerical stability
T max_score = T(-INFINITY);
T scores[SEQ_LEN]; // Cache scores to avoid recomputation
for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) {
// Compute Q @ K^T for this key position
const uint k_base = k_base_start + key_pos * HEAD_DIM;
T score = T(0.0);
// Vectorized dot product - process 8 elements at a time for wider SIMD utilization.
// HEAD_DIM=128 is a multiple of 8, so no remainder check is needed.
for (uint d = 0; d < HEAD_DIM; d += 8) {
score += query_vec[d] * keys[k_base + d] +
query_vec[d+1] * keys[k_base + d+1] +
query_vec[d+2] * keys[k_base + d+2] +
query_vec[d+3] * keys[k_base + d+3] +
query_vec[d+4] * keys[k_base + d+4] +
query_vec[d+5] * keys[k_base + d+5] +
query_vec[d+6] * keys[k_base + d+6] +
query_vec[d+7] * keys[k_base + d+7];
}
// Apply attention scaling
score *= scale_val;
// Check attention mask and set score to -INFINITY if invalid.
// This makes the loop body uniform, avoiding conditional branching mid-loop.
bool is_valid = use_mask_val ? mask[mask_base + key_pos] : true;
scores[key_pos] = is_valid ? score : T(-INFINITY);
max_score = max(max_score, scores[key_pos]);
}
// Second pass: compute softmax denominator
T sum_exp = T(0.0);
for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) {
// Compute exp(score - max_score) unconditionally.
// If scores[key_pos] was -INFINITY (due to masking), exp(...) will correctly evaluate to 0.
T exp_score = exp(scores[key_pos] - max_score);
scores[key_pos] = exp_score; // Overwrite with exp(score - max)
sum_exp += exp_score;
}
// Initialize output to zero (8-wide unrolled)
for (uint d = 0; d < HEAD_DIM; d += 8) {
output[out_base + d] = T(0.0);
output[out_base + d+1] = T(0.0);
output[out_base + d+2] = T(0.0);
output[out_base + d+3] = T(0.0);
output[out_base + d+4] = T(0.0);
output[out_base + d+5] = T(0.0);
output[out_base + d+6] = T(0.0);
output[out_base + d+7] = T(0.0);
}
// Third pass: compute weighted sum of values
if (sum_exp > T(0.0)) { // This outer check is necessary to prevent division by zero
T inv_sum_exp = T(1.0) / sum_exp; // Pre-compute inverse for performance
for (uint key_pos = 0; key_pos < SEQ_LEN; key_pos++) {
T attention_weight = scores[key_pos] * inv_sum_exp; // Use pre-computed inverse
// If scores[key_pos] was 0 (due to mask or exp(-inf)), attention_weight will be 0.
// Multiplying by 0 won't change the accumulator, so the branch is not strictly needed
// and removing it can improve SIMD utilization by making the loop uniform.
const uint v_base = v_base_start + key_pos * HEAD_DIM;
// Vectorized accumulation - process 8 elements at a time.
// HEAD_DIM=128 is a multiple of 8, so no remainder check is needed.
for (uint d = 0; d < HEAD_DIM; d += 8) {
output[out_base + d] += attention_weight * values[v_base + d];
output[out_base + d+1] += attention_weight * values[v_base + d+1];
output[out_base + d+2] += attention_weight * values[v_base + d+2];
output[out_base + d+3] += attention_weight * values[v_base + d+3];
output[out_base + d+4] += attention_weight * values[v_base + d+4];
output[out_base + d+5] += attention_weight * values[v_base + d+5];
output[out_base + d+6] += attention_weight * values[v_base + d+6];
output[out_base + d+7] += attention_weight * values[v_base + d+7];
}
}
}
"""
# EVOLVE-BLOCK-END
try:
# Prepare kernel inputs
scale_tensor = mx.array([scale], dtype=queries.dtype)
use_mask_tensor = mx.array([1 if use_mask else 0], dtype=mx.int32)
# Create and execute custom Metal kernel
kernel = mx.fast.metal_kernel(
name="qwen3_gqa_attention_kernel",
input_names=["queries", "keys", "values", "mask", "scale", "use_mask"],
output_names=["output"],
source=kernel_source,
)
# Optimize thread group size for Apple Silicon
threadgroup_size = min(32, L) # Adapt to sequence length
# Execute kernel
outputs = kernel(
inputs=[queries, keys, values, mask_tensor, scale_tensor, use_mask_tensor],
output_shapes=[(B, num_heads, L, head_dim)],
output_dtypes=[queries.dtype],
grid=(L, num_heads, B), # (SEQ_LEN, NUM_HEADS, BATCH_SIZE)
threadgroup=(threadgroup_size, 1, 1),
template=[
("T", queries.dtype),
("BATCH_SIZE", B),
("NUM_HEADS", num_heads),
("NUM_KV_HEADS", num_kv_heads),
("SEQ_LEN", L),
("HEAD_DIM", head_dim),
("HEADS_PER_KV", heads_per_kv),
],
)
return outputs[0]
except Exception as e:
# No fallback - let the custom kernel failure propagate for proper scoring
print(f"❌ Custom GQA kernel failed: {e}")
raise RuntimeError(f"Custom Metal kernel execution failed: {e}") from e
class CustomGQAAttention(nn.Module):
"""
Qwen3 attention module with custom Metal kernel optimization.
This module integrates the custom Metal kernel while maintaining
compatibility with the standard MLX-LM interface.
"""
def __init__(self, args):
super().__init__()
# Standard Qwen3 parameters
dim = args.hidden_size # 2048
self.n_heads = n_heads = args.num_attention_heads # 16
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads # 8
head_dim = args.head_dim # 128
self.scale = head_dim**-0.5
# Standard MLX-LM projections
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
# Standard MLX-LM norms
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
# Standard MLX-LM RoPE
try:
from mlx_lm.models.rope_utils import initialize_rope
self.rope = initialize_rope(
head_dim,
base=args.rope_theta,
traditional=False,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
)
except ImportError:
print("⚠️ Could not import mlx_lm rope_utils, using basic RoPE")
self.rope = None
print(f"🔧 Initialized Custom Metal GQA Attention")
print(f" 📊 Architecture: {n_heads}:{n_kv_heads} heads ({n_heads//n_kv_heads}:1 ratio)")
print(f" 🎯 Head dimension: {head_dim}")
print(f" ⚡ Using custom Metal kernel for GQA optimization")
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
# Standard preprocessing (already optimized, don't evolve)
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(0, 2, 1, 3)
keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
# Standard RoPE application (already optimized, don't evolve)
if cache is not None:
if self.rope is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
if self.rope is not None:
queries = self.rope(queries)
keys = self.rope(keys)
# CORE INNOVATION: Custom Metal kernel for GQA attention
output = qwen3_custom_gqa_attention(queries, keys, values, scale=self.scale, mask=mask)
# Standard postprocessing (already optimized, don't evolve)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
def create_metal_qwen3_optimization_hook():
"""
Create hooks to replace Qwen3's attention with Metal kernel optimized version.
"""
def apply_optimization_hook():
"""Apply the Metal kernel optimized attention"""
try:
import mlx_lm.models.qwen3 as qwen3_module
# Store original attention class
original_attention = qwen3_module.Attention
# Replace with Metal optimized implementation
qwen3_module.Attention = CustomGQAAttention
print("✅ Applied Custom Metal GQA Attention hook")
return original_attention
except ImportError:
print("❌ Could not import mlx_lm.models.qwen3")
return None
def remove_optimization_hook(original_attention):
"""Remove the optimization hook"""
try:
import mlx_lm.models.qwen3 as qwen3_module
qwen3_module.Attention = original_attention
print("✅ Removed Custom Metal GQA Attention hook")
except ImportError:
pass
return apply_optimization_hook, remove_optimization_hook
def benchmark_metal_gqa_optimization():
"""
Benchmark Metal kernel optimized GQA attention against MLX baseline.
"""
# Qwen3-0.6B configuration
class MockArgs:
hidden_size = 2048
num_attention_heads = 16
num_key_value_heads = 8
head_dim = 128
rms_norm_eps = 1e-06
rope_theta = 1000000
rope_scaling = None
max_position_embeddings = 40960
args = MockArgs()
# Test configurations for Metal kernel validation
test_configs = [
("short_sequence", 1, 128, 2048),
("medium_sequence", 1, 512, 2048),
("long_sequence", 1, 1024, 2048),
("max_sequence", 1, 2048, 2048),
]
print("Benchmarking Custom Metal GQA Kernel vs MLX Baseline")
print("=" * 70)
# Initialize Metal optimized attention
metal_attn = CustomGQAAttention(args)
for config_name, batch_size, seq_len, hidden_size in test_configs:
print(f"\nTesting {config_name}: B={batch_size}, L={seq_len}")
# Create test inputs
x = mx.random.normal((batch_size, seq_len, hidden_size))
mask = "causal"
# Warmup runs
for _ in range(3):
_ = metal_attn(x, mask=mask)
mx.eval(_)
# Benchmark Metal optimized implementation
mx.synchronize()
start_time = time.perf_counter()
for _ in range(10):
output = metal_attn(x, mask=mask)
mx.eval(output)
mx.synchronize()
end_time = time.perf_counter()
avg_time = (end_time - start_time) / 10
tokens_per_sec = seq_len / avg_time
print(f" Metal GQA: {avg_time*1000:.2f} ms, {tokens_per_sec:.1f} tokens/sec")
print(f" Memory: {mx.get_active_memory() / 1e9:.2f} GB")
def test_metal_gqa_correctness():
"""
Test that Metal kernel implementation produces correct results.
"""
print("Testing Custom Metal GQA Correctness")
print("=" * 50)
# Test configuration
B, L, D = 1, 64, 2048
class MockArgs:
hidden_size = 2048
num_attention_heads = 16
num_key_value_heads = 8
head_dim = 128
rms_norm_eps = 1e-06
rope_theta = 1000000
rope_scaling = None
max_position_embeddings = 40960
args = MockArgs()
# Create test input
x = mx.random.normal((B, L, D))
mask = "causal"
# Test Metal optimized implementation
metal_attn = CustomGQAAttention(args)
output = metal_attn(x, mask=mask)
print(f"✅ Metal GQA output shape: {output.shape}")
# Check for valid output
has_nan = bool(mx.any(mx.isnan(output)))
has_inf = bool(mx.any(mx.isinf(output)))
print(f"✅ Has NaN: {has_nan}, Has Inf: {has_inf}")
# Check output statistics
output_mean = float(mx.mean(output))
output_std = float(mx.std(output))
print(f"✅ Output statistics - Mean: {output_mean:.6f}, Std: {output_std:.6f}")
# Test direct kernel function
print("\n=== Testing Direct Kernel Function ===")
B, H, L, D = 1, 16, 128, 128
q = mx.random.normal((B, H, L, D))
k = mx.random.normal((B, 8, L, D)) # 8 KV heads
v = mx.random.normal((B, 8, L, D))
scale = 1.0 / math.sqrt(D)
kernel_output = qwen3_custom_gqa_attention(q, k, v, scale=scale, mask="causal")
print(f"✅ Direct kernel output shape: {kernel_output.shape}")
kernel_mean = float(mx.mean(kernel_output))
kernel_std = float(mx.std(kernel_output))
print(f"✅ Direct kernel stats - Mean: {kernel_mean:.6f}, Std: {kernel_std:.6f}")
return True
if __name__ == "__main__":
print("Custom Metal Kernel Qwen3 GQA Optimization")
print("=" * 70)
# Test correctness first
test_metal_gqa_correctness()
print("\n")
# Benchmark performance
benchmark_metal_gqa_optimization()
print("\n" + "=" * 70)
print("Ready for Metal Kernel Evolution")
print("Evolution focus:")
print("1. 🔧 Metal kernel source code optimization")
print("2. 💾 Memory access pattern improvements for Apple Silicon")
print("3. 🎯 GQA-specific optimizations for 16:8 head ratio")
print("4. ⚡ Vectorization and SIMD optimization")
print("5. 🚀 Thread group and grid configuration tuning")
print("Target: 5-15% performance improvement through Metal kernel innovation")
print("=" * 70)