From 30c28edcdf4a216a202a153617d9f2c17d78ef44 Mon Sep 17 00:00:00 2001 From: sepcnt <30561671+sepcnt@users.noreply.github.com> Date: Mon, 11 May 2026 13:08:01 +0800 Subject: [PATCH 1/3] Optimize disk cache source loading --- .../test_tilelang_kernel_cache_atomic_save.py | 251 ++++++++++++++++++ tilelang/cache/__init__.py | 167 +++++++++--- tilelang/cache/kernel_cache.py | 180 +++++++++++-- tilelang/jit/__init__.py | 48 +++- tilelang/jit/adapter/base.py | 17 ++ tilelang/jit/adapter/cutedsl/adapter.py | 24 +- tilelang/jit/adapter/cython/adapter.py | 20 +- tilelang/jit/adapter/nvrtc/adapter.py | 17 +- tilelang/jit/adapter/tvm_ffi.py | 52 ++-- tilelang/jit/kernel.py | 36 ++- 10 files changed, 716 insertions(+), 96 deletions(-) diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py index 3e87ceacaa..b15f1f73e3 100644 --- a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -1,7 +1,11 @@ +import builtins import errno from pathlib import Path + +import cloudpickle import pytest +import tilelang.cache.kernel_cache as kernel_cache_mod from tilelang.cache.kernel_cache import KernelCache from tilelang.env import env from tilelang.jit.adapter.nvrtc.kernel_cache import NVRTCKernelCache @@ -46,6 +50,253 @@ def _make_fake_nvrtc_kernel(tmp_path): return _FakeKernel(str(lib_path)) +def _write_complete_kernel_cache_entry( + cache: KernelCache, + key: str, + device_source: str = "// device kernel", + host_source: str = "// host kernel", + prim_func=None, +) -> Path: + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) + (cache_path / cache.device_kernel_path).write_text(device_source) + (cache_path / cache.host_kernel_path).write_text(host_source) + (cache_path / cache.kernel_lib_path).write_bytes(b"fake-so") + with (cache_path / cache.params_path).open("wb") as f: + cloudpickle.dump(["param"], f) + if prim_func is not None: + with (cache_path / cache.prim_func_path).open("wb") as f: + cloudpickle.dump(prim_func, f) + return cache_path + + +def test_kernel_cache_disk_hit_defers_source_loading(cache_dirs, monkeypatch): + cache = KernelCache() + key = "lazy-source-load" + cache_path = _write_complete_kernel_cache_entry(cache, key) + + sentinel = object() + captured = {} + + def fail_source_load(*args, **kwargs): + raise AssertionError("disk cache hit should pass source paths through for lazy loading") + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(cache, "_load_kernel_source", fail_source_load) + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is sentinel + assert captured["host_kernel_source"] is None + assert captured["device_kernel_source"] is None + assert captured["host_kernel_source_path"] == str(cache_path / cache.host_kernel_path) + assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path) + assert captured["kernel_lib_path"] == str(cache_path / cache.kernel_lib_path) + assert captured["params"] == ["param"] + + +def test_kernel_cache_disk_hit_perf_skips_large_source_file_reads(cache_dirs, monkeypatch): + cache = KernelCache() + key = "lazy-source-load-perf" + large_source = "// source\n" + ("x" * (2 * 1024 * 1024)) + cache_path = _write_complete_kernel_cache_entry( + cache, + key, + device_source=large_source, + host_source=large_source, + ) + source_paths = { + (cache_path / cache.device_kernel_path).resolve(), + (cache_path / cache.host_kernel_path).resolve(), + } + source_read_count = 0 + sentinel = object() + + real_open = builtins.open + + def tracking_open(file, *args, **kwargs): + nonlocal source_read_count + mode = args[0] if args else kwargs.get("mode", "r") + try: + path = Path(file).resolve() + except TypeError: + return real_open(file, *args, **kwargs) + if "r" in mode and path in source_paths: + source_read_count += 1 + raise AssertionError("cache perf regression: source file read during disk cache hit") + return real_open(file, *args, **kwargs) + + def fake_from_database(cls, **kwargs): + return sentinel + + monkeypatch.setattr(builtins, "open", tracking_open) + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is sentinel + assert source_read_count == 0 + + +def test_kernel_cache_frontend_hit_loads_serialized_prim_func(cache_dirs, monkeypatch): + cache = KernelCache() + key = "frontend-kernel-key" + prim_func = {"name": "cached_prim_func"} + cache_path = _write_complete_kernel_cache_entry(cache, key, prim_func=prim_func) + cache.store_frontend_cache("frontend-key", key) + + sentinel = object() + captured = {} + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache.load_frontend_cached( + "frontend-key", + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + ) + + assert loaded is sentinel + assert captured["func"] == prim_func + assert captured["host_kernel_source"] is None + assert captured["device_kernel_source"] is None + assert captured["host_kernel_source_path"] == str(cache_path / cache.host_kernel_path) + assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path) + + +def test_jit_frontend_cache_hit_skips_tir_elaboration(monkeypatch): + import tilelang + import tilelang.language as T + from tilelang.jit import JITImpl + + sentinel = object() + calls = [] + + @tilelang.jit + def frontend_cached_kernel(block_m: int = 16): + @T.prim_func + def kernel(): + T.evaluate(0) + + return kernel + + def fake_load_frontend_cached(frontend_key_data, **kwargs): + calls.append((frontend_key_data, kwargs)) + return sentinel + + def fail_compile(self, *args, **kwargs): + raise AssertionError("frontend cache hit should not elaborate TIR") + + monkeypatch.setattr("tilelang.cache.load_frontend_cached", fake_load_frontend_cached) + monkeypatch.setattr(JITImpl, "compile", fail_compile) + + assert frontend_cached_kernel(block_m=32) is sentinel + assert calls + assert "frontend_cached_kernel" in calls[0][0]["function"] + + +def test_kernel_cache_disk_hit_rejects_entries_missing_sources(cache_dirs, monkeypatch): + cache = KernelCache() + key = "missing-source-entry" + cache_path = Path(cache._get_cache_path(key)) + cache_path.mkdir(parents=True) + (cache_path / cache.kernel_lib_path).write_bytes(b"fake-so") + with (cache_path / cache.params_path).open("wb") as f: + cloudpickle.dump(["param"], f) + + def fail_from_database(cls, **kwargs): + raise AssertionError("incomplete cache entries should miss before rebuilding from database") + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fail_from_database)) + + loaded = cache._load_kernel_from_disk( + key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + func=None, + ) + + assert loaded is None + + +def test_nvrtc_adapter_host_source_lazy_loads(tmp_path): + pytest.importorskip("cuda.bindings.driver", reason="NVRTC adapter requires cuda-python") + from tilelang.jit.adapter.nvrtc.adapter import NVRTCKernelAdapter + + host_source_path = tmp_path / "host_kernel.cu" + host_source_path.write_text("// nvrtc host source") + adapter = NVRTCKernelAdapter.__new__(NVRTCKernelAdapter) + adapter.host_func = None + adapter._host_kernel_source_path = str(host_source_path) + + assert adapter.get_host_source() == "// nvrtc host source" + assert adapter.host_func == "// nvrtc host source" + + +def test_cutedsl_adapter_host_source_lazy_loads(tmp_path): + from tilelang.jit.adapter.cutedsl.adapter import CuTeDSLKernelAdapter + + host_source_path = tmp_path / "kernel.py" + host_source_path.write_text("# cutedsl host source") + adapter = CuTeDSLKernelAdapter.__new__(CuTeDSLKernelAdapter) + adapter.host_kernel_source = None + adapter.host_func = None + adapter._host_kernel_source_path = str(host_source_path) + + assert adapter.get_host_source() == "# cutedsl host source" + assert adapter.host_kernel_source == "# cutedsl host source" + + +def test_tvm_ffi_source_fallback_handles_missing_runtime_module(): + from tilelang.jit.adapter.tvm_ffi import TVMFFIKernelAdapter + + adapter = TVMFFIKernelAdapter.__new__(TVMFFIKernelAdapter) + adapter.host_kernel_source = None + adapter.device_kernel_source = None + adapter._host_kernel_source_path = None + adapter._device_kernel_source_path = None + adapter.rt_mod = None + + assert adapter.get_host_source() is None + assert adapter.get_device_source() is None + assert adapter.get_kernel_source(kernel_only=True) == "" + assert adapter.get_kernel_source(kernel_only=False) == "" + + def test_kernel_cache_rewrites_incomplete_cache_dir(cache_dirs, tmp_path): cache = KernelCache() key = "atomic-repair" diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index d0ee6c9a44..b091a2525f 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging +import json +from hashlib import sha256 from typing import TYPE_CHECKING, Literal from tvm.target import Target from tvm.tir import PrimFunc @@ -27,22 +29,21 @@ } -def cached( - func: PrimFunc = None, - out_idx: list[int] = None, - *args, - target: str | Target | None = None, - target_host: str | Target | None = None, - execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, - verbose: bool | None = None, - pass_configs: dict | None = None, - compile_flags: list[str] | str | None = None, -) -> JITKernel: - """ - Caches and reuses compiled kernels (using KernelCache class). - """ - # Apply environment variable defaults if parameters are not explicitly set - # This is the SINGLE source of truth for env var processing +def _normalize_for_json(value): + if isinstance(value, dict): + return {str(k): _normalize_for_json(v) for k, v in sorted(value.items(), key=lambda item: str(item[0]))} + if isinstance(value, (list, tuple)): + return [_normalize_for_json(v) for v in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return repr(value) + + +def _resolve_cache_dispatch( + target: str | Target | None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None, + verbose: bool | None, +): if target is None: target = env.get_default_target() if execution_backend is None: @@ -50,40 +51,136 @@ def cached( if verbose is None: verbose = env.get_default_verbose() - # Normalize target and resolve execution backend before proceeding from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target norm_target = Target(_determine_target(target)) if isinstance(target, str) else target requested_backend = execution_backend - execution_backend = resolve_execution_backend(requested_backend, norm_target) + resolved_backend = resolve_execution_backend(requested_backend, norm_target) if verbose: allowed_now = allowed_backends_for_target(norm_target, include_unavailable=False) - # Avoid duplicate logs when caller already resolved explicitly - if requested_backend in (None, "auto") or requested_backend != execution_backend: + if requested_backend in (None, "auto") or requested_backend != resolved_backend: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.info( "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", - execution_backend, + resolved_backend, requested_backend, norm_target.kind.name, ", ".join(sorted(allowed_now)), ) - if execution_backend in _dispatch_map: - return _dispatch_map[execution_backend].cached( - func, - out_idx, - *args, - target=norm_target, - target_host=target_host, - execution_backend=execution_backend, - verbose=verbose, - pass_configs=pass_configs, - compile_flags=compile_flags, - ) - else: - raise ValueError(f'Cannot find support for execution backend "{execution_backend}"') + if resolved_backend not in _dispatch_map: + raise ValueError(f'Cannot find support for execution backend "{resolved_backend}"') + return _dispatch_map[resolved_backend], norm_target, resolved_backend, verbose + + +def _make_frontend_cache_key( + frontend_key_data: dict, + *, + target: Target, + target_host: str | Target | None, + execution_backend: str, + out_idx: list[int] | int | None, + pass_configs: dict | None, + compile_flags: list[str] | str | None, +) -> str: + key_data = { + "frontend": _normalize_for_json(frontend_key_data), + "out_idx": _normalize_for_json(out_idx), + "target": str(target), + "target_host": str(target_host) if target_host else None, + "execution_backend": execution_backend, + "pass_configs": _normalize_for_json(pass_configs), + "compile_flags": _normalize_for_json(compile_flags), + } + key_string = json.dumps(key_data, sort_keys=True) + return sha256(key_string.encode()).hexdigest() + + +def cached( + func: PrimFunc = None, + out_idx: list[int] = None, + *args, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel: + """ + Caches and reuses compiled kernels (using KernelCache class). + """ + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + return cache.cached( + func, + out_idx, + *args, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def load_frontend_cached( + frontend_key_data: dict, + *, + out_idx: list[int] | int | None = None, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> JITKernel | None: + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + frontend_key = _make_frontend_cache_key( + frontend_key_data, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + out_idx=out_idx, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + return cache.load_frontend_cached( + frontend_key, + target=norm_target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + verbose=verbose, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + + +def store_frontend_cache( + frontend_key_data: dict, + kernel_key: str, + *, + out_idx: list[int] | int | None = None, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, +) -> None: + cache, norm_target, execution_backend, verbose = _resolve_cache_dispatch(target, execution_backend, verbose) + frontend_key = _make_frontend_cache_key( + frontend_key_data, + target=norm_target, + target_host=target_host, + execution_backend=execution_backend, + out_idx=out_idx, + pass_configs=pass_configs, + compile_flags=compile_flags, + ) + cache.store_frontend_cache(frontend_key, kernel_key, verbose=verbose) def clear_cache(): diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 1fd9e68b27..36bc3b4e09 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -46,7 +46,9 @@ class KernelCache: host_kernel_path = "host_kernel.cu" kernel_lib_path = "kernel_lib.so" params_path = "params.pkl" + prim_func_path = "prim_func.pkl" cache_root_dir = "kernels" + frontend_cache_root_dir = "frontend" staging_root_dir = ".staging" @staticmethod @@ -166,6 +168,7 @@ def _create_dirs(): os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True) os.makedirs(KernelCache._get_namespace_root(), exist_ok=True) os.makedirs(KernelCache._get_cache_root(), exist_ok=True) + os.makedirs(KernelCache._get_frontend_cache_root(), exist_ok=True) os.makedirs(KernelCache._get_staging_root(), exist_ok=True) staging_root = KernelCache._get_staging_root() @@ -182,6 +185,10 @@ def _get_namespace_root() -> str: def _get_cache_root() -> str: return os.path.join(KernelCache._get_namespace_root(), KernelCache.cache_root_dir) + @staticmethod + def _get_frontend_cache_root() -> str: + return os.path.join(KernelCache._get_namespace_root(), KernelCache.frontend_cache_root_dir) + @staticmethod def _get_staging_root() -> str: return os.path.join(KernelCache._get_namespace_root(), KernelCache.staging_root_dir) @@ -330,19 +337,23 @@ def cached( ) return self._memory_cache[key] - if verbose: - self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") + if verbose: + self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") - # Then check disk cache - kernel = self._load_kernel_from_disk( - key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose - ) - if kernel is not None: - if verbose: - self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") - # Populate memory cache with disk result + # Disk loads can be expensive for large kernel sets; keep them outside + # the global cache lock so independent cache hits can proceed in parallel. + kernel = self._load_kernel_from_disk( + key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose + ) + if kernel is not None: + if verbose: + self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing self._memory_cache[key] = kernel - return kernel + return kernel if verbose: self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '')}") @@ -365,6 +376,7 @@ def cached( self._set_adapter_cache_path(kernel, cache_path) # Store in memory cache after compilation + self._tag_kernel_cache_entry(kernel, key, self._get_cache_path(key)) self._memory_cache[key] = kernel return kernel @@ -388,6 +400,20 @@ def _get_cache_path(self, key: str) -> str: """ return os.path.join(self._get_cache_root(), key) + def _get_frontend_cache_path(self, frontend_key: str) -> str: + return os.path.join( + self._get_frontend_cache_root(), + f"{self._sanitize_path_component(frontend_key)}.json", + ) + + @staticmethod + def _tag_kernel_cache_entry(kernel: JITKernel, key: str, cache_path: str) -> None: + try: + kernel._tilelang_cache_key = key + kernel._tilelang_cache_path = cache_path + except Exception: + pass + @staticmethod def _load_binary(path: str): with open(path, "rb") as file: @@ -459,6 +485,19 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non self.logger.debug(f"Saving kernel parameters to disk: {params_path}") KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) + # Save the PrimFunc so frontend cache hits can rebuild the adapter + # without re-elaborating the Python DSL in a fresh process. This is + # optional for backward compatibility with existing cache entries. + if func is not None: + prim_func_path = os.path.join(staging_path, self.prim_func_path) + if verbose: + self.logger.debug(f"Saving PrimFunc to disk: {prim_func_path}") + try: + KernelCache._safe_write_file(prim_func_path, "wb", lambda file: cloudpickle.dump(func, file)) + except Exception: + if verbose: + self.logger.exception("Error saving optional PrimFunc cache metadata") + missing_files = self._get_missing_complete_cache_files(staging_path) if missing_files: missing_names = ", ".join(os.path.basename(path) for path in missing_files) @@ -513,14 +552,12 @@ def _load_kernel_from_disk( kernel_lib_path = os.path.join(cache_path, self.kernel_lib_path) params_path = os.path.join(cache_path, self.params_path) - required_files = self._get_required_files(cache_path) - - if not all([os.path.exists(file) for file in required_files]): + missing_files = self._get_missing_complete_cache_files(cache_path) + if missing_files: + if verbose: + self.logger.debug("Disk cache entry is incomplete; missing files: %s", missing_files) return None - # Load the kernel source file (optional) - device_kernel_source, host_kernel_source = self._load_kernel_source(device_kernel_path, host_kernel_path, verbose) - # Load kernel parameters kernel_params: list[KernelParam] | None = None try: @@ -531,10 +568,12 @@ def _load_kernel_from_disk( except Exception: self.logger.exception("Error loading kernel parameters from disk") - return self._build_kernel( + kernel = self._build_kernel( func=func, - host_kernel_source=host_kernel_source, - device_kernel_source=device_kernel_source, + host_kernel_source=None, + device_kernel_source=None, + host_kernel_path=host_kernel_path, + device_kernel_path=device_kernel_path, kernel_lib_path=kernel_lib_path, kernel_params=kernel_params, target=target, @@ -544,6 +583,95 @@ def _load_kernel_from_disk( pass_configs=pass_configs, compile_flags=compile_flags, ) + if kernel is not None: + prim_func_path = os.path.join(cache_path, self.prim_func_path) + if func is not None and not os.path.exists(prim_func_path): + try: + KernelCache._safe_write_file(prim_func_path, "wb", lambda file: cloudpickle.dump(func, file)) + except Exception: + if verbose: + self.logger.exception("Error upgrading cache entry with PrimFunc") + self._tag_kernel_cache_entry(kernel, key, cache_path) + return kernel + + def load_frontend_cached( + self, + frontend_key: str, + *, + target: str | Target = "auto", + target_host: str | Target | None = None, + out_idx: list[int] | None = None, + execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi", + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, + verbose: bool = False, + ) -> JITKernel | None: + if not env.is_cache_enabled(): + return None + + frontend_path = self._get_frontend_cache_path(frontend_key) + try: + with open(frontend_path, encoding="utf-8") as file: + frontend_entry = json.load(file) + except OSError: + return None + except Exception: + self.logger.exception("Error loading frontend cache entry") + return None + + key = frontend_entry.get("kernel_key") + if not isinstance(key, str) or not key: + return None + + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + + cache_path = self._get_cache_path(key) + prim_func_path = os.path.join(cache_path, self.prim_func_path) + try: + with open(prim_func_path, "rb") as file: + func = cloudpickle.load(file) + except OSError: + return None + except Exception: + self.logger.exception("Error loading PrimFunc from frontend cache entry") + return None + + kernel = self._load_kernel_from_disk( + key, + target=target, + target_host=target_host, + out_idx=out_idx, + execution_backend=execution_backend, + pass_configs=pass_configs, + compile_flags=compile_flags, + func=func, + verbose=verbose, + ) + if kernel is None: + return None + + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + self._memory_cache[key] = kernel + return kernel + + def store_frontend_cache(self, frontend_key: str, kernel_key: str, *, verbose: bool = False) -> None: + if not env.is_cache_enabled(): + return + + KernelCache._create_dirs() + frontend_path = self._get_frontend_cache_path(frontend_key) + payload = {"kernel_key": kernel_key} + try: + KernelCache._safe_write_file(frontend_path, "w", lambda file: json.dump(payload, file, sort_keys=True)) + except Exception: + if verbose: + self.logger.exception("Error saving frontend cache entry") def _clear_disk_cache(self): """ @@ -638,8 +766,10 @@ def _set_adapter_cache_path(self, kernel: JITKernel, cache_path: str): def _build_kernel( self, func: Callable | None, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, + host_kernel_path: str | None, + device_kernel_path: str | None, kernel_lib_path: str, kernel_params: list[KernelParam] | None, target: str | Target, @@ -651,10 +781,6 @@ def _build_kernel( ) -> JITKernel | None: # Check all required components and report specific failures missing_components = [] - if not host_kernel_source: - missing_components.append("host_kernel_source") - if not device_kernel_source: - missing_components.append("device_kernel_source") if not kernel_params: missing_components.append("kernel_params") @@ -674,4 +800,6 @@ def _build_kernel( execution_backend=execution_backend, pass_configs=pass_configs, compile_flags=compile_flags, + host_kernel_source_path=host_kernel_path, + device_kernel_source_path=device_kernel_path, ) diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index f7beac7cb7..8ecec12f12 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -431,6 +431,20 @@ def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs): key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple) return key + def _frontend_cache_key_data(self, key: tuple) -> dict[str, Any]: + func_name = getattr(getattr(self.func, "orig_func", self.func), "__name__", "jit_kernel") + func_qualname = getattr(getattr(self.func, "orig_func", self.func), "__qualname__", func_name) + func_module = getattr(getattr(self.func, "orig_func", self.func), "__module__", None) + return { + "function": func_name, + "qualname": func_qualname, + "module": func_module, + "source": self.func_source, + "signature": str(self.signature), + "key": repr(key), + "mode": self.mode, + } + def get_kernel_source(self, *args: _P.args, **kwargs: _P.kwargs) -> str: kernel = self.compile(*args, **kwargs) return kernel.get_kernel_source() @@ -462,7 +476,39 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: key, kernel_args = self.func.parse_args(*args, **kwargs) kernel = self._kernel_cache.get(key, None) if kernel is None: - kernel = self.compile(*args, **kwargs) + frontend_key_data = None + if self.mode == "lazy" and not kernel_args: + frontend_key_data = self._frontend_cache_key_data(key) + from tilelang.cache import load_frontend_cached + + kernel = load_frontend_cached( + frontend_key_data, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) + if kernel is None: + kernel = self.compile(*args, **kwargs) + if frontend_key_data is not None: + kernel_key = getattr(kernel, "_tilelang_cache_key", None) + if kernel_key: + from tilelang.cache import store_frontend_cache + + store_frontend_cache( + frontend_key_data, + kernel_key, + out_idx=self.out_idx, + execution_backend=self.execution_backend, + target=self.target, + target_host=self.target_host, + verbose=self.verbose, + pass_configs=self.pass_configs, + compile_flags=self.compile_flags, + ) self._kernel_cache[key] = kernel # eager mode: execute kernel immediately and return result diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 3669f9e35c..3da55f4d8c 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -94,3 +94,20 @@ def get_kernel_source(self, kernel_only: bool = True) -> str: def _post_init(self): self.func = self._convert_torch_func() + + def _load_cached_text_source(self, source_attr: str, path_attr: str) -> str | None: + source = getattr(self, source_attr, None) + if source is not None: + return source + + path = getattr(self, path_attr, None) + if path is None: + return None + + try: + with open(path, encoding="utf-8") as file: + source = file.read() + except OSError: + return None + setattr(self, source_attr, source) + return source diff --git a/tilelang/jit/adapter/cutedsl/adapter.py b/tilelang/jit/adapter/cutedsl/adapter.py index 7dd7917a27..87dc57c7f4 100644 --- a/tilelang/jit/adapter/cutedsl/adapter.py +++ b/tilelang/jit/adapter/cutedsl/adapter.py @@ -104,18 +104,23 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.host_kernel_source = host_kernel_source adapter.device_kernel_source = device_kernel_source + adapter.host_func = host_kernel_source + adapter._host_kernel_source_path = host_kernel_source_path + adapter._device_kernel_source_path = device_kernel_source_path if isinstance(func_or_mod, tir.PrimFunc): gsym = func_or_mod.attrs.get("global_symbol") @@ -220,6 +225,13 @@ def _lookup_dynamic_symbolic(self, v: tir.Var) -> tuple[int, int, int]: return self._dynamic_symbolic_name_map[v.name] raise KeyError(f"Dynamic symbolic variable '{v.name}' not found in symbolic map") + def get_host_source(self) -> str | None: + """Get the cached host-side source code.""" + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + if source is not None: + return source + return getattr(self, "host_func", None) + def get_kernel_source(self, kernel_only: bool = True) -> str | None: """Get the CUDA kernel source code. @@ -228,7 +240,13 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None: str | None The kernel source code, or None if not available """ - return self.device_kernel_source + if not kernel_only: + return self.get_host_source() + + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source def _forward_from_prebuild_lib(self, *args, stream: int | None = None, device_id: int = 0): """Low-level function to call the compiled CUDA kernel. diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index 62bb0a6c33..cfaed86db3 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -155,18 +155,22 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.host_kernel_source = host_kernel_source adapter.device_kernel_source = device_kernel_source + adapter._host_kernel_source_path = host_kernel_source_path + adapter._device_kernel_source_path = device_kernel_source_path adapter.kernel_global_source = device_kernel_source # Set alias for compatibility adapter.pass_configs = pass_configs @@ -386,12 +390,16 @@ def is_dynamic(self): def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" if kernel_only: - return self.device_kernel_source + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source else: # Wrapper only has host kernel source - assert self.host_kernel_source is not None, "Wrapped source is not available" - return self.host_kernel_source + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + assert source is not None, "Wrapped source is not available" + return source def get_host_source(self): """Returns the source code of the host function.""" - return self.host_kernel_source + return self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index d7626393b1..c31c22f964 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -102,18 +102,23 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.host_kernel_source = host_kernel_source adapter.device_kernel_source = device_kernel_source + adapter.host_func = host_kernel_source + adapter._host_kernel_source_path = host_kernel_source_path + adapter._device_kernel_source_path = device_kernel_source_path if isinstance(func_or_mod, tir.PrimFunc): adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) @@ -201,9 +206,13 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None: The kernel source code, or None if not available """ if kernel_only: - return self.device_kernel_source + return self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") else: - return self.host_func + return self._load_cached_text_source("host_func", "_host_kernel_source_path") + + def get_host_source(self) -> str | None: + """Get the cached host-side source code.""" + return self._load_cached_text_source("host_func", "_host_kernel_source_path") def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel.""" diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 55a5e7ffbd..7795d40969 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -265,19 +265,27 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) adapter.host_kernel_source = host_kernel_source adapter.device_kernel_source = device_kernel_source - adapter.wrapped_source = device_kernel_source + "\n\n" + host_kernel_source + adapter._host_kernel_source_path = host_kernel_source_path + adapter._device_kernel_source_path = device_kernel_source_path + adapter.wrapped_source = ( + device_kernel_source + "\n\n" + host_kernel_source + if device_kernel_source is not None and host_kernel_source is not None + else None + ) adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): @@ -291,28 +299,42 @@ def from_database( adapter.verbose = verbose adapter.libpath = kernel_lib_path adapter.kernel_global_source = device_kernel_source + adapter.rt_mod = None adapter.executable = runtime.load_module(kernel_lib_path) adapter._post_init() return adapter - def get_host_source(self): + def get_host_source(self) -> str | None: """Returns the source code of the host module.""" - if self.host_kernel_source is not None: - return self.host_kernel_source - return self.rt_mod.inspect_source() - - def get_device_source(self): + source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path") + if source is not None: + return source + rt_mod = getattr(self, "rt_mod", None) + if rt_mod is None: + return None + return rt_mod.inspect_source() + + def get_device_source(self) -> str | None: """Returns the source code of the device module.""" - if self.device_kernel_source is not None: - return self.device_kernel_source - return self.rt_mod.imports[0].inspect_source() + source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path") + if source is not None: + self.kernel_global_source = source + return source + rt_mod = getattr(self, "rt_mod", None) + if rt_mod is None: + return None + return rt_mod.imports[0].inspect_source() def get_kernel_source(self, kernel_only: bool = False): """Returns the source code of the compiled kernel.""" + device_source = self.get_device_source() or "" if kernel_only: - return self.get_device_source() - else: - return self.get_device_source() + "\n\n" + self.get_host_source() + return device_source + + host_source = self.get_host_source() or "" + if device_source and host_source: + return device_source + "\n\n" + host_source + return device_source or host_source @property def prim_func(self) -> tir.PrimFunc: diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index b63884e307..5914ba89ce 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -148,8 +148,8 @@ def __init__( def from_database( cls, func: PrimFunc, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, params: list[KernelParam], target: str | Target, @@ -158,6 +158,8 @@ def from_database( execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ): """ Alternative constructor to create a TorchFunction directly from a database. @@ -183,6 +185,8 @@ def from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, + host_kernel_source_path=host_kernel_source_path, + device_kernel_source_path=device_kernel_source_path, ) instance.torch_function = instance.adapter.func return instance @@ -340,11 +344,13 @@ def _create_adapter_from_database( result_idx: list[int] | int, target: str | Target, func_or_mod: PrimFunc | tvm.runtime.Module, - host_kernel_source: str, - device_kernel_source: str, + host_kernel_source: str | None, + device_kernel_source: str | None, kernel_lib_path: str, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, + host_kernel_source_path: str | None = None, + device_kernel_source_path: str | None = None, ) -> BaseKernelAdapter: target = self.target execution_backend = self.execution_backend @@ -361,6 +367,8 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, + host_kernel_source_path=host_kernel_source_path, + device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "cython": adapter = CythonKernelAdapter.from_database( @@ -372,6 +380,8 @@ def _create_adapter_from_database( device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, + host_kernel_source_path=host_kernel_source_path, + device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter @@ -386,6 +396,8 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, + host_kernel_source_path=host_kernel_source_path, + device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "cutedsl": adapter = CuTeDSLKernelAdapter.from_database( @@ -398,6 +410,8 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, + host_kernel_source_path=host_kernel_source_path, + device_kernel_source_path=device_kernel_source_path, ) else: # Handle invalid backend. @@ -620,11 +634,21 @@ def params(self) -> list[KernelParam]: @property def kernel_source(self) -> str: - return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source + if self.artifact: + return self.artifact.kernel_source + source = getattr(self.adapter, "kernel_global_source", None) + if source is not None: + return source + return self.adapter.get_kernel_source(kernel_only=True) or "" @property def host_source(self) -> str: - return str(self.artifact.host_mod) if self.artifact else "" + if self.artifact: + return str(self.artifact.host_mod) + get_host_source = getattr(self.adapter, "get_host_source", None) + if get_host_source is None: + return "" + return get_host_source() or "" def export_library(self, kernel_file: str) -> None: """ From d279b94365d9a2dce46fdc78beae147906ab614e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 19 May 2026 16:56:12 +0800 Subject: [PATCH 2/3] Add tests for kernel cache frontend hit and improve memory cache handling - Introduced a new test to validate the round-trip functionality of real primitive functions in the kernel cache. - Enhanced the memory cache logic to prevent overwriting existing entries. - Improved error handling when tagging kernel cache entries with logging for better debugging. --- .../test_tilelang_kernel_cache_atomic_save.py | 42 +++++++++++++++++++ tilelang/cache/kernel_cache.py | 15 +++++-- tilelang/jit/__init__.py | 3 ++ 3 files changed, 57 insertions(+), 3 deletions(-) diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py index b15f1f73e3..b79cef387e 100644 --- a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -194,6 +194,48 @@ def fake_from_database(cls, **kwargs): assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path) +def test_kernel_cache_frontend_hit_round_trips_real_prim_func(cache_dirs, tmp_path, monkeypatch): + import tilelang.language as T + import tvm + + @T.prim_func + def kernel(): + T.evaluate(0) + + cache = KernelCache() + key = "frontend-real-prim-func-key" + frontend_key = "frontend-real-prim-func" + cache_path = Path(cache._get_cache_path(key)) + + cache._save_kernel_to_disk(key, _make_fake_kernel(tmp_path), func=kernel) + cache.store_frontend_cache(frontend_key, key) + + sentinel = object() + captured = {} + + def fake_from_database(cls, **kwargs): + captured.update(kwargs) + return sentinel + + monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database)) + + loaded = cache.load_frontend_cached( + frontend_key, + target="cuda", + target_host=None, + out_idx=[0], + execution_backend="tvm_ffi", + pass_configs=None, + compile_flags=None, + ) + + assert loaded is sentinel + assert (cache_path / cache.prim_func_path).exists() + assert isinstance(captured["func"], tvm.tir.PrimFunc) + assert tvm.ir.structural_equal(captured["func"], kernel) + assert str(captured["func"].attrs["global_symbol"]) == str(kernel.attrs["global_symbol"]) + + def test_jit_frontend_cache_hit_skips_tir_elaboration(monkeypatch): import tilelang import tilelang.language as T diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 36bc3b4e09..cb626a62ba 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -377,7 +377,11 @@ def cached( # Store in memory cache after compilation self._tag_kernel_cache_entry(kernel, key, self._get_cache_path(key)) - self._memory_cache[key] = kernel + with self._lock: + existing = self._memory_cache.get(key) + if existing is not None: + return existing + self._memory_cache[key] = kernel return kernel def clear_cache(self): @@ -411,8 +415,13 @@ def _tag_kernel_cache_entry(kernel: JITKernel, key: str, cache_path: str) -> Non try: kernel._tilelang_cache_key = key kernel._tilelang_cache_path = cache_path - except Exception: - pass + except (AttributeError, TypeError): + logging.getLogger(__name__).debug( + "Could not tag kernel cache entry for key %s at %s", + key, + cache_path, + exc_info=True, + ) @staticmethod def _load_binary(path: str): diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 8ecec12f12..f4bdbcd5e6 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -477,6 +477,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: kernel = self._kernel_cache.get(key, None) if kernel is None: frontend_key_data = None + # Frontend cache is only safe when lazy-mode parse_args leaves no + # runtime kernel_args; then _frontend_cache_key_data fully identifies + # the compiled kernel, assuming compile-time values have stable reprs. if self.mode == "lazy" and not kernel_args: frontend_key_data = self._frontend_cache_key_data(key) from tilelang.cache import load_frontend_cached From f9a54ea832f7359abd2540e8d143e365ed6895ec Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 19 May 2026 17:07:28 +0800 Subject: [PATCH 3/3] [Cache] Refine lazy source restore Guard memory-cache writes after compile, make cache tagging failures diagnosable, and document the lazy frontend-cache gate. Wrap cached source text/path in CachedTextSource and add a real PrimFunc frontend-cache round-trip regression test. --- .../test_tilelang_kernel_cache_atomic_save.py | 13 ++++------- tilelang/autotuner/param.py | 5 ++-- tilelang/cache/kernel_cache.py | 15 ++++-------- tilelang/jit/adapter/__init__.py | 2 +- tilelang/jit/adapter/base.py | 11 +++++++++ tilelang/jit/adapter/cutedsl/adapter.py | 16 +++++-------- tilelang/jit/adapter/cython/adapter.py | 16 +++++-------- tilelang/jit/adapter/nvrtc/adapter.py | 16 +++++-------- tilelang/jit/adapter/tvm_ffi.py | 20 +++++++--------- tilelang/jit/kernel.py | 23 ++++--------------- 10 files changed, 56 insertions(+), 81 deletions(-) diff --git a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py index b79cef387e..0a6090281d 100644 --- a/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py +++ b/testing/python/cache/test_tilelang_kernel_cache_atomic_save.py @@ -8,6 +8,7 @@ import tilelang.cache.kernel_cache as kernel_cache_mod from tilelang.cache.kernel_cache import KernelCache from tilelang.env import env +from tilelang.jit.adapter.base import CachedTextSource from tilelang.jit.adapter.nvrtc.kernel_cache import NVRTCKernelCache @@ -100,10 +101,8 @@ def fake_from_database(cls, **kwargs): ) assert loaded is sentinel - assert captured["host_kernel_source"] is None - assert captured["device_kernel_source"] is None - assert captured["host_kernel_source_path"] == str(cache_path / cache.host_kernel_path) - assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path) + assert captured["host_kernel_source"] == CachedTextSource(path=str(cache_path / cache.host_kernel_path)) + assert captured["device_kernel_source"] == CachedTextSource(path=str(cache_path / cache.device_kernel_path)) assert captured["kernel_lib_path"] == str(cache_path / cache.kernel_lib_path) assert captured["params"] == ["param"] @@ -188,10 +187,8 @@ def fake_from_database(cls, **kwargs): assert loaded is sentinel assert captured["func"] == prim_func - assert captured["host_kernel_source"] is None - assert captured["device_kernel_source"] is None - assert captured["host_kernel_source_path"] == str(cache_path / cache.host_kernel_path) - assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path) + assert captured["host_kernel_source"] == CachedTextSource(path=str(cache_path / cache.host_kernel_path)) + assert captured["device_kernel_source"] == CachedTextSource(path=str(cache_path / cache.device_kernel_path)) def test_kernel_cache_frontend_hit_round_trips_real_prim_func(cache_dirs, tmp_path, monkeypatch): diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index ad741b4f9b..166e7b4b3f 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -12,6 +12,7 @@ import errno from tilelang.jit import JITKernel +from tilelang.jit.adapter.base import CachedTextSource import cloudpickle import os import shutil @@ -357,8 +358,8 @@ def _load_kernel_from_disk( if host_kernel_source and device_kernel_source and kernel_params: return JITKernel.from_database( func=func, - host_kernel_source=host_kernel_source, - device_kernel_source=device_kernel_source, + host_kernel_source=CachedTextSource(text=host_kernel_source), + device_kernel_source=CachedTextSource(text=device_kernel_source), kernel_lib_path=kernel_lib_path, params=kernel_params, target=target, diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index cb626a62ba..51b62d5309 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -22,6 +22,7 @@ from tilelang.utils.language import get_prim_func_name from tilelang import env from tilelang.jit import JITKernel +from tilelang.jit.adapter.base import CachedTextSource from tilelang import __version__ import platform @@ -579,10 +580,8 @@ def _load_kernel_from_disk( kernel = self._build_kernel( func=func, - host_kernel_source=None, - device_kernel_source=None, - host_kernel_path=host_kernel_path, - device_kernel_path=device_kernel_path, + host_kernel_source=CachedTextSource(path=host_kernel_path), + device_kernel_source=CachedTextSource(path=device_kernel_path), kernel_lib_path=kernel_lib_path, kernel_params=kernel_params, target=target, @@ -775,10 +774,8 @@ def _set_adapter_cache_path(self, kernel: JITKernel, cache_path: str): def _build_kernel( self, func: Callable | None, - host_kernel_source: str | None, - device_kernel_source: str | None, - host_kernel_path: str | None, - device_kernel_path: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, kernel_params: list[KernelParam] | None, target: str | Target, @@ -809,6 +806,4 @@ def _build_kernel( execution_backend=execution_backend, pass_configs=pass_configs, compile_flags=compile_flags, - host_kernel_source_path=host_kernel_path, - device_kernel_source_path=device_kernel_path, ) diff --git a/tilelang/jit/adapter/__init__.py b/tilelang/jit/adapter/__init__.py index 0d99452855..eb48af408d 100644 --- a/tilelang/jit/adapter/__init__.py +++ b/tilelang/jit/adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseKernelAdapter # noqa: F401 +from .base import BaseKernelAdapter, CachedTextSource # noqa: F401 from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401 from .cython import CythonKernelAdapter # noqa: F401 from .nvrtc import NVRTCKernelAdapter # noqa: F401 diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 3da55f4d8c..d65f6bf51d 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -3,11 +3,18 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Callable from tilelang.engine.param import KernelParam import torch +@dataclass(frozen=True) +class CachedTextSource: + text: str | None = None + path: str | None = None + + class BaseKernelAdapter(ABC): func: Callable | None = None @@ -95,6 +102,10 @@ def get_kernel_source(self, kernel_only: bool = True) -> str: def _post_init(self): self.func = self._convert_torch_func() + def _set_cached_text_source(self, source_attr: str, path_attr: str, source: CachedTextSource) -> None: + setattr(self, source_attr, source.text) + setattr(self, path_attr, source.path) + def _load_cached_text_source(self, source_attr: str, path_attr: str) -> str | None: source = getattr(self, source_attr, None) if source is not None: diff --git a/tilelang/jit/adapter/cutedsl/adapter.py b/tilelang/jit/adapter/cutedsl/adapter.py index 87dc57c7f4..7fe69d5eaf 100644 --- a/tilelang/jit/adapter/cutedsl/adapter.py +++ b/tilelang/jit/adapter/cutedsl/adapter.py @@ -14,7 +14,7 @@ from tilelang.jit.adapter.cutedsl.libgen import CuTeDSLLibraryGenerator from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource logger = logging.getLogger(__name__) @@ -104,23 +104,19 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter.host_func = host_kernel_source - adapter._host_kernel_source_path = host_kernel_source_path - adapter._device_kernel_source_path = device_kernel_source_path + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) + adapter.host_func = host_kernel_source.text if isinstance(func_or_mod, tir.PrimFunc): gsym = func_or_mod.attrs.get("global_symbol") diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index cfaed86db3..64ecde93cf 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -12,7 +12,7 @@ from tvm import tir from tvm.relax import TensorType -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.jit.adapter.wrapper import TLWrapper from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.utils import is_cuda_target, is_hip_target, is_cpu_target, is_metal_target @@ -155,23 +155,19 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter._host_kernel_source_path = host_kernel_source_path - adapter._device_kernel_source_path = device_kernel_source_path - adapter.kernel_global_source = device_kernel_source # Set alias for compatibility + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) + adapter.kernel_global_source = device_kernel_source.text # Set alias for compatibility adapter.pass_configs = pass_configs if isinstance(func_or_mod, tir.PrimFunc): diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index c31c22f964..f420117b05 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -11,7 +11,7 @@ from tilelang.jit.adapter.wrapper import TLPyWrapper from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.jit.adapter.nvrtc import is_nvrtc_available, check_nvrtc_available from .libgen import NVRTCLibraryGenerator @@ -102,23 +102,19 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter.host_func = host_kernel_source - adapter._host_kernel_source_path = host_kernel_source_path - adapter._device_kernel_source_path = device_kernel_source_path + adapter.host_kernel_source = host_kernel_source.text + adapter._set_cached_text_source("host_func", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) if isinstance(func_or_mod, tir.PrimFunc): adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod}) diff --git a/tilelang/jit/adapter/tvm_ffi.py b/tilelang/jit/adapter/tvm_ffi.py index 7795d40969..6877c0d1cf 100644 --- a/tilelang/jit/adapter/tvm_ffi.py +++ b/tilelang/jit/adapter/tvm_ffi.py @@ -17,7 +17,7 @@ from tvm.target import Target from tvm.relax import TensorType from tilelang.utils.target import determine_target -from tilelang.jit.adapter.base import BaseKernelAdapter +from tilelang.jit.adapter.base import BaseKernelAdapter, CachedTextSource from tilelang.utils.language import retrieve_func_from_module from tilelang.engine.param import KernelParam from tilelang.language.dtypes import dtype @@ -265,25 +265,21 @@ def from_database( result_idx: list[int], target: str, func_or_mod: tir.PrimFunc | tvm.IRModule, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, verbose: bool = False, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) - adapter.host_kernel_source = host_kernel_source - adapter.device_kernel_source = device_kernel_source - adapter._host_kernel_source_path = host_kernel_source_path - adapter._device_kernel_source_path = device_kernel_source_path + adapter._set_cached_text_source("host_kernel_source", "_host_kernel_source_path", host_kernel_source) + adapter._set_cached_text_source("device_kernel_source", "_device_kernel_source_path", device_kernel_source) adapter.wrapped_source = ( - device_kernel_source + "\n\n" + host_kernel_source - if device_kernel_source is not None and host_kernel_source is not None + device_kernel_source.text + "\n\n" + host_kernel_source.text + if device_kernel_source.text is not None and host_kernel_source.text is not None else None ) adapter.pass_configs = pass_configs @@ -298,7 +294,7 @@ def from_database( adapter.verbose = verbose adapter.libpath = kernel_lib_path - adapter.kernel_global_source = device_kernel_source + adapter.kernel_global_source = device_kernel_source.text adapter.rt_mod = None adapter.executable = runtime.load_module(kernel_lib_path) adapter._post_init() diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 5914ba89ce..a6744f159e 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -17,6 +17,7 @@ from tilelang.engine.param import CompiledArtifact, KernelParam from tilelang.jit.adapter import ( BaseKernelAdapter, + CachedTextSource, CythonKernelAdapter, CuTeDSLKernelAdapter, TVMFFIKernelAdapter, @@ -148,8 +149,8 @@ def __init__( def from_database( cls, func: PrimFunc, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, params: list[KernelParam], target: str | Target, @@ -158,8 +159,6 @@ def from_database( execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ): """ Alternative constructor to create a TorchFunction directly from a database. @@ -185,8 +184,6 @@ def from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, - host_kernel_source_path=host_kernel_source_path, - device_kernel_source_path=device_kernel_source_path, ) instance.torch_function = instance.adapter.func return instance @@ -344,13 +341,11 @@ def _create_adapter_from_database( result_idx: list[int] | int, target: str | Target, func_or_mod: PrimFunc | tvm.runtime.Module, - host_kernel_source: str | None, - device_kernel_source: str | None, + host_kernel_source: CachedTextSource, + device_kernel_source: CachedTextSource, kernel_lib_path: str, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | None = None, - host_kernel_source_path: str | None = None, - device_kernel_source_path: str | None = None, ) -> BaseKernelAdapter: target = self.target execution_backend = self.execution_backend @@ -367,8 +362,6 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, - host_kernel_source_path=host_kernel_source_path, - device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "cython": adapter = CythonKernelAdapter.from_database( @@ -380,8 +373,6 @@ def _create_adapter_from_database( device_kernel_source=device_kernel_source, kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, - host_kernel_source_path=host_kernel_source_path, - device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "nvrtc": from tilelang.jit.adapter import NVRTCKernelAdapter @@ -396,8 +387,6 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, - host_kernel_source_path=host_kernel_source_path, - device_kernel_source_path=device_kernel_source_path, ) elif execution_backend == "cutedsl": adapter = CuTeDSLKernelAdapter.from_database( @@ -410,8 +399,6 @@ def _create_adapter_from_database( kernel_lib_path=kernel_lib_path, pass_configs=pass_configs, compile_flags=compile_flags, - host_kernel_source_path=host_kernel_source_path, - device_kernel_source_path=device_kernel_source_path, ) else: # Handle invalid backend.