Skip to content

Commit 7833c2a

Browse files
[python] Harden bench harness for cross-compilation sweeps and improve reporting (#451)
Signed-off-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com> Co-authored-by: Nicolas Vasilache <nicolasvasilache@users.noreply.github.com>
1 parent c6803be commit 7833c2a

2 files changed

Lines changed: 81 additions & 45 deletions

File tree

contrib/kittens/test/bench/bench_harness.py

Lines changed: 78 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,49 @@ def _save_tmpfile(prefix, lines):
3535
return path
3636

3737

38+
def _save_error_file(prefix, phase, errors, repro_cmd_fn=None, num_iterations=None):
39+
"""Save errors grouped by category for easy debugging.
40+
41+
Output format:
42+
- Top-level summary (total count, unique categories)
43+
- One section per category, most frequent first
44+
- Each section has a header and all configs in that category
45+
"""
46+
from collections import defaultdict
47+
48+
by_category = defaultdict(list)
49+
for c, e, full in errors:
50+
by_category[e].append((c, full))
51+
52+
lines = [
53+
f"# {len(errors)} {phase} failures, {len(by_category)} unique errors",
54+
"#",
55+
]
56+
for msg, entries in sorted(by_category.items(), key=lambda kv: -len(kv[1])):
57+
lines.append(f"# {len(entries):>5}x {msg}")
58+
lines.append("")
59+
60+
for msg, entries in sorted(by_category.items(), key=lambda kv: -len(kv[1])):
61+
lines.append("=" * 78)
62+
lines.append(f"[{len(entries)}x] {msg}")
63+
lines.append("=" * 78)
64+
lines.append("")
65+
for c, full in entries:
66+
repro = ""
67+
if repro_cmd_fn:
68+
try:
69+
repro = f" | repro: {repro_cmd_fn(c, num_iterations)}"
70+
except Exception:
71+
pass
72+
lines.append(f" {c.label}{repro}")
73+
if full and full != msg.removeprefix(f"{phase}: "):
74+
for fline in full.split("\n"):
75+
lines.append(f" {fline}")
76+
lines.append("")
77+
78+
return _save_tmpfile(prefix, lines)
79+
80+
3881
def check_numpy_blas(label=""):
3982
import time
4083

@@ -54,12 +97,17 @@ def check_numpy_blas(label=""):
5497

5598

5699
def detect_num_gpus():
100+
"""Return the number of available GPUs, or 0 if none are present."""
57101
try:
102+
from aster.hip import system_has_gpu
103+
104+
if not system_has_gpu("gfx942"):
105+
return 0
58106
from aster.testing import hip_get_device_count
59107

60108
return max(1, hip_get_device_count())
61109
except Exception:
62-
return 1
110+
return 0
63111

64112

65113
def format_mlir_error(e):
@@ -501,14 +549,18 @@ def bench_perf_sweep(
501549

502550
exec_active = random.sample(exec_active, exec_sample)
503551

504-
print(f"\n--- Executing {len(exec_active)} configs ({num_gpus} GPU(s)) ---")
505-
results, exec_failed = run_on_gpus(
506-
exec_active,
507-
hsaco_paths,
508-
num_iterations,
509-
num_gpus,
510-
desc="Executing",
511-
)
552+
if num_gpus == 0:
553+
print("\nNo GPUs detected -- skipping execution phase.")
554+
results, exec_failed = [], []
555+
else:
556+
print(f"\n--- Executing {len(exec_active)} configs ({num_gpus} GPU(s)) ---")
557+
results, exec_failed = run_on_gpus(
558+
exec_active,
559+
hsaco_paths,
560+
num_iterations,
561+
num_gpus,
562+
desc="Executing",
563+
)
512564
failed.extend((c, e, "") for c, e in exec_failed)
513565

514566
# Summary: separate files for compile errors vs exec errors.
@@ -526,44 +578,17 @@ def bench_perf_sweep(
526578
saved_files.append(p)
527579
print(f"\nResults ({len(results)}) saved in {p}")
528580
if compile_errs:
529-
from collections import Counter
530-
531-
err_counts = Counter(e for _, e, _ in compile_errs)
532-
header = [
533-
f"# {len(compile_errs)} compile failures, {len(err_counts)} unique errors",
534-
"#",
535-
]
536-
for msg, cnt in err_counts.most_common(10):
537-
header.append(f"# {cnt:>5}x {msg}")
538-
header.append("#")
539-
detail = []
540-
for c, e, full in compile_errs:
541-
repro = ""
542-
if repro_cmd_fn:
543-
try:
544-
repro = f" | repro: {repro_cmd_fn(c, num_iterations)}"
545-
except Exception:
546-
pass
547-
detail.append(f"{c.label}: {e}{repro}")
548-
if full and full != e.removeprefix("compile: "):
549-
for line in full.split("\n"):
550-
detail.append(f" {line}")
551-
p = _save_tmpfile("bench_compile_errors_", header + detail)
581+
p = _save_error_file(
582+
"bench_compile_errors_",
583+
"compile",
584+
compile_errs,
585+
repro_cmd_fn,
586+
num_iterations,
587+
)
552588
saved_files.append(p)
553589
print(f"{len(compile_errs)} compile errors in {p}")
554590
if exec_errs:
555-
from collections import Counter
556-
557-
exec_counts = Counter(e for _, e, _ in exec_errs)
558-
header = [
559-
f"# {len(exec_errs)} exec failures, {len(exec_counts)} unique errors",
560-
"#",
561-
]
562-
for msg, cnt in exec_counts.most_common(10):
563-
header.append(f"# {cnt:>5}x {msg}")
564-
header.append("#")
565-
detail = [f"{c.label}: {e}" for c, e, _ in exec_errs]
566-
p = _save_tmpfile("bench_exec_errors_", header + detail)
591+
p = _save_error_file("bench_exec_errors_", "exec", exec_errs)
567592
saved_files.append(p)
568593
print(f"{len(exec_errs)} exec errors in {p}")
569594

@@ -631,9 +656,13 @@ def run_single(cfg, compile_fn, args, execute_fn):
631656
print(f"\n--- Assembly ---\n{asm}")
632657
return
633658

634-
A, B = make_inputs(cfg)
659+
has_gpu = detect_num_gpus() > 0
635660

636661
if args.hsaco:
662+
if not has_gpu:
663+
print("No GPUs detected -- skipping execution.")
664+
return
665+
A, B = make_inputs(cfg)
637666
print_config(cfg, args.iterations)
638667
_, times_ns = execute_fn(
639668
cfg, args.hsaco, args.iterations, A, B, skip_gpu_check=True
@@ -653,6 +682,10 @@ def run_single(cfg, compile_fn, args, execute_fn):
653682
raise SystemExit(1)
654683
if print_asm:
655684
print(f"\n--- Assembly ---\n{asm}")
685+
if not has_gpu:
686+
print("No GPUs detected -- skipping execution.")
687+
return
688+
A, B = make_inputs(cfg)
656689
_, times_ns = execute_fn(cfg, tmp.name, args.iterations, A, B)
657690

658691
measured = times_ns[WARMUP_ITERATIONS:]

contrib/kittens/test/bench/bench_perf_sweep_001_gemm_fp16_weak_scaled.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,9 @@ def verify_top_configs(
344344
return
345345
if num_gpus is None:
346346
num_gpus = detect_num_gpus()
347+
if num_gpus == 0:
348+
print("\nNo GPUs detected -- skipping correctness verification.")
349+
return
347350
top = results[:num_configs]
348351
to_verify = [c for c, *_ in top if c.label in hsaco_paths]
349352
if not to_verify:

0 commit comments

Comments
 (0)