|
7 | 7 | --shape_grouped with profile row count comparison. |
8 | 8 | """ |
9 | 9 |
|
| 10 | +import glob |
10 | 11 | import os |
11 | 12 | import sys |
12 | 13 | import csv |
@@ -52,7 +53,27 @@ def _write_csv(path, header, rows): |
52 | 53 | writer.writerow(row) |
53 | 54 |
|
54 | 55 |
|
| 56 | +def _cleanup_stale_lock_files(): |
| 57 | + """Remove stale FileBaton lock files left by killed subprocesses.""" |
| 58 | + build_dir = os.path.join(AITER_ROOT, "aiter", "jit", "build") |
| 59 | + if not os.path.isdir(build_dir): |
| 60 | + return |
| 61 | + lock_patterns = [ |
| 62 | + os.path.join(build_dir, "lock_*"), |
| 63 | + os.path.join(build_dir, "*", "build", "lock"), |
| 64 | + os.path.join(build_dir, "lock_3rdparty_*"), |
| 65 | + ] |
| 66 | + for pattern in lock_patterns: |
| 67 | + for lock_file in glob.glob(pattern): |
| 68 | + try: |
| 69 | + os.remove(lock_file) |
| 70 | + print(f"Cleaned up stale lock file: {lock_file}", flush=True) |
| 71 | + except OSError: |
| 72 | + pass |
| 73 | + |
| 74 | + |
55 | 75 | def _run_tuner(script, untuned, tuned, extra_args=None, timeout=300, mp=1): |
| 76 | + _cleanup_stale_lock_files() |
56 | 77 | cmd = [ |
57 | 78 | sys.executable, |
58 | 79 | os.path.join(AITER_ROOT, script), |
@@ -82,6 +103,7 @@ def _run_tuner(script, untuned, tuned, extra_args=None, timeout=300, mp=1): |
82 | 103 | env=env, |
83 | 104 | ) |
84 | 105 | except subprocess.TimeoutExpired as e: |
| 106 | + _cleanup_stale_lock_files() |
85 | 107 | raise AssertionError( |
86 | 108 | f"Tuner timed out after {timeout}s (likely GPU hang or infinite loop)\n" |
87 | 109 | f" cmd: {' '.join(cmd)}\n" |
@@ -512,65 +534,51 @@ def test_batched_bf16(self): |
512 | 534 |
|
513 | 535 | @unittest.skipUnless(_gpu_available(), "No GPU available") |
514 | 536 | class TestComparePipeline(unittest.TestCase): |
515 | | - """Test --compare and --compare --update_improved end-to-end.""" |
| 537 | + """Test --compare --update_improved end-to-end.""" |
516 | 538 |
|
517 | 539 | CONFIGS = { |
518 | 540 | "a8w8_blockscale": { |
519 | 541 | "script": "csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.py", |
520 | 542 | "header": ["M", "N", "K"], |
521 | | - "shapes": [(1, 1024, 512), (16, 1536, 7168)], |
| 543 | + "shapes": [(1, 1024, 512)], |
522 | 544 | "keys": ["cu_num", "M", "N", "K"], |
| 545 | + "timeout": 3600, |
523 | 546 | }, |
524 | 547 | } |
525 | 548 |
|
526 | | - def _run_compare(self, name, update_improved=False): |
527 | | - cfg = self.CONFIGS[name] |
| 549 | + def test_compare_and_update(self): |
| 550 | + """--compare --update_improved: tune, compare, update tuned CSV.""" |
| 551 | + cfg = self.CONFIGS["a8w8_blockscale"] |
| 552 | + timeout = cfg.get("timeout", 900) |
528 | 553 | tmp = tempfile.mkdtemp() |
529 | | - untuned = os.path.join(tmp, "untuned.csv") |
530 | | - tuned = os.path.join(tmp, "tuned.csv") |
531 | | - _write_csv(untuned, cfg["header"], cfg["shapes"]) |
532 | | - |
533 | | - extra = ["--compare"] |
534 | | - if update_improved: |
535 | | - extra.append("--update_improved") |
536 | | - result = _run_tuner( |
537 | | - cfg["script"], untuned, tuned, extra_args=extra, timeout=900 |
538 | | - ) |
539 | | - return result, tuned, tmp |
540 | | - |
541 | | - def test_compare_only(self): |
542 | | - """--compare runs pre/post benchmark and prints comparison.""" |
543 | | - result, tuned, tmp = self._run_compare("a8w8_blockscale", update_improved=False) |
544 | 554 | try: |
| 555 | + untuned = os.path.join(tmp, "untuned.csv") |
| 556 | + tuned = os.path.join(tmp, "tuned.csv") |
| 557 | + _write_csv(untuned, cfg["header"], cfg["shapes"]) |
| 558 | + |
| 559 | + result = _run_tuner( |
| 560 | + cfg["script"], |
| 561 | + untuned, |
| 562 | + tuned, |
| 563 | + extra_args=[ |
| 564 | + "--compare", |
| 565 | + "--update_improved", |
| 566 | + "--libtype", |
| 567 | + "ck", |
| 568 | + "--batch", |
| 569 | + "1", |
| 570 | + ], |
| 571 | + timeout=timeout, |
| 572 | + mp=1, |
| 573 | + ) |
545 | 574 | if result.returncode != 0: |
546 | 575 | print(f"\n=== compare STDOUT ===\n{result.stdout[-2000:]}") |
547 | 576 | print(f"\n=== compare STDERR ===\n{result.stderr[-2000:]}") |
548 | | - self.assertEqual(result.returncode, 0, "compare tuner failed") |
549 | | - output = result.stdout + result.stderr |
550 | | - self.assertIn( |
551 | | - "Compare Report", output, "Expected 'Compare Report' in output" |
552 | | - ) |
553 | | - finally: |
554 | | - import shutil |
555 | | - |
556 | | - shutil.rmtree(tmp, ignore_errors=True) |
557 | | - |
558 | | - def test_compare_update_improved(self): |
559 | | - """--compare --update_improved writes tuned CSV and prints comparison.""" |
560 | | - result, tuned, tmp = self._run_compare("a8w8_blockscale", update_improved=True) |
561 | | - try: |
562 | | - if result.returncode != 0: |
563 | | - print(f"\n=== compare+update STDOUT ===\n{result.stdout[-2000:]}") |
564 | | - print(f"\n=== compare+update STDERR ===\n{result.stderr[-2000:]}") |
565 | 577 | self.assertEqual(result.returncode, 0, "compare+update tuner failed") |
566 | | - self.assertTrue(os.path.exists(tuned), "tuned CSV not created") |
567 | 578 | output = result.stdout + result.stderr |
568 | 579 | self.assertIn( |
569 | 580 | "Compare Report", output, "Expected 'Compare Report' in output" |
570 | 581 | ) |
571 | | - df = pd.read_csv(tuned) |
572 | | - df.columns = df.columns.str.strip() |
573 | | - self.assertGreaterEqual(len(df), 1, "tuned CSV should have at least 1 row") |
574 | 582 | finally: |
575 | 583 | import shutil |
576 | 584 |
|
|
0 commit comments