Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions src/ninetoothed/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import csv
import enum
import functools
import hashlib
import inspect
import itertools
import multiprocessing
Expand Down Expand Up @@ -58,8 +59,19 @@ def build(

output_dir = pathlib.Path(output_dir)

configs = tuple(configs)

if meta_parameters is not None:
meta_parameters = tuple(meta_parameters)

fingerprint = _compute_fingerprint(premake, configs, meta_parameters, caller)

cached = _load_cached(
configs, meta_parameters, kernel_name=kernel_name, output_dir=output_dir
configs,
meta_parameters,
kernel_name=kernel_name,
output_dir=output_dir,
fingerprint=fingerprint,
)

if cached is not None:
Expand Down Expand Up @@ -168,6 +180,8 @@ def build(

kernel = _generate_launch_func(kernel_name=kernel_name, output_dir=output_dir)

_write_fingerprint(kernel_name, output_dir, fingerprint)

if meta_parameters is not None:
kernel_before_auto_tuning = kernel

Expand Down Expand Up @@ -571,12 +585,15 @@ def _make(premake, config, caller, kernel_name, output_dir):
return kernel_name_, param_names, combination, config, tensors


def _load_cached(configs, meta_parameters, *, kernel_name, output_dir):
def _load_cached(configs, meta_parameters, *, kernel_name, output_dir, fingerprint):
so_path = output_dir / f"{kernel_name}.so"

if not so_path.exists():
return None

if not _fingerprint_matches(kernel_name, output_dir, fingerprint):
return None

if meta_parameters is None:
return _load_launch_func(kernel_name=kernel_name, output_dir=output_dir)

Expand All @@ -599,6 +616,52 @@ def _load_cached(configs, meta_parameters, *, kernel_name, output_dir):
)


def _compute_fingerprint(premake, configs, meta_parameters, caller):
hasher = hashlib.sha256()

target = premake.func if isinstance(premake, functools.partial) else premake

try:
hasher.update(inspect.getsource(target).encode())
except (TypeError, OSError):
hasher.update(repr(target).encode())

if isinstance(premake, functools.partial):
hasher.update(repr(premake.args).encode())
hasher.update(repr(sorted(premake.keywords.items())).encode())

hasher.update(repr(configs).encode())
hasher.update(repr(meta_parameters).encode())
hasher.update(repr(caller).encode())

package_dir = pathlib.Path(__file__).parent

for path in (package_dir / "build.py", package_dir / "aot.py"):
try:
hasher.update(path.read_bytes())
except OSError:
pass

return hasher.hexdigest()


def _fingerprint_matches(kernel_name, output_dir, fingerprint):
path = _fingerprint_path(kernel_name, output_dir)

if not path.exists():
return False

return path.read_text().strip() == fingerprint


def _write_fingerprint(kernel_name, output_dir, fingerprint):
_fingerprint_path(kernel_name, output_dir).write_text(fingerprint)


def _fingerprint_path(kernel_name, output_dir):
return output_dir / f"{kernel_name}.fingerprint"


def _read_auto_tuning_cache(path):
if not path.exists():
return None
Expand Down
Loading