Skip to content

Commit b140e55

Browse files
fix tuning test (#3118)
* fix tuning test * update --------- Co-authored-by: Xin Huang <Xin.Huang@amd.com>
1 parent 9c50619 commit b140e55

3 files changed

Lines changed: 80 additions & 45 deletions

File tree

op_tests/tuning_tests/README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Minimal test suite for validating the aiter tuning infrastructure.
1111
| `test_compare_logic.py` | 1 | No | Compare/update_improved: `_build_compare_update_plan`, `_merge_compare_filtered_results` |
1212
| `test_mp_tuner_logic.py` | 1 | No | `mp_tuner` polling: timeout, AcceleratorError, KeyError, pool restart |
1313
| `test_online_tune.py` | 1 | No | `AITER_ONLINE_TUNE` decision logic, `mp_lock` synchronization, MainFunc CSV write, cfg_2stages reload |
14-
| `test_tune_pipeline.py` | 2 | Yes | End-to-end: run each tuner on small shapes, verify output CSV; `--compare --update_improved`; `AITER_ONLINE_TUNE` e2e |
14+
| `test_tune_pipeline.py` | 2 | Yes | End-to-end: run each tuner on small shapes (mp=1 + mp=default), verify output CSV; `--compare --update_improved`; `AITER_ONLINE_TUNE` e2e |
1515
| `test_asm_splitk_guard.py` | 1 | No | `GemmTuner.asm_gemm_all_solutions` SplitK semaphore grid guard |
1616
| `test_run_config.py` | 2 | Yes | Run --run_config on ALL existing tuned CSVs (configs + model_configs) |
1717

@@ -58,6 +58,24 @@ python3 -m unittest op_tests.tuning_tests.test_run_config -v
5858
python3 -m unittest discover -s op_tests/tuning_tests -v
5959
```
6060

61+
### Running individual tuner tests
62+
63+
Each tuner in `test_tune_pipeline.py` has two variants: `_mp1` (single GPU) and `_mp_default` (all GPUs).
64+
65+
```bash
66+
# Run a specific tuner (both mp1 and mp_default)
67+
python3 -m pytest op_tests/tuning_tests/test_tune_pipeline.py -k "gradlib_bf16" -v
68+
69+
# Run only the single-GPU variant
70+
python3 -m pytest op_tests/tuning_tests/test_tune_pipeline.py -k "gradlib_bf16_mp1" -v
71+
72+
# Run only the multi-GPU variant
73+
python3 -m pytest op_tests/tuning_tests/test_tune_pipeline.py -k "gradlib_bf16_mp_default" -v
74+
75+
# Run a specific tuner with unittest
76+
python3 -m unittest op_tests.tuning_tests.test_tune_pipeline.TestTunePipeline.test_a8w8_blockscale_mp1 -v
77+
```
78+
6179
## Reproducing with custom config
6280

6381
Use `TUNE_TEST_FAMILY` to run `--run_config` for a specific family. Config is resolved via `AITER_CONFIGS` automatically:

op_tests/tuning_tests/test_tune_pipeline.py

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
--shape_grouped with profile row count comparison.
88
"""
99

10+
import glob
1011
import os
1112
import sys
1213
import csv
@@ -52,7 +53,27 @@ def _write_csv(path, header, rows):
5253
writer.writerow(row)
5354

5455

56+
def _cleanup_stale_lock_files():
57+
"""Remove stale FileBaton lock files left by killed subprocesses."""
58+
build_dir = os.path.join(AITER_ROOT, "aiter", "jit", "build")
59+
if not os.path.isdir(build_dir):
60+
return
61+
lock_patterns = [
62+
os.path.join(build_dir, "lock_*"),
63+
os.path.join(build_dir, "*", "build", "lock"),
64+
os.path.join(build_dir, "lock_3rdparty_*"),
65+
]
66+
for pattern in lock_patterns:
67+
for lock_file in glob.glob(pattern):
68+
try:
69+
os.remove(lock_file)
70+
print(f"Cleaned up stale lock file: {lock_file}", flush=True)
71+
except OSError:
72+
pass
73+
74+
5575
def _run_tuner(script, untuned, tuned, extra_args=None, timeout=300, mp=1):
76+
_cleanup_stale_lock_files()
5677
cmd = [
5778
sys.executable,
5879
os.path.join(AITER_ROOT, script),
@@ -82,6 +103,7 @@ def _run_tuner(script, untuned, tuned, extra_args=None, timeout=300, mp=1):
82103
env=env,
83104
)
84105
except subprocess.TimeoutExpired as e:
106+
_cleanup_stale_lock_files()
85107
raise AssertionError(
86108
f"Tuner timed out after {timeout}s (likely GPU hang or infinite loop)\n"
87109
f" cmd: {' '.join(cmd)}\n"
@@ -512,65 +534,51 @@ def test_batched_bf16(self):
512534

513535
@unittest.skipUnless(_gpu_available(), "No GPU available")
514536
class TestComparePipeline(unittest.TestCase):
515-
"""Test --compare and --compare --update_improved end-to-end."""
537+
"""Test --compare --update_improved end-to-end."""
516538

517539
CONFIGS = {
518540
"a8w8_blockscale": {
519541
"script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py",
520542
"header": ["M", "N", "K"],
521-
"shapes": [(1, 1024, 512), (16, 1536, 7168)],
543+
"shapes": [(1, 1024, 512)],
522544
"keys": ["cu_num", "M", "N", "K"],
545+
"timeout": 3600,
523546
},
524547
}
525548

526-
def _run_compare(self, name, update_improved=False):
527-
cfg = self.CONFIGS[name]
549+
def test_compare_and_update(self):
550+
"""--compare --update_improved: tune, compare, update tuned CSV."""
551+
cfg = self.CONFIGS["a8w8_blockscale"]
552+
timeout = cfg.get("timeout", 900)
528553
tmp = tempfile.mkdtemp()
529-
untuned = os.path.join(tmp, "untuned.csv")
530-
tuned = os.path.join(tmp, "tuned.csv")
531-
_write_csv(untuned, cfg["header"], cfg["shapes"])
532-
533-
extra = ["--compare"]
534-
if update_improved:
535-
extra.append("--update_improved")
536-
result = _run_tuner(
537-
cfg["script"], untuned, tuned, extra_args=extra, timeout=900
538-
)
539-
return result, tuned, tmp
540-
541-
def test_compare_only(self):
542-
"""--compare runs pre/post benchmark and prints comparison."""
543-
result, tuned, tmp = self._run_compare("a8w8_blockscale", update_improved=False)
544554
try:
555+
untuned = os.path.join(tmp, "untuned.csv")
556+
tuned = os.path.join(tmp, "tuned.csv")
557+
_write_csv(untuned, cfg["header"], cfg["shapes"])
558+
559+
result = _run_tuner(
560+
cfg["script"],
561+
untuned,
562+
tuned,
563+
extra_args=[
564+
"--compare",
565+
"--update_improved",
566+
"--libtype",
567+
"ck",
568+
"--batch",
569+
"1",
570+
],
571+
timeout=timeout,
572+
mp=1,
573+
)
545574
if result.returncode != 0:
546575
print(f"\n=== compare STDOUT ===\n{result.stdout[-2000:]}")
547576
print(f"\n=== compare STDERR ===\n{result.stderr[-2000:]}")
548-
self.assertEqual(result.returncode, 0, "compare tuner failed")
549-
output = result.stdout + result.stderr
550-
self.assertIn(
551-
"Compare Report", output, "Expected 'Compare Report' in output"
552-
)
553-
finally:
554-
import shutil
555-
556-
shutil.rmtree(tmp, ignore_errors=True)
557-
558-
def test_compare_update_improved(self):
559-
"""--compare --update_improved writes tuned CSV and prints comparison."""
560-
result, tuned, tmp = self._run_compare("a8w8_blockscale", update_improved=True)
561-
try:
562-
if result.returncode != 0:
563-
print(f"\n=== compare+update STDOUT ===\n{result.stdout[-2000:]}")
564-
print(f"\n=== compare+update STDERR ===\n{result.stderr[-2000:]}")
565577
self.assertEqual(result.returncode, 0, "compare+update tuner failed")
566-
self.assertTrue(os.path.exists(tuned), "tuned CSV not created")
567578
output = result.stdout + result.stderr
568579
self.assertIn(
569580
"Compare Report", output, "Expected 'Compare Report' in output"
570581
)
571-
df = pd.read_csv(tuned)
572-
df.columns = df.columns.str.strip()
573-
self.assertGreaterEqual(len(df), 1, "tuned CSV should have at least 1 row")
574582
finally:
575583
import shutil
576584

op_tests/tuning_tests/test_tuner_infra.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,8 @@ def test_two_files_merge_dedup(self):
398398
if os.path.exists(merged_path):
399399
os.unlink(merged_path)
400400

401-
def test_column_mismatch_raises(self):
402-
"""Two CSVs with different columns -> AssertionError."""
401+
def test_column_mismatch_merges(self):
402+
"""Two CSVs with different columns -> merged with missing cols filled."""
403403
tuner = _StubTuner.get()
404404
h1 = [
405405
"gfx",
@@ -438,8 +438,17 @@ def test_column_mismatch_raises(self):
438438
h2,
439439
[[TEST_GFX, 304, 1, 1024, 512, "x", 100.0, "k0", 1.0, 1.0, 0.01]],
440440
)
441-
with self.assertRaises(AssertionError):
442-
tuner.update_config_files(f"{f1}{os.pathsep}{f2}", "test_mismatch")
441+
merged_path = tuner.update_config_files(
442+
f"{f1}{os.pathsep}{f2}", "test_mismatch"
443+
)
444+
try:
445+
df = pd.read_csv(merged_path)
446+
self.assertIn("extra_col", df.columns)
447+
self.assertIn("kernelId", df.columns)
448+
self.assertIn("splitK", df.columns)
449+
finally:
450+
if os.path.exists(merged_path):
451+
os.unlink(merged_path)
443452

444453
def test_missing_second_file(self):
445454
"""Second path doesn't exist -> only first file data."""

0 commit comments

Comments
 (0)