diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py index 3e87ceacaa..0a6090281d 100644 --- a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -1,9 +1,14 @@ +import builtins import errno from pathlib import Path + +import cloudpickle import pytest +import tilelang.cache.kernel_cache as kernel_cache_mod from tilelang.cache.kernel_cache import KernelCache from tilelang.env import env +from tilelang.jit.adapter.base import CachedTextSource from tilelang.jit.adapter.nvrtc.kernel_cache import NVRTCKernelCache @@ -46,6 +51,291 @@ def _make_fake_nvrtc_kernel(tmp_path): return _FakeKernel(str(lib_path)) +def _write_complete_kernel_cache_entry( + cache: KernelCache, + key: str, + device_source: str = "// device kernel", + host_source: str = "// host kernel", + prim_func=None, +) -> Path: + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) + (cache_path / cache.device_kernel_path).write_text(device_source) + (cache_path / cache.host_kernel_path).write_text(host_source) + (cache_path / cache.kernel_lib_path).write_bytes(b"fake-so") + with (cache_path / cache.params_path).open("wb") as f: + cloudpickle.dump(["param"], f) + if prim_func is not None: + with (cache_path / cache.prim_func_path).open("wb") as f: + cloudpickle.dump(prim_func, f) + return cache_path + + +def test_kernel_cache_disk_hit_defers_source_loading(cache_dirs, monkeypatch): + cache = KernelCache() + key = "lazy-source-load" + cache_path = _write_complete_kernel_cache_entry(cache, key) + + sentinel = object() + captured = {} + + def fail_source_load(*args, **kwargs): + raise AssertionError("disk cache hit should pass source paths through for lazy loading") + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(cache, "_load_kernel_source", fail_source_load) + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is sentinel + assert captured["host_kernel_source"] == CachedTextSource(path=str(cache_path / cache.host_kernel_path)) + assert captured["device_kernel_source"] == CachedTextSource(path=str(cache_path / cache.device_kernel_path)) + assert captured["kernel_lib_path"] == str(cache_path / cache.kernel_lib_path) + assert captured["params"] == ["param"] + + +def test_kernel_cache_disk_hit_perf_skips_large_source_file_reads(cache_dirs, monkeypatch): + cache = KernelCache() + key = "lazy-source-load-perf" + large_source = "// source\n" + ("x" * (2 * 1024 * 1024)) + cache_path = _write_complete_kernel_cache_entry( + cache, + key, + device_source=large_source, + host_source=large_source, + ) + source_paths = { + (cache_path / cache.device_kernel_path).resolve(), + (cache_path / cache.host_kernel_path).resolve(), + } + source_read_count = 0 + sentinel = object() + + real_open = builtins.open + + def tracking_open(file, *args, **kwargs): + nonlocal source_read_count + mode = args[0] if args else kwargs.get("mode", "r") + try: + path = Path(file).resolve() + except TypeError: + return real_open(file, *args, **kwargs) + if "r" in mode and path in source_paths: + source_read_count += 1 + raise AssertionError("cache perf regression: source file read during disk cache hit") + return real_open(file, *args, **kwargs) + + def fake_from_database(cls, **kwargs): + return sentinel + + monkeypatch.setattr(builtins, "open", tracking_open) + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is sentinel + assert source_read_count == 0 + + +def test_kernel_cache_frontend_hit_loads_serialized_prim_func(cache_dirs, monkeypatch): + cache = KernelCache() + key = "frontend-kernel-key" + prim_func = {"name": "cached_prim_func"} + cache_path = _write_complete_kernel_cache_entry(cache, key, prim_func=prim_func) + cache.store_frontend_cache("frontend-key", key) + + sentinel = object() + captured = {} + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache.load_frontend_cached( + "frontend-key", + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + ) + + assert loaded is sentinel + assert captured["func"] == prim_func + assert captured["host_kernel_source"] == CachedTextSource(path=str(cache_path / cache.host_kernel_path)) + assert captured["device_kernel_source"] == CachedTextSource(path=str(cache_path / cache.device_kernel_path)) + + +def test_kernel_cache_frontend_hit_round_trips_real_prim_func(cache_dirs, tmp_path, monkeypatch): + import tilelang.language as T + import tvm + + @T.prim_func + def kernel(): + T.evaluate(0) + + cache = KernelCache() + key = "frontend-real-prim-func-key" + frontend_key = "frontend-real-prim-func" + cache_path = Path(cache._get_cache_path(key)) + + cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path), func=kernel) + cache.store_frontend_cache(frontend_key, key) + + sentinel = object() + captured = {} + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache.load_frontend_cached( + frontend_key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + ) + + assert loaded is sentinel + assert (cache_path / cache.prim_func_path).exists() + assert isinstance(captured["func"], tvm.tir.PrimFunc) + assert tvm.ir.structural_equal(captured["func"], kernel) + assert str(captured["func"].attrs["global_symbol"]) == str(kernel.attrs["global_symbol"]) + + +def test_jit_frontend_cache_hit_skips_tir_elaboration(monkeypatch): + import tilelang + import tilelang.language as T + from tilelang.jit import JITImpl + + sentinel = object() + calls = [] + + @tilelang.jit + def frontend_cached_kernel(block_m: int = 16): + @T.prim_func + def kernel(): + T.evaluate(0) + + return kernel + + def fake_load_frontend_cached(frontend_key_data, **kwargs): + calls.append((frontend_key_data, kwargs)) + return sentinel + + def fail_compile(self, *args, **kwargs): + raise AssertionError("frontend cache hit should not elaborate TIR") + + monkeypatch.setattr("tilelang.cache.load_frontend_cached", fake_load_frontend_cached) + monkeypatch.setattr(JITImpl, "compile", fail_compile) + + assert frontend_cached_kernel(block_m=32) is sentinel + assert calls + assert "frontend_cached_kernel" in calls[0][0]["function"] + + +def test_kernel_cache_disk_hit_rejects_entries_missing_sources(cache_dirs, monkeypatch): + cache = KernelCache() + key = "missing-source-entry" + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) + (cache_path / cache.kernel_lib_path).write_bytes(b"fake-so") + with (cache_path / cache.params_path).open("wb") as f: + cloudpickle.dump(["param"], f) + + def fail_from_database(cls, **kwargs): + raise AssertionError("incomplete cache entries should miss before rebuilding from database") + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fail_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is None + + +def test_nvrtc_adapter_host_source_lazy_loads(tmp_path): + pytest.importorskip("cuda.bindings.driver", reason="NVRTC adapter requires cuda-python") + from tilelang.jit.adapter.nvrtc.adapter import NVRTCKernelAdapter + + host_source_path = tmp_path / "host_kernel.cu" + host_source_path.write_text("// nvrtc host source") + adapter = NVRTCKernelAdapter.__new__(NVRTCKernelAdapter) + adapter.host_func = None + adapter._host_kernel_source_path = str(host_source_path) + + assert adapter.get_host_source() == "// nvrtc host source" + assert adapter.host_func == "// nvrtc host source" + + +def test_cutedsl_adapter_host_source_lazy_loads(tmp_path): + from tilelang.jit.adapter.cutedsl.adapter import CuTeDSLKernelAdapter + + host_source_path = tmp_path / "kernel.py" + host_source_path.write_text("# cutedsl host source") + adapter = CuTeDSLKernelAdapter.__new__(CuTeDSLKernelAdapter) + adapter.host_kernel_source = None + adapter.host_func = None + adapter._host_kernel_source_path = str(host_source_path) + + assert adapter.get_host_source() == "# cutedsl host source" + assert adapter.host_kernel_source == "# cutedsl host source" + + +def test_tvm_ffi_source_fallback_handles_missing_runtime_module(): + from tilelang.jit.adapter.tvm_ffi import TVMFFIKernelAdapter + + adapter = TVMFFIKernelAdapter.__new__(TVMFFIKernelAdapter) + adapter.host_kernel_source = None + adapter.device_kernel_source = None + adapter._host_kernel_source_path = None + adapter._device_kernel_source_path = None + adapter.rt_mod = None + + assert adapter.get_host_source() is None + assert adapter.get_device_source() is None + assert adapter.get_kernel_source(kernel_only=True) == "" + assert adapter.get_kernel_source(kernel_only=False) == "" + + def test_kernel_cache_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): cache = KernelCache() key = "atomic-repair" diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index be5e364606..0d1c43e618 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -13,6 +13,7 @@ import errno from tilelang.jit import JITKernel +from tilelang.jit.adapter.base import CachedTextSource import cloudpickle import os import shutil @@ -358,8 +359,8 @@ def _load_kernel_from_disk( if host_kernel_source and device_kernel_source and kernel_params: return JITKernel.from_database( func=func, - host_kernel_source=host_kernel_source, - device_kernel_source=device_kernel_source, + host_kernel_source=CachedTextSource(text=host_kernel_source), + device_kernel_source=CachedTextSource(text=device_kernel_source), kernel_lib_path=kernel_lib_path, params=kernel_params, target=target, diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index d0ee6c9a44..b091a2525f 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging +import json +from hashlib import sha256 from typing import TYPE_CHECKING, Literal from tvm.target import Target from tvm.tir import PrimFunc @@ -27,22 +29,21 @@ } -def cached( - func: PrimFunc = None, - out_idx: list[int] = None, - *args, - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, - verbose: bool | None = None, - pass_configs: dict | None = None, - compile_flags: list[str] | str | None = None, -) -> JITKernel: - """ - Caches and reuses compiled kernels (using KernelCache class). - """ - # Apply environment variable defaults if parameters are not explicitly set - # This is the SINGLE source of truth for env var processing +def _normalize_for_json(value): + if isinstance(value, dict): + return {str(k): _normalize_for_json(v) for k, v in sorted(value.items(), key=lambda item: str(item[0]))} + if isinstance(value, (list, tuple)): + return [_normalize_for_json(v) for v in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return repr(value) + + +def _resolve_cache_dispatch( + target: str | Target | None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None, + verbose: bool | None, +): if target is None: target = env.get_default_target() if execution_backend is None: @@ -50,40 +51,136 @@ def cached( if verbose is None: verbose = env.get_default_verbose() - # Normalize target and resolve execution backend before proceeding from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target norm_target = Target(_determine_target(target)) if isinstance(target, str) else target requested_backend = execution_backend - execution_backend = resolve_execution_backend(requested_backend, norm_target) + resolved_backend = resolve_execution_backend(requested_backend, norm_target) if verbose: allowed_now = allowed_backends_for_target(norm_target, include_unavailable=False) - # Avoid duplicate logs when caller already resolved explicitly - if requested_backend in (None, "auto") or requested_backend != execution_backend: + if requested_backend in (None, "auto") or requested_backend != resolved_backend: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.info( "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", - execution_backend, + resolved_backend, requested_backend, norm_target.kind.name, ", ".join(sorted(allowed_now)), ) - if execution_backend in _dispatch_map: - return _dispatch_map[execution_backend].cached( - func, - out_idx, - *args, - target=norm_target, - target_host=target_host, - execution_backend=execution_backend, - verbose=verbose, - pass_configs=pass_configs, - compile_flags=compile_flags, - ) - else: - raise ValueError(f'Cannot find support for execution backend "{execution_backend}"') + if resolved_backend not in _dispatch_map: + raise ValueError(f'Cannot find support for execution backend "{resolved_backend}"') + return _dispatch_map[resolved_backend], norm_target, resolved_backend, verbose + + +def _make_frontend_cache_key( + frontend_key_data: dict, + *, + target: Target, + target_host: str | Target | None, + execution_backend: str, + out_idx: list[int] | int | None, + pass_configs: dict | None, + compile_flags: list[str] | str | None, +) -> str: + key_data = { + "frontend": _normalize_for_json(frontend_key_data), + "out_idx": _normalize_for_json(out_idx), + "target": str(target), + "target_host": str(target_host) if target_host else None, + "execution_backend": execution_backend, + "pass_configs": _normalize_for_json(pass_configs), + "compile_flags": _normalize_for_json(compile_flags), + } + key_string = json.dumps(key_data, sort_keys=True) + return sha256(key_string.encode()).hexdigest() + + +def cached( + func: PrimFunc = None, + out_idx: list[int] = None, + *args, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel: + """ + Caches and reuses compiled kernels (using KernelCache class). + """ + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + return cache.cached( + func, + out_idx, + *args, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def load_frontend_cached( + frontend_key_data: dict, + *, + out_idx: list[int] | int | None = None, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel | None: + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + frontend_key = _make_frontend_cache_key( + frontend_key_data, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + out_idx=out_idx, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + return cache.load_frontend_cached( + frontend_key, + target=norm_target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def store_frontend_cache( + frontend_key_data: dict, + kernel_key: str, + *, + out_idx: list[int] | int | None = None, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> None: + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + frontend_key = _make_frontend_cache_key( + frontend_key_data, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + out_idx=out_idx, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + cache.store_frontend_cache(frontend_key, kernel_key, verbose=verbose) def clear_cache(): diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index c9f2836ecf..2ed4e7f2e7 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -23,6 +23,7 @@ from tilelang.utils.language import get_prim_func_name from tilelang import env from tilelang.jit import JITKernel +from tilelang.jit.adapter.base import CachedTextSource from tilelang import __version__ import platform @@ -47,7 +48,9 @@ class KernelCache: host_kernel_path = "host_kernel.cu" kernel_lib_path = "kernel_lib.so" params_path = "params.pkl" + prim_func_path = "prim_func.pkl" cache_root_dir = "kernels" + frontend_cache_root_dir = "frontend" staging_root_dir = ".staging" @staticmethod @@ -167,6 +170,7 @@ def _create_dirs(): os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) os.makedirs(KernelCache._get_namespace_root(), exist_ok=True) os.makedirs(KernelCache._get_cache_root(), exist_ok=True) + os.makedirs(KernelCache._get_frontend_cache_root(), exist_ok=True) os.makedirs(KernelCache._get_staging_root(), exist_ok=True) staging_root = KernelCache._get_staging_root() @@ -183,6 +187,10 @@ def _get_namespace_root() -> str: def _get_cache_root() -> str: return os.path.join(KernelCache._get_namespace_root(), KernelCache.cache_root_dir) + @staticmethod + def _get_frontend_cache_root() -> str: + return os.path.join(KernelCache._get_namespace_root(), KernelCache.frontend_cache_root_dir) + @staticmethod def _get_staging_root() -> str: return os.path.join(KernelCache._get_namespace_root(), KernelCache.staging_root_dir) @@ -331,19 +339,23 @@ def cached( ) return self._memory_cache[key] - if verbose: - self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") + if verbose: + self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") - # Then check disk cache - kernel = self._load_kernel_from_disk( - key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose - ) - if kernel is not None: - if verbose: - self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") - # Populate memory cache with disk result + # Disk loads can be expensive for large kernel sets; keep them outside + # the global cache lock so independent cache hits can proceed in parallel. + kernel = self._load_kernel_from_disk( + key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose + ) + if kernel is not None: + if verbose: + self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing self._memory_cache[key] = kernel - return kernel + return kernel if verbose: self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '')}") @@ -366,7 +378,12 @@ def cached( self._set_adapter_cache_path(kernel, cache_path) # Store in memory cache after compilation - self._memory_cache[key] = kernel + self._tag_kernel_cache_entry(kernel, key, self._get_cache_path(key)) + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + self._memory_cache[key] = kernel return kernel def clear_cache(self): @@ -389,6 +406,25 @@ def _get_cache_path(self, key: str) -> str: """ return os.path.join(self._get_cache_root(), key) + def _get_frontend_cache_path(self, frontend_key: str) -> str: + return os.path.join( + self._get_frontend_cache_root(), + f"{self._sanitize_path_component(frontend_key)}.json", + ) + + @staticmethod + def _tag_kernel_cache_entry(kernel: JITKernel, key: str, cache_path: str) -> None: + try: + kernel._tilelang_cache_key = key + kernel._tilelang_cache_path = cache_path + except (AttributeError, TypeError): + logging.getLogger(__name__).debug( + "Could not tag kernel cache entry for key %s at %s", + key, + cache_path, + exc_info=True, + ) + @staticmethod def _load_binary(path: str): with open(path, "rb") as file: @@ -460,6 +496,19 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) + # Save the PrimFunc so frontend cache hits can rebuild the adapter + # without re-elaborating the Python DSL in a fresh process. This is + # optional for backward compatibility with existing cache entries. + if func is not None: + prim_func_path = os.path.join(staging_path, self.prim_func_path) + if verbose: + self.logger.debug(f"Saving PrimFunc to disk: {prim_func_path}") + try: + KernelCache._safe_write_file(prim_func_path, "wb", lambda file: cloudpickle.dump(func, file)) + except Exception: + if verbose: + self.logger.exception("Error saving optional PrimFunc cache metadata") + missing_files = self._get_missing_complete_cache_files(staging_path) if missing_files: missing_names = ", ".join(os.path.basename(path) for path in missing_files) @@ -514,14 +563,12 @@ def _load_kernel_from_disk( kernel_lib_path = os.path.join(cache_path, self.kernel_lib_path) params_path = os.path.join(cache_path, self.params_path) - required_files = self._get_required_files(cache_path) - - if not all([os.path.exists(file) for file in required_files]): + missing_files = self._get_missing_complete_cache_files(cache_path) + if missing_files: + if verbose: + self.logger.debug("Disk cache entry is incomplete; missing files: %s", missing_files) return None - # Load the kernel source file (optional) - device_kernel_source, host_kernel_source = self._load_kernel_source(device_kernel_path, host_kernel_path, verbose) - # Load kernel parameters kernel_params: list[KernelParam] | None = None try: @@ -532,10 +579,10 @@ def _load_kernel_from_disk( except Exception: self.logger.exception("Error loading kernel parameters from disk") - return self._build_kernel( + kernel = self._build_kernel( func=func, - host_kernel_source=host_kernel_source, - device_kernel_source=device_kernel_source, + host_kernel_source=CachedTextSource(path=host_kernel_path), + device_kernel_source=CachedTextSource(path=device_kernel_path), kernel_lib_path=kernel_lib_path, kernel_params=kernel_params, target=target, @@ -545,6 +592,95 @@ def _load_kernel_from_disk( pass_configs=pass_configs, compile_flags=compile_flags, ) + if kernel is not None: + prim_func_path = os.path.join(cache_path, self.prim_func_path) + if func is not None and not os.path.exists(prim_func_path): + try: + KernelCache._safe_write_file(prim_func_path, "wb", lambda file: cloudpickle.dump(func, file)) + except Exception: + if verbose: + self.logger.exception("Error upgrading cache entry with PrimFunc") + self._tag_kernel_cache_entry(kernel, key, cache_path) + return kernel + + def load_frontend_cached( + self, + frontend_key: str, + *, + target: str | Target = "auto", + target_host: str | Target | None = None, + out_idx: list[int] | None = None, + execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, + verbose: bool = False, + ) -> JITKernel | None: + if not env.is_cache_enabled(): + return None + + frontend_path = self._get_frontend_cache_path(frontend_key) + try: + with open(frontend_path, encoding="utf-8") as file: + frontend_entry = json.load(file) + except OSError: + return None + except Exception: + self.logger.exception("Error loading frontend cache entry") + return None + + key = frontend_entry.get("kernel_key") + if not isinstance(key, str) or not key: + return None + + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + + cache_path = self._get_cache_path(key) + prim_func_path = os.path.join(cache_path, self.prim_func_path) + try: + with open(prim_func_path, "rb") as file: + func = cloudpickle.load(file) + except OSError: + return None + except Exception: + self.logger.exception("Error loading PrimFunc from frontend cache entry") + return None + + kernel = self._load_kernel_from_disk( + key, + target=target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + pass_configs=pass_configs, + compile_flags=compile_flags, + func=func, + verbose=verbose, + ) + if kernel is None: + return None + + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + self._memory_cache[key] = kernel + return kernel + + def store_frontend_cache(self, frontend_key: str, kernel_key: str, *, verbose: bool = False) -> None: + if not env.is_cache_enabled(): + return + + KernelCache._create_dirs() + frontend_path = self._get_frontend_cache_path(frontend_key) + payload = {"kernel_key": kernel_key} + try: + KernelCache._safe_write_file(frontend_path, "w", lambda file: json.dump(payload, file, sort_keys=True)) + except Exception: + if verbose: + self.logger.exception("Error saving frontend cache entry") def _clear_disk_cache(self): """ @@ -639,8 +775,8 @@ def _set_adapter_cache_path(self, kernel: JITKernel, cache_path: str): def _build_kernel( self, func: Callable | None, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, kernel_params: list[KernelParam] | None, target: str | Target, @@ -652,10 +788,6 @@ def _build_kernel( ) -> JITKernel | None: # Check all required components and report specific failures missing_components = [] - if not host_kernel_source: - missing_components.append("host_kernel_source") - if not device_kernel_source: - missing_components.append("device_kernel_source") if not kernel_params: missing_components.append("kernel_params") diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index e5a7f10950..c067099cbd 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -427,6 +427,20 @@ def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) return key + def _frontend_cache_key_data(self, key: tuple) -> dict[str, Any]: + func_name = getattr(getattr(self.func, "orig_func", self.func), "__name__", "jit_kernel") + func_qualname = getattr(getattr(self.func, "orig_func", self.func), "__qualname__", func_name) + func_module = getattr(getattr(self.func, "orig_func", self.func), "__module__", None) + return { + "function": func_name, + "qualname": func_qualname, + "module": func_module, + "source": self.func_source, + "signature": str(self.signature), + "key": repr(key), + "mode": self.mode, + } + def get_kernel_source(self, *args: _P.args, **kwargs: _P.kwargs) -> str: kernel = self.compile(*args, **kwargs) return kernel.get_kernel_source() @@ -458,7 +472,42 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: key, kernel_args = self.func.parse_args(*args, **kwargs) kernel = self._kernel_cache.get(key, None) if kernel is None: - kernel = self.compile(*args, **kwargs) + frontend_key_data = None + # Frontend cache is only safe when lazy-mode parse_args leaves no + # runtime kernel_args; then _frontend_cache_key_data fully identifies + # the compiled kernel, assuming compile-time values have stable reprs. + if self.mode == "lazy" and not kernel_args: + frontend_key_data = self._frontend_cache_key_data(key) + from tilelang.cache import load_frontend_cached + + kernel = load_frontend_cached( + frontend_key_data, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) + if kernel is None: + kernel = self.compile(*args, **kwargs) + if frontend_key_data is not None: + kernel_key = getattr(kernel, "_tilelang_cache_key", None) + if kernel_key: + from tilelang.cache import store_frontend_cache + + store_frontend_cache( + frontend_key_data, + kernel_key, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) self._kernel_cache[key] = kernel # eager mode: execute kernel immediately and return result diff --git a/tilelang/jit/adapter/__init__.py b/tilelang/jit/adapter/__init__.py index 0d99452855..eb48af408d 100644 --- a/tilelang/jit/adapter/__init__.py +++ b/tilelang/jit/adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseKernelAdapter # noqa: F401 +from .base import BaseKernelAdapter, CachedTextSource # noqa: F401 from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401 diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index afe55d7ec8..8691f4f74b 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -3,11 +3,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any from collections.abc import Callable -from tilelang.engine.param import KernelParam +from dataclasses import dataclass +from typing import Any + import torch +from tilelang.engine.param import KernelParam + + +@dataclass(frozen=True) +class CachedTextSource: + text: str | None = None + path: str | None = None + class BaseKernelAdapter(ABC): func: Callable | None = None @@ -95,3 +104,24 @@ def get_kernel_source(self, kernel_only: bool = True) -> str: def _post_init(self): self.func = self._convert_torch_func() + + def _set_cached_text_source(self, source_attr: str, path_attr: str, source: CachedTextSource) -> None: + setattr(self, source_attr, source.text) + setattr(self, path_attr, source.path) + + def _load_cached_text_source(self, source_attr: str, path_attr: str) -> str | None: + source = getattr(self, source_attr, None) + if source is not None: + return source + + path = getattr(self, path_attr, None) + if path is None: + return None + + try: + with open(path, encoding="utf-8") as file: + source = file.read() + except OSError: + return None + setattr(self, source_attr, source) + return source diff --git a/tilelang/jit/adapter/cutedsl/adapter.py b/tilelang/jit/adapter/cutedsl/adapter.py index 23f51f9ffb..40057f2ffa 100644 --- a/tilelang/jit/adapter/cutedsl/adapter.py +++ b/tilelang/jit/adapter/cutedsl/adapter.py @@ -15,7 +15,7 @@ from tilelang.jit.adapter.cutedsl.libgen import CuTeDSLLibraryGenerator from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource logger = logging.getLogger(__name__) @@ -105,8 +105,8 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -115,8 +115,9 @@ def from_database( adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) + adapter.host_func = host_kernel_source.text if isinstance(func_or_mod, tir.PrimFunc): gsym = func_or_mod.attrs.get("global_symbol") @@ -221,6 +222,13 @@ def _lookup_dynamic_symbolic(self, v: tir.Var) -> tuple[int, int, int]: return self._dynamic_symbolic_name_map[v.name] raise KeyError(f"Dynamic symbolic variable '{v.name}' not found in symbolic map") + def get_host_source(self) -> str | None: + """Get the cached host-side source code.""" + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + if source is not None: + return source + return getattr(self, "host_func", None) + def get_kernel_source(self, kernel_only: bool = True) -> str | None: """Get the CUDA kernel source code. @@ -229,7 +237,13 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None: str | None The kernel source code, or None if not available """ - return self.device_kernel_source + if not kernel_only: + return self.get_host_source() + + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source def _forward_from_prebuild_lib(self, *args, stream: int | None = None, device_id: int = 0): """Low-level function to call the compiled CUDA kernel. diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index ff82bf4d34..0c5b7b2a3d 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -13,7 +13,7 @@ from tvm import tir from tvm.relax import TensorType -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target @@ -156,8 +156,8 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -166,9 +166,9 @@ def from_database( adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter.kernel_global_source = device_kernel_source # Set alias for compatibility + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) + adapter.kernel_global_source = device_kernel_source.text # Set alias for compatibility adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): @@ -387,12 +387,16 @@ def is_dynamic(self): def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" if kernel_only: - return self.device_kernel_source + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source else: # Wrapper only has host kernel source - assert self.host_kernel_source is not None, "Wrapped source is not available" - return self.host_kernel_source + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + assert source is not None, "Wrapped source is not available" + return source def get_host_source(self): """Returns the source code of the host function.""" - return self.host_kernel_source + return self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index 73d7be50b9..ac7b59b974 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -12,7 +12,7 @@ from tilelang.jit.adapter.wrapper import TLPyWrapper from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available from .libgen import NVRTCLibraryGenerator @@ -103,8 +103,8 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -113,8 +113,9 @@ def from_database( adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source + adapter.host_kernel_source = host_kernel_source.text + adapter._set_cached_text_source("host_func", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) if isinstance(func_or_mod, tir.PrimFunc): adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -202,9 +203,13 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None: The kernel source code, or None if not available """ if kernel_only: - return self.device_kernel_source + return self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") else: - return self.host_func + return self._load_cached_text_source("host_func", "_host_kernel_source_path") + + def get_host_source(self) -> str | None: + """Get the cached host-side source code.""" + return self._load_cached_text_source("host_func", "_host_kernel_source_path") def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel.""" diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 3939f78e81..045a315164 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -18,7 +18,7 @@ from tvm.target import Target from tvm.relax import TensorType from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.utils.language import retrieve_func_from_module from tilelang.engine.param import KernelParam from tilelang.language.dtypes import dtype @@ -266,8 +266,8 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, @@ -276,9 +276,13 @@ def from_database( adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) + adapter.wrapped_source = ( + device_kernel_source.text + "\n\n" + host_kernel_source.text + if device_kernel_source.text is not None and host_kernel_source.text is not None + else None + ) adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): @@ -291,29 +295,43 @@ def from_database( adapter.verbose = verbose adapter.libpath = kernel_lib_path - adapter.kernel_global_source = device_kernel_source + adapter.kernel_global_source = device_kernel_source.text + adapter.rt_mod = None adapter.executable = runtime.load_module(kernel_lib_path) adapter._post_init() return adapter - def get_host_source(self): + def get_host_source(self) -> str | None: """Returns the source code of the host module.""" - if self.host_kernel_source is not None: - return self.host_kernel_source - return self.rt_mod.inspect_source() - - def get_device_source(self): + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + if source is not None: + return source + rt_mod = getattr(self, "rt_mod", None) + if rt_mod is None: + return None + return rt_mod.inspect_source() + + def get_device_source(self) -> str | None: """Returns the source code of the device module.""" - if self.device_kernel_source is not None: - return self.device_kernel_source - return self.rt_mod.imports[0].inspect_source() + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source + rt_mod = getattr(self, "rt_mod", None) + if rt_mod is None: + return None + return rt_mod.imports[0].inspect_source() def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" + device_source = self.get_device_source() or "" if kernel_only: - return self.get_device_source() - else: - return self.get_device_source() + "\n\n" + self.get_host_source() + return device_source + + host_source = self.get_host_source() or "" + if device_source and host_source: + return device_source + "\n\n" + host_source + return device_source or host_source @property def prim_func(self) -> tir.PrimFunc: diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 3ae56c537b..925b554407 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -12,6 +12,7 @@ from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.jit.adapter import ( BaseKernelAdapter, + CachedTextSource, CythonKernelAdapter, CuTeDSLKernelAdapter, TVMFFIKernelAdapter, @@ -143,8 +144,8 @@ def __init__( def from_database( cls, func: PrimFunc, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, params: list[KernelParam], target: str | Target, @@ -335,8 +336,8 @@ def _create_adapter_from_database( result_idx: list[int] | int, target: str | Target, func_or_mod: PrimFunc | tvm.runtime.Module, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, @@ -615,11 +616,21 @@ def params(self) -> list[KernelParam]: @property def kernel_source(self) -> str: - return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source + if self.artifact: + return self.artifact.kernel_source + source = getattr(self.adapter, "kernel_global_source", None) + if source is not None: + return source + return self.adapter.get_kernel_source(kernel_only=True) or "" @property def host_source(self) -> str: - return str(self.artifact.host_mod) if self.artifact else "" + if self.artifact: + return str(self.artifact.host_mod) + get_host_source = getattr(self.adapter, "get_host_source", None) + if get_host_source is None: + return "" + return get_host_source() or "" def export_library(self, kernel_file: str) -> None: """