Skip to content

Commit 5f76d44

Browse files
committed
Invalidate ninetoothed.build cache when inputs change
`_load_cached` previously returned the cached `.so` whenever it existed on disk, regardless of whether the source code, configs, or library internals had changed since it was produced. That made stale artifacts silently survive across `premake`, `configs`, or `ninetoothed` updates. Compute a SHA-256 fingerprint over the `premake` source (with `functools.partial` args and keywords folded in), `configs`, `meta_parameters`, `caller`, and the contents of `build.py` and `aot.py`. Persist it next to the build artifacts as `<kernel_name>.fingerprint`, and invalidate the cache when it does not match.
1 parent eea40cd commit 5f76d44

1 file changed

Lines changed: 65 additions & 2 deletions

File tree

src/ninetoothed/build.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import csv
33
import enum
44
import functools
5+
import hashlib
56
import inspect
67
import itertools
78
import multiprocessing
@@ -58,8 +59,19 @@ def build(
5859

5960
output_dir = pathlib.Path(output_dir)
6061

62+
configs = tuple(configs)
63+
64+
if meta_parameters is not None:
65+
meta_parameters = tuple(meta_parameters)
66+
67+
fingerprint = _compute_fingerprint(premake, configs, meta_parameters, caller)
68+
6169
cached = _load_cached(
62-
configs, meta_parameters, kernel_name=kernel_name, output_dir=output_dir
70+
configs,
71+
meta_parameters,
72+
kernel_name=kernel_name,
73+
output_dir=output_dir,
74+
fingerprint=fingerprint,
6375
)
6476

6577
if cached is not None:
@@ -168,6 +180,8 @@ def build(
168180

169181
kernel = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)
170182

183+
_write_fingerprint(kernel_name, output_dir, fingerprint)
184+
171185
if meta_parameters is not None:
172186
kernel_before_auto_tuning = kernel
173187

@@ -571,12 +585,15 @@ def _make(premake, config, caller, kernel_name, output_dir):
571585
return kernel_name_, param_names, combination, config, tensors
572586

573587

574-
def _load_cached(configs, meta_parameters, *, kernel_name, output_dir):
588+
def _load_cached(configs, meta_parameters, *, kernel_name, output_dir, fingerprint):
575589
so_path = output_dir / f"{kernel_name}.so"
576590

577591
if not so_path.exists():
578592
return None
579593

594+
if not _fingerprint_matches(kernel_name, output_dir, fingerprint):
595+
return None
596+
580597
if meta_parameters is None:
581598
return _load_launch_func(kernel_name=kernel_name, output_dir=output_dir)
582599

@@ -599,6 +616,52 @@ def _load_cached(configs, meta_parameters, *, kernel_name, output_dir):
599616
)
600617

601618

619+
def _compute_fingerprint(premake, configs, meta_parameters, caller):
620+
hasher = hashlib.sha256()
621+
622+
target = premake.func if isinstance(premake, functools.partial) else premake
623+
624+
try:
625+
hasher.update(inspect.getsource(target).encode())
626+
except (TypeError, OSError):
627+
hasher.update(repr(target).encode())
628+
629+
if isinstance(premake, functools.partial):
630+
hasher.update(repr(premake.args).encode())
631+
hasher.update(repr(sorted(premake.keywords.items())).encode())
632+
633+
hasher.update(repr(configs).encode())
634+
hasher.update(repr(meta_parameters).encode())
635+
hasher.update(repr(caller).encode())
636+
637+
package_dir = pathlib.Path(__file__).parent
638+
639+
for path in (package_dir / "build.py", package_dir / "aot.py"):
640+
try:
641+
hasher.update(path.read_bytes())
642+
except OSError:
643+
pass
644+
645+
return hasher.hexdigest()
646+
647+
648+
def _fingerprint_matches(kernel_name, output_dir, fingerprint):
649+
path = _fingerprint_path(kernel_name, output_dir)
650+
651+
if not path.exists():
652+
return False
653+
654+
return path.read_text().strip() == fingerprint
655+
656+
657+
def _write_fingerprint(kernel_name, output_dir, fingerprint):
658+
_fingerprint_path(kernel_name, output_dir).write_text(fingerprint)
659+
660+
661+
def _fingerprint_path(kernel_name, output_dir):
662+
return output_dir / f"{kernel_name}.fingerprint"
663+
664+
602665
def _read_auto_tuning_cache(path):
603666
if not path.exists():
604667
return None

0 commit comments

Comments
 (0)