-
Notifications
You must be signed in to change notification settings - Fork 935
Expand file tree
/
Copy pathevaluator.py
More file actions
1471 lines (1232 loc) · 64.2 KB
/
evaluator.py
File metadata and controls
1471 lines (1232 loc) · 64.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
🛡️ BULLETPROOF METAL KERNEL EVALUATOR 🛡️
This evaluator provides MAXIMUM protection against Metal kernel failures during evolution:
🔧 METAL-SPECIFIC PROTECTION:
1. Pre-execution kernel parameter validation
2. Memory safety checks before GPU execution
3. Command buffer error detection and recovery
4. Thread-safe Metal kernel execution wrapping
5. Graceful fallback to standard attention on ANY Metal failure
🚀 EVOLUTION SAFETY:
- NEVER crashes the evolution process
- Handles kIOGPUCommandBufferCallbackErrorInvalidResource errors
- Catches GPU memory violations, out-of-bounds access, race conditions
- Provides detailed error classification for debugging
- Maintains evolution progress even with buggy kernel code
🎯 ROBUST ERROR RECOVERY:
- Multiple retry attempts with exponential backoff
- Automatic fallback mechanisms
- Comprehensive error statistics tracking
- Safe cleanup of GPU resources
"""
import os
import sys
import json
import time
import traceback
import threading
import subprocess
import tempfile
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import mlx.core as mx
import mlx.nn as nn
# Import the comprehensive benchmark suite for consistent testing
from qwen3_benchmark_suite import Qwen3BenchmarkSuite, BenchmarkConfig, BenchmarkResult
class MetalKernelSafetyError(Exception):
"""Metal kernel safety violation"""
pass
class GPUCommandBufferError(Exception):
"""GPU command buffer execution error"""
pass
class MetalMemoryViolationError(Exception):
"""Metal kernel memory access violation"""
pass
class BulletproofMetalEvaluator:
"""Bulletproof evaluator that NEVER crashes from Metal kernel failures"""
def __init__(self):
self.model_path = "mlx-community/Qwen3-0.6B-bf16"
# Enhanced error handling configuration
self.max_retry_attempts = 3
self.retry_base_delay = 1.0 # Base delay for exponential backoff
self.kernel_validation_timeout = 30 # Timeout for kernel validation
# Comprehensive error tracking
self.metal_command_buffer_errors = 0
self.metal_memory_violations = 0
self.metal_compilation_errors = 0
self.gpu_resource_errors = 0
self.total_metal_errors = 0
self.successful_fallbacks = 0
self.retry_attempts_used = 0
# Safety thresholds
self.max_sequence_length_safe = 512 # Start with safer sequence lengths
self.max_batch_size_safe = 1
self.max_head_dimension_safe = 128
# Baseline metrics storage
self.baseline_metrics = None
self.baseline_results = None
# Use comprehensive benchmark suite
self.benchmark_suite = Qwen3BenchmarkSuite(self.model_path)
print("🛡️ BULLETPROOF METAL KERNEL EVALUATOR INITIALIZED")
print(f"📱 Model: {self.model_path}")
print(f"🔁 Max retry attempts: {self.max_retry_attempts}")
print(f"⚡ GPU error protection: MAXIMUM")
print(f"🧠 Memory safety validation: ENABLED")
print(f"🎯 Command buffer error handling: ACTIVE")
def evaluate(self, program_text: str) -> Dict[str, Any]:
"""
BULLETPROOF evaluation that handles ALL Metal kernel failures:
1. Enhanced program extraction with syntax validation
2. Pre-execution kernel safety validation
3. Protected baseline measurement with fallback
4. GPU-safe correctness testing with memory checks
5. Armored benchmarking with command buffer protection
6. Comprehensive Metal error recovery and statistics
"""
print("\n" + "🛡️ " * 50)
print("🛡️ BULLETPROOF METAL KERNEL EVALUATION STARTING")
print("🛡️ " * 50)
print("✅ GPU Command Buffer Error Protection: ACTIVE")
print("✅ Metal Memory Violation Detection: ENABLED")
print("✅ Automatic Fallback Mechanisms: READY")
print("✅ Multi-layer Error Recovery: ARMED")
print("✅ Evolution Process Protection: MAXIMUM")
print("🛡️ " * 50)
try:
# Reset all error counters
self._reset_error_counters()
# Step 1: Enhanced program extraction with Metal validation
print("\n🔧 STEP 1: Enhanced Program Extraction with Metal Validation")
extraction_result = self._bulletproof_extract_custom_attention(program_text)
if not extraction_result["success"]:
return self._create_comprehensive_failure_result(
f"Program extraction failed: {extraction_result['error']}"
)
custom_attention_class = extraction_result["class"]
program_source = extraction_result["program_text"]
# Step 2: Pre-execution Metal kernel safety validation
print("\n🔍 STEP 2: Pre-execution Metal Kernel Safety Validation")
safety_result = self._validate_metal_kernel_safety(custom_attention_class)
if not safety_result["success"]:
print(f"⚠️ Metal kernel safety validation failed: {safety_result['error']}")
print("🛡️ Proceeding with enhanced protection...")
# Step 3: Memory-safe correctness testing FIRST (fail fast, skip baseline if invalid)
print("\n🔍 STEP 3: Memory-Safe Custom Attention Correctness Testing")
correctness_result = self._memory_safe_correctness_test(custom_attention_class)
if not correctness_result["success"]:
return self._create_comprehensive_failure_result(
f"Memory-safe correctness test failed: {correctness_result['error']}"
)
correctness_score = correctness_result["score"]
if correctness_score < 0.90: # Slightly more lenient for complex kernels
return self._create_comprehensive_failure_result(
f"Correctness score too low: {correctness_score:.3f} (required: 0.90)"
)
# Step 4: GPU-protected baseline measurement (only if correctness passed)
print("\n📊 STEP 4: GPU-Protected Baseline Performance Measurement")
baseline_results = self._gpu_protected_measure_baseline()
if not baseline_results:
return self._create_comprehensive_failure_result(
"Failed to measure baseline performance with GPU protection"
)
# Step 5: Command-buffer-protected benchmarking
print("\n🚀 STEP 5: Command-Buffer-Protected Performance Benchmarking")
benchmark_result = self._command_buffer_protected_benchmark(
program_source, custom_attention_class
)
if not benchmark_result["success"]:
return self._create_comprehensive_failure_result(
f"Command-buffer-protected benchmarking failed: {benchmark_result['error']}"
)
custom_results = benchmark_result["results"]
# Step 6: Enhanced performance analysis
print("\n📈 STEP 6: Enhanced Performance Analysis")
performance_analysis = self._analyze_performance_with_safety_metrics(
baseline_results, custom_results
)
# Step 7: Calculate safety-adjusted final score
final_score = self._calculate_safety_adjusted_score(
performance_analysis, correctness_score
)
# Step 8: Generate comprehensive result with full error statistics
result = {
"success": True,
"final_score": final_score,
"combined_score": final_score,
"performance_metrics": performance_analysis["aggregate_metrics"],
"correctness_score": correctness_score,
"benchmark_results": [self._result_to_dict(r) for r in custom_results],
"baseline_comparison": performance_analysis["comparison_summary"],
"individual_comparisons": performance_analysis["individual_comparisons"],
"summary": self._generate_comprehensive_summary(
performance_analysis, correctness_score
),
"metal_safety_statistics": self._get_comprehensive_error_statistics(),
"safety_validation": safety_result,
}
self._print_bulletproof_evaluation_results(result)
return result
except Exception as e:
# Ultimate protection: even this top-level catch must never crash evolution
self.total_metal_errors += 1
error_msg = f"TOP-LEVEL BULLETPROOF CATCH: {str(e)}"
print(f"🛡️ {error_msg}")
traceback.print_exc()
return self._create_comprehensive_failure_result(error_msg)
def _reset_error_counters(self):
"""Reset all error tracking counters"""
self.metal_command_buffer_errors = 0
self.metal_memory_violations = 0
self.metal_compilation_errors = 0
self.gpu_resource_errors = 0
self.total_metal_errors = 0
self.successful_fallbacks = 0
self.retry_attempts_used = 0
def _bulletproof_extract_custom_attention(self, program_text: str) -> Dict[str, Any]:
"""Bulletproof extraction with comprehensive Metal kernel validation"""
try:
print(" 🔍 Bulletproof program analysis with Metal validation...")
# Handle file paths vs direct text
if (
program_text.startswith("/")
and "\n" not in program_text
and len(program_text) < 500
):
print(f" 📁 Reading program from file: {program_text}")
if os.path.exists(program_text):
try:
with open(program_text, "r") as f:
actual_program_text = f.read()
except Exception as e:
return {"success": False, "error": f"File read error: {e}"}
else:
return {"success": False, "error": f"Program file not found: {program_text}"}
else:
actual_program_text = program_text
# Enhanced syntax validation
try:
compile(actual_program_text, "<evolved_program>", "exec")
print(" ✅ Enhanced syntax validation passed")
except SyntaxError as e:
return {"success": False, "error": f"Syntax error: {e}"}
# Pre-validate Metal kernel syntax (static analysis)
metal_validation = self._static_validate_metal_kernel_syntax(actual_program_text)
if not metal_validation["safe"]:
print(
f" ⚠️ Metal kernel static validation warning: {metal_validation['warnings']}"
)
# Create ultra-safe execution environment
exec_globals = self._create_bulletproof_execution_environment()
# Execute program with maximum protection
print(" ⚙️ Executing program with MAXIMUM protection...")
try:
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: exec(actual_program_text, exec_globals)
)
if not success:
self.total_metal_errors += 1
return {"success": False, "error": f"Protected execution failed: {result}"}
except Exception as e:
self.total_metal_errors += 1
return {"success": False, "error": f"Execution error with GPU protection: {e}"}
# Enhanced class extraction and validation
custom_class = exec_globals.get("CustomGQAAttention")
if custom_class is None:
return {
"success": False,
"error": "CustomGQAAttention class not found in executed code",
}
# Comprehensive class validation
validation_result = self._validate_custom_attention_class(custom_class)
if not validation_result["valid"]:
return {"success": False, "error": validation_result["error"]}
print(f" ✅ Successfully extracted and validated CustomGQAAttention class")
print(f" 🛡️ Metal safety pre-checks: {metal_validation['safe']}")
return {
"success": True,
"class": custom_class,
"metal_validation": metal_validation,
"program_text": actual_program_text,
}
except Exception as e:
self.total_metal_errors += 1
return {"success": False, "error": f"Bulletproof extraction failed: {str(e)}"}
def _static_validate_metal_kernel_syntax(self, program_text: str) -> Dict[str, Any]:
"""Static analysis of Metal kernel syntax for common safety issues"""
warnings = []
# Check for common Metal safety issues
dangerous_patterns = [
("buffer overflow", ["queries[", "keys[", "values[", "output[", "mask["]),
("unguarded loops", ["for (", "while ("]),
("raw pointers", ["*queries", "*keys", "*values", "*output"]),
("thread sync issues", ["threadgroup", "simdgroup"]),
]
for issue_type, patterns in dangerous_patterns:
for pattern in patterns:
if pattern in program_text:
warnings.append(f"{issue_type}: {pattern}")
# Check for bounds checking
has_bounds_checking = any(
check in program_text
for check in [
"batch_idx >= BATCH_SIZE",
"head_idx >= NUM_HEADS",
"query_pos >= SEQ_LEN",
"d < HEAD_DIM",
]
)
if not has_bounds_checking:
warnings.append("missing bounds checking")
return {
"safe": len(warnings) == 0,
"warnings": warnings,
"has_bounds_checking": has_bounds_checking,
}
def _validate_custom_attention_class(self, custom_class: Any) -> Dict[str, Any]:
"""Comprehensive validation of custom attention class"""
try:
# Basic type checking
if not isinstance(custom_class, type):
return {"valid": False, "error": "CustomGQAAttention is not a valid class"}
# Check for required methods
required_methods = ["__init__", "__call__"]
for method in required_methods:
if not hasattr(custom_class, method):
return {"valid": False, "error": f"Missing required method: {method}"}
# Check if it inherits from nn.Module (recommended)
if not issubclass(custom_class, nn.Module):
print(" ⚠️ CustomGQAAttention doesn't inherit from nn.Module")
print(" ✅ Custom attention class validation passed")
return {"valid": True}
except Exception as e:
return {"valid": False, "error": f"Class validation error: {e}"}
def _validate_metal_kernel_safety(self, custom_attention_class: Any) -> Dict[str, Any]:
"""Pre-execution validation of Metal kernel safety"""
try:
print(" 🔍 Validating Metal kernel safety parameters...")
# Mock arguments for safety testing
class MockArgs:
# NOTE: This should reflect the default model used by this evaluator:
# `mlx-community/Qwen3-0.6B-bf16` (16 Q heads : 8 KV heads, head_dim=128).
hidden_size = 2048
num_attention_heads = 16
num_key_value_heads = 8
head_dim = 128
rms_norm_eps = 1e-06
rope_theta = 1000000
rope_scaling = None
max_position_embeddings = 40960
args = MockArgs()
# Try to instantiate with safety checks
try:
instance = custom_attention_class(args)
if instance is None:
return {"success": False, "error": "Failed to instantiate custom attention"}
print(" ✅ Custom attention instantiation successful")
# Basic parameter validation (should match the args we instantiated with)
if hasattr(instance, "n_heads") and instance.n_heads != args.num_attention_heads:
return {
"success": False,
"error": f"Invalid head count: {instance.n_heads} (expected {args.num_attention_heads})",
}
if hasattr(instance, "n_kv_heads") and instance.n_kv_heads != args.num_key_value_heads:
return {
"success": False,
"error": f"Invalid KV head count: {instance.n_kv_heads}",
}
return {"success": True, "validated": True}
except Exception as e:
error_msg = str(e)
if any(keyword in error_msg.lower() for keyword in ["metal", "kernel", "gpu"]):
self.metal_compilation_errors += 1
return {"success": False, "error": f"Instantiation failed: {error_msg}"}
except Exception as e:
self.total_metal_errors += 1
return {"success": False, "error": f"Safety validation error: {e}"}
def _bulletproof_execute_with_gpu_protection(self, func) -> Tuple[bool, Any]:
"""Execute function with maximum GPU and Metal kernel protection"""
try:
# Clear any existing GPU state
mx.eval(mx.array([1.0])) # Simple operation to ensure GPU is responsive
# Execute with comprehensive error catching
result = func()
return True, result
except RuntimeError as e:
error_msg = str(e)
# Classify specific Metal/GPU errors
if "kIOGPUCommandBufferCallbackErrorInvalidResource" in error_msg:
self.metal_command_buffer_errors += 1
self.total_metal_errors += 1
return False, f"GPU Command Buffer Error (memory violation): {error_msg}"
elif "METAL" in error_msg.upper():
self.metal_memory_violations += 1
self.total_metal_errors += 1
return False, f"Metal Memory Violation: {error_msg}"
elif any(keyword in error_msg.lower() for keyword in ["gpu", "metal", "kernel"]):
self.gpu_resource_errors += 1
self.total_metal_errors += 1
return False, f"GPU Resource Error: {error_msg}"
else:
return False, f"Runtime Error: {error_msg}"
except Exception as e:
error_msg = str(e)
# Additional classification for other Metal-related exceptions
if any(
keyword in error_msg.lower() for keyword in ["metal", "kernel", "gpu", "mps", "mtl"]
):
self.total_metal_errors += 1
return False, f"General Metal Error: {error_msg}"
else:
return False, f"Execution Error: {error_msg}"
def _gpu_protected_measure_baseline(self) -> Optional[List[BenchmarkResult]]:
"""GPU-protected baseline measurement with enhanced error handling"""
try:
print(" 📊 Running GPU-protected baseline benchmark...")
# Ensure clean GPU state
self._ensure_clean_gpu_state()
self._ensure_standard_attention()
# Get baseline configurations
baseline_configs = self._get_safe_benchmark_configs()
if not baseline_configs:
print(" ❌ No safe benchmark configurations available")
return None
baseline_results = []
successful_count = 0
for i, config in enumerate(baseline_configs, 1):
print(f" [{i}/{len(baseline_configs)}] GPU-protected baseline: {config.name}")
retry_count = 0
while retry_count <= self.max_retry_attempts:
try:
# Clean GPU state before each attempt
self._ensure_clean_gpu_state()
# Run with GPU protection
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: self.benchmark_suite.run_single_benchmark(config)
)
if success and result:
baseline_results.append(result)
successful_count += 1
print(
f" ✅ GPU-protected {config.name}: {result.decode_tokens_per_sec:.1f} tokens/sec"
)
break
else:
if retry_count < self.max_retry_attempts:
print(f" 🔄 Retry {retry_count + 1}: {result}")
retry_count += 1
time.sleep(self.retry_base_delay * (2**retry_count))
continue
else:
print(f" ❌ All retries exhausted for {config.name}: {result}")
break
except Exception as e:
if retry_count < self.max_retry_attempts:
print(f" 🔄 Exception retry {retry_count + 1}: {e}")
retry_count += 1
time.sleep(self.retry_base_delay * (2**retry_count))
continue
else:
print(f" ❌ Final exception for {config.name}: {e}")
break
# Check success rate
min_required = max(2, len(baseline_configs) * 0.5) # At least 50% success
if successful_count < min_required:
print(
f" ❌ Insufficient baseline results: {successful_count}/{len(baseline_configs)}"
)
return None
# Store baseline metrics
self._store_enhanced_baseline_metrics(baseline_results)
print(f" ✅ GPU-protected baseline complete ({successful_count} successful)")
return baseline_results
except Exception as e:
print(f" ❌ GPU-protected baseline measurement failed: {e}")
return None
def _memory_safe_correctness_test(self, custom_attention_class: Any) -> Dict[str, Any]:
"""Memory-safe correctness testing with GPU protection"""
print(" 🔍 Running memory-safe correctness testing...")
try:
# Safe test configuration
class MockArgs:
# Must match the default model `mlx-community/Qwen3-0.6B-bf16`
hidden_size = 2048
num_attention_heads = 16
num_key_value_heads = 8
head_dim = 128
rms_norm_eps = 1e-06
rope_theta = 1000000
rope_scaling = None
max_position_embeddings = 40960
args = MockArgs()
# Conservative test cases (smaller sequences for safety)
test_cases = [
(1, 8, 2048), # Micro sequence
(1, 16, 2048), # Very short
(1, 32, 2048), # Short sequence
(1, 64, 2048), # Medium sequence
]
correctness_scores = []
local_command_buffer_errors = 0
local_memory_violations = 0
for B, L, D in test_cases:
print(f" 🧪 Memory-safe testing sequence length {L}...")
retry_count = 0
while retry_count <= self.max_retry_attempts:
try:
# Clean GPU state
self._ensure_clean_gpu_state()
# Create conservative test inputs
# IMPORTANT: Match the real inference dtype used by the default model
# (`mlx-community/Qwen3-0.6B-bf16`), otherwise Metal kernels may compile
# for float32 in correctness tests but fail under bfloat16 in practice.
x = (mx.random.normal((B, L, D)) * 0.1).astype(mx.bfloat16)
mask = "causal"
# Test with maximum GPU protection
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: self._test_single_sequence_memory_safe(
custom_attention_class, args, x, mask
)
)
if success:
correctness_scores.append(result)
print(f" ✅ Sequence {L}: PASS (score={result:.3f})")
break
else:
error_msg = str(result)
# Enhanced error classification
if "command buffer" in error_msg.lower():
local_command_buffer_errors += 1
elif "memory violation" in error_msg.lower():
local_memory_violations += 1
# EARLY EXIT: Metal compilation errors are deterministic - no retry
if "unable to build metal library" in error_msg.lower():
self.metal_compilation_errors += 1
print(f" ❌ Metal compilation error (no retry): {error_msg[:200]}...")
# Return early - compilation errors won't be fixed by retrying
return {
"success": False,
"score": 0.0,
"error": "Metal kernel compilation failed - bfloat16 incompatible code",
"command_buffer_errors": local_command_buffer_errors,
"memory_violations": local_memory_violations,
"compilation_error": True,
}
if retry_count < self.max_retry_attempts:
print(
f" 🔄 Retry {retry_count + 1} for length {L}: {error_msg}"
)
retry_count += 1
time.sleep(self.retry_base_delay * (2**retry_count))
continue
else:
print(f" ❌ All retries failed for length {L}: {error_msg}")
correctness_scores.append(0.0)
break
except Exception as e:
error_msg = str(e)
print(f" ❌ Exception for length {L}: {error_msg}")
# EARLY EXIT: Metal compilation errors are deterministic - no retry
if "unable to build metal library" in error_msg.lower():
self.metal_compilation_errors += 1
print(f" ❌ Metal compilation error (no retry): {error_msg[:200]}...")
return {
"success": False,
"score": 0.0,
"error": "Metal kernel compilation failed - bfloat16 incompatible code",
"command_buffer_errors": local_command_buffer_errors,
"memory_violations": local_memory_violations,
"compilation_error": True,
}
if retry_count < self.max_retry_attempts:
retry_count += 1
time.sleep(self.retry_base_delay * (2**retry_count))
continue
else:
correctness_scores.append(0.0)
break
# Update global error counters
self.metal_command_buffer_errors += local_command_buffer_errors
self.metal_memory_violations += local_memory_violations
self.total_metal_errors += local_command_buffer_errors + local_memory_violations
# Calculate overall correctness with partial credit
overall_correctness = np.mean(correctness_scores) if correctness_scores else 0.0
print(f" 📊 Memory-safe overall correctness: {overall_correctness:.3f}")
print(f" 🛡️ Command buffer errors: {local_command_buffer_errors}")
print(f" 🛡️ Memory violations: {local_memory_violations}")
return {
"success": True,
"score": overall_correctness,
"command_buffer_errors": local_command_buffer_errors,
"memory_violations": local_memory_violations,
}
except Exception as e:
self.total_metal_errors += 1
print(f" ❌ Memory-safe correctness testing failed: {e}")
return {"success": False, "error": str(e)}
def _test_single_sequence_memory_safe(
self, custom_attention_class: Any, args: Any, x: Any, mask: Any
) -> float:
"""Test single sequence with enhanced memory safety"""
try:
# Force bfloat16 to exercise the same kernel template/compilation path as production
# inference with `mlx-community/Qwen3-0.6B-bf16`.
if x.dtype != mx.bfloat16:
x = x.astype(mx.bfloat16)
# Pre-execution safety checks
if x.shape[1] > self.max_sequence_length_safe:
raise MetalKernelSafetyError(
f"Sequence length {x.shape[1]} exceeds safe limit {self.max_sequence_length_safe}"
)
if x.shape[0] > self.max_batch_size_safe:
raise MetalKernelSafetyError(
f"Batch size {x.shape[0]} exceeds safe limit {self.max_batch_size_safe}"
)
# Instantiate with error checking
custom_attn = custom_attention_class(args)
if custom_attn is None:
raise ValueError("Failed to instantiate custom attention")
# Ensure module parameters follow the intended compute dtype as well.
# Otherwise, float32 weights can upcast intermediate Q/K/V tensors and
# accidentally avoid bfloat16 kernel compilation.
if hasattr(custom_attn, "set_dtype"):
custom_attn.set_dtype(mx.bfloat16)
# Conservative forward pass with timeout simulation
start_time = time.time()
output = custom_attn(x, mask=mask)
elapsed_time = time.time() - start_time
# Timeout check (soft limit)
if elapsed_time > self.kernel_validation_timeout:
print(f" ⚠️ Slow execution detected: {elapsed_time:.2f}s")
return 0.5 # Partial credit for slow but working kernel
# Enhanced output validation
if output is None:
raise ValueError("Custom attention returned None")
# Shape validation
expected_shape = x.shape
if output.shape != expected_shape:
raise ValueError(f"Wrong output shape: {output.shape}, expected {expected_shape}")
# Enhanced finite value check
finite_mask = mx.isfinite(output)
if not mx.all(finite_mask):
finite_ratio = float(mx.mean(finite_mask.astype(mx.float32)))
if finite_ratio < 0.9:
raise ValueError(f"Too many non-finite values: {finite_ratio:.2%} finite")
else:
print(f" ⚠️ Some non-finite values: {finite_ratio:.2%} finite")
return 0.7 # Partial credit
# Enhanced statistical validation
output_mean = float(mx.mean(output))
output_std = float(mx.std(output))
output_max = float(mx.max(mx.abs(output)))
# More lenient bounds for complex kernels
if abs(output_mean) > 10.0:
print(f" ⚠️ Large mean: {output_mean:.6f}")
return 0.6
if output_std > 100.0 or output_std < 0.00001:
print(f" ⚠️ Unusual std: {output_std:.6f}")
return 0.6
if output_max > 1000.0:
print(f" ⚠️ Large max value: {output_max:.6f}")
return 0.7
# All checks passed
return 1.0
except MetalKernelSafetyError as e:
raise e # Re-raise safety errors
except Exception as e:
error_msg = str(e)
if any(
keyword in error_msg.lower()
for keyword in ["metal", "kernel", "gpu", "command buffer"]
):
raise GPUCommandBufferError(f"GPU execution error: {error_msg}")
else:
raise ValueError(f"Sequence test error: {error_msg}")
def _command_buffer_protected_benchmark(
self, program_text: str, custom_attention_class: Any
) -> Dict[str, Any]:
"""Command-buffer-protected benchmarking with maximum safety"""
print(" 🚀 Running command-buffer-protected benchmarking...")
retry_attempt = 0
while retry_attempt <= self.max_retry_attempts:
try:
print(f" 🔄 Protected attempt {retry_attempt + 1}/{self.max_retry_attempts + 1}")
# Clean GPU state before each major attempt
self._ensure_clean_gpu_state()
# Apply custom attention hook with protection
hook_result = self._gpu_protected_apply_hook(custom_attention_class)
if not hook_result["success"]:
if retry_attempt < self.max_retry_attempts:
print(f" 🔄 Hook failed, retrying... ({hook_result['error']})")
retry_attempt += 1
time.sleep(self.retry_base_delay * (2**retry_attempt))
continue
return {
"success": False,
"error": f"Hook application failed: {hook_result['error']}",
}
original_attention = hook_result["original"]
temp_program_path = None
try:
# Ensure the evolved program is available to the subprocess that runs mlx_lm.generate.
# Monkey-patching in this evaluator process does NOT propagate across subprocess boundaries.
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f:
f.write(program_text)
temp_program_path = f.name
self.benchmark_suite.hook_program_path = temp_program_path
# Run benchmarks with command buffer protection
custom_configs = self._get_safe_benchmark_configs()
custom_results = []
successful_benchmarks = 0
for i, config in enumerate(custom_configs, 1):
print(
f" [{i}/{len(custom_configs)}] Command-buffer-protected: {config.name}"
)
benchmark_retry = 0
while benchmark_retry <= 2: # Fewer retries per benchmark
try:
# Clean state before each benchmark
self._ensure_clean_gpu_state()
# Run with maximum protection
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: self.benchmark_suite.run_single_benchmark(config)
)
if success and result:
custom_results.append(result)
successful_benchmarks += 1
print(
f" ✅ Protected {config.name}: {result.decode_tokens_per_sec:.1f} tokens/sec"
)
break
else:
if benchmark_retry < 2:
print(
f" 🔄 Benchmark retry {benchmark_retry + 1}: {result}"
)
benchmark_retry += 1
time.sleep(1)
continue
else:
print(f" ❌ Benchmark failed: {result}")
break
except Exception as e:
if benchmark_retry < 2:
print(
f" 🔄 Benchmark exception retry {benchmark_retry + 1}: {e}"
)
benchmark_retry += 1
time.sleep(1)
continue
else:
print(f" ❌ Benchmark exception: {e}")
break
# Check success rate
min_required = max(2, len(custom_configs) * 0.4) # Lowered to 40% for safety
if successful_benchmarks >= min_required:
print(
f" ✅ Command-buffer-protected benchmarks complete ({successful_benchmarks} successful)"
)
self.retry_attempts_used = retry_attempt
return {"success": True, "results": custom_results}
else:
error_msg = f"Insufficient benchmarks: {successful_benchmarks}/{len(custom_configs)} succeeded"
if retry_attempt < self.max_retry_attempts:
print(f" 🔄 {error_msg}, retrying full attempt...")
retry_attempt += 1
time.sleep(self.retry_base_delay * (2**retry_attempt))
continue
return {"success": False, "error": error_msg}
finally:
# Always clear subprocess hook settings and clean up temp program
self.benchmark_suite.hook_program_path = None
if temp_program_path:
try:
os.unlink(temp_program_path)
except OSError:
pass
# Always restore original attention
self._gpu_protected_remove_hook(original_attention)
except Exception as e:
error_msg = f"Command-buffer-protected attempt failed: {str(e)}"
print(f" ❌ {error_msg}")
if retry_attempt < self.max_retry_attempts:
retry_attempt += 1
time.sleep(self.retry_base_delay * (2**retry_attempt))
continue
return {"success": False, "error": error_msg}
return {"success": False, "error": "All command-buffer-protected attempts exhausted"}
def _ensure_clean_gpu_state(self):
"""Ensure clean GPU state before operations"""
try:
# Simple operation to ensure GPU responsiveness
test_op = mx.array([1.0, 2.0, 3.0])
mx.eval(test_op * 2)
# Small delay to let GPU settle
time.sleep(0.1)
except Exception as e:
print(f" ⚠️ GPU state cleanup warning: {e}")
def _gpu_protected_apply_hook(self, custom_attention_class: Any) -> Dict[str, Any]:
"""GPU-protected application of custom attention hook"""
try:
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: self._apply_attention_hook_safely(custom_attention_class)
)
if success:
return {"success": True, "original": result}
else:
return {"success": False, "error": result}
except Exception as e:
return {"success": False, "error": f"GPU-protected hook application failed: {e}"}
def _apply_attention_hook_safely(self, custom_attention_class: Any) -> Any:
"""Safely apply attention hook"""
import mlx_lm.models.qwen3 as qwen3_module
# Store original attention class
original_attention = getattr(qwen3_module, "Attention", None)
if original_attention is None:
raise RuntimeError("Could not find original Attention class")
# Apply custom attention
qwen3_module.Attention = custom_attention_class
# Verify the hook was applied
if qwen3_module.Attention != custom_attention_class:
raise RuntimeError("Hook application verification failed")
print(" ✅ Custom attention hook applied with GPU protection")
return original_attention
def _gpu_protected_remove_hook(self, original_attention: Any):
"""GPU-protected removal of custom attention hook"""
try:
success, result = self._bulletproof_execute_with_gpu_protection(
lambda: self._remove_attention_hook_safely(original_attention)
)
if not success:
print(f" ⚠️ Hook removal warning: {result}")
except Exception as e:
print(f" ⚠️ Hook removal error (non-fatal): {e}")
def _remove_attention_hook_safely(self, original_attention: Any):
"""Safely remove attention hook"""
import mlx_lm.models.qwen3 as qwen3_module
qwen3_module.Attention = original_attention
print(" ✅ Hook removed with GPU protection")
def _create_bulletproof_execution_environment(self) -> Dict[str, Any]:
"""Create bulletproof execution environment with enhanced imports"""
import math
import numpy as np
import time
from typing import Optional, Tuple, Any
exec_globals = {
"__builtins__": __builtins__,
"mx": mx,
"nn": nn,
"np": np,
"math": math,
"time": time,
"Optional": Optional,
"Tuple": Tuple,
"Any": Any,
}
# Enhanced MLX-LM import with error handling
try:
exec_globals["mlx_lm"] = __import__("mlx_lm")
print(" ✅ MLX-LM imported for bulletproof execution")