Skip to content

Commit d9f9657

Browse files
committed
Optimize disk cache source loading
1 parent c2cabd9 commit d9f9657

8 files changed

Lines changed: 230 additions & 48 deletions

File tree

testing/python/cache/test_tilelang_kernel_cache_atomic_save.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import builtins
12
import errno
23
from pathlib import Path
4+
5+
import cloudpickle
36
import pytest
47

8+
import tilelang.cache.kernel_cache as kernel_cache_mod
59
from tilelang.cache.kernel_cache import KernelCache
610
from tilelang.env import env
711
from tilelang.jit.adapter.nvrtc.kernel_cache import NVRTCKernelCache
@@ -46,6 +50,112 @@ def _make_fake_nvrtc_kernel(tmp_path):
4650
return _FakeKernel(str(lib_path))
4751

4852

53+
def _write_complete_kernel_cache_entry(
54+
cache: KernelCache,
55+
key: str,
56+
device_source: str = "// device kernel",
57+
host_source: str = "// host kernel",
58+
) -> Path:
59+
cache_path = Path(cache._get_cache_path(key))
60+
cache_path.mkdir(parents=True)
61+
(cache_path / cache.device_kernel_path).write_text(device_source)
62+
(cache_path / cache.host_kernel_path).write_text(host_source)
63+
(cache_path / cache.kernel_lib_path).write_bytes(b"fake-so")
64+
with (cache_path / cache.params_path).open("wb") as f:
65+
cloudpickle.dump(["param"], f)
66+
return cache_path
67+
68+
69+
def test_kernel_cache_disk_hit_defers_source_loading(cache_dirs, monkeypatch):
70+
cache = KernelCache()
71+
key = "lazy-source-load"
72+
cache_path = _write_complete_kernel_cache_entry(cache, key)
73+
74+
sentinel = object()
75+
captured = {}
76+
77+
def fail_source_load(*args, **kwargs):
78+
raise AssertionError("disk cache hit should pass source paths through for lazy loading")
79+
80+
def fake_from_database(cls, **kwargs):
81+
captured.update(kwargs)
82+
return sentinel
83+
84+
monkeypatch.setattr(cache, "_load_kernel_source", fail_source_load)
85+
monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database))
86+
87+
loaded = cache._load_kernel_from_disk(
88+
key,
89+
target="cuda",
90+
target_host=None,
91+
out_idx=[0],
92+
execution_backend="tvm_ffi",
93+
pass_configs=None,
94+
compile_flags=None,
95+
func=None,
96+
)
97+
98+
assert loaded is sentinel
99+
assert captured["host_kernel_source"] is None
100+
assert captured["device_kernel_source"] is None
101+
assert captured["host_kernel_source_path"] == str(cache_path / cache.host_kernel_path)
102+
assert captured["device_kernel_source_path"] == str(cache_path / cache.device_kernel_path)
103+
assert captured["kernel_lib_path"] == str(cache_path / cache.kernel_lib_path)
104+
assert captured["params"] == ["param"]
105+
106+
107+
def test_kernel_cache_disk_hit_perf_skips_large_source_file_reads(cache_dirs, monkeypatch):
108+
cache = KernelCache()
109+
key = "lazy-source-load-perf"
110+
large_source = "// source\n" + ("x" * (2 * 1024 * 1024))
111+
cache_path = _write_complete_kernel_cache_entry(
112+
cache,
113+
key,
114+
device_source=large_source,
115+
host_source=large_source,
116+
)
117+
source_paths = {
118+
(cache_path / cache.device_kernel_path).resolve(),
119+
(cache_path / cache.host_kernel_path).resolve(),
120+
}
121+
source_read_count = 0
122+
sentinel = object()
123+
124+
real_open = builtins.open
125+
126+
def tracking_open(file, *args, **kwargs):
127+
nonlocal source_read_count
128+
mode = args[0] if args else kwargs.get("mode", "r")
129+
try:
130+
path = Path(file).resolve()
131+
except TypeError:
132+
return real_open(file, *args, **kwargs)
133+
if "r" in mode and path in source_paths:
134+
source_read_count += 1
135+
raise AssertionError("cache perf regression: source file read during disk cache hit")
136+
return real_open(file, *args, **kwargs)
137+
138+
def fake_from_database(cls, **kwargs):
139+
return sentinel
140+
141+
monkeypatch.setattr(builtins, "open", tracking_open)
142+
monkeypatch.setattr(kernel_cache_mod.JITKernel, "from_database", classmethod(fake_from_database))
143+
144+
loaded = cache._load_kernel_from_disk(
145+
key,
146+
target="cuda",
147+
target_host=None,
148+
out_idx=[0],
149+
execution_backend="tvm_ffi",
150+
pass_configs=None,
151+
compile_flags=None,
152+
func=None,
153+
)
154+
155+
assert loaded is sentinel
156+
assert source_read_count == 0
157+
158+
49159
def test_kernel_cache_rewrites_incomplete_cache_dir(cache_dirs, tmp_path):
50160
cache = KernelCache()
51161
key = "atomic-repair"

tilelang/cache/kernel_cache.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -330,19 +330,23 @@ def cached(
330330
)
331331
return self._memory_cache[key]
332332

333-
if verbose:
334-
self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '<unknown>')}")
333+
if verbose:
334+
self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '<unknown>')}")
335335

336-
# Then check disk cache
337-
kernel = self._load_kernel_from_disk(
338-
key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose
339-
)
340-
if kernel is not None:
341-
if verbose:
342-
self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '<unknown>')}")
343-
# Populate memory cache with disk result
336+
# Disk loads can be expensive for large kernel sets; keep them outside
337+
# the global cache lock so independent cache hits can proceed in parallel.
338+
kernel = self._load_kernel_from_disk(
339+
key, target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose
340+
)
341+
if kernel is not None:
342+
if verbose:
343+
self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '<unknown>')}")
344+
with self._lock:
345+
existing = self._memory_cache.get(key)
346+
if existing is not None:
347+
return existing
344348
self._memory_cache[key] = kernel
345-
return kernel
349+
return kernel
346350

347351
if verbose:
348352
self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '<unknown>')}")
@@ -518,9 +522,6 @@ def _load_kernel_from_disk(
518522
if not all([os.path.exists(file) for file in required_files]):
519523
return None
520524

521-
# Load the kernel source file (optional)
522-
device_kernel_source, host_kernel_source = self._load_kernel_source(device_kernel_path, host_kernel_path, verbose)
523-
524525
# Load kernel parameters
525526
kernel_params: list[KernelParam] | None = None
526527
try:
@@ -533,8 +534,10 @@ def _load_kernel_from_disk(
533534

534535
return self._build_kernel(
535536
func=func,
536-
host_kernel_source=host_kernel_source,
537-
device_kernel_source=device_kernel_source,
537+
host_kernel_source=None,
538+
device_kernel_source=None,
539+
host_kernel_path=host_kernel_path,
540+
device_kernel_path=device_kernel_path,
538541
kernel_lib_path=kernel_lib_path,
539542
kernel_params=kernel_params,
540543
target=target,
@@ -638,8 +641,10 @@ def _set_adapter_cache_path(self, kernel: JITKernel, cache_path: str):
638641
def _build_kernel(
639642
self,
640643
func: Callable | None,
641-
host_kernel_source: str,
642-
device_kernel_source: str,
644+
host_kernel_source: str | None,
645+
device_kernel_source: str | None,
646+
host_kernel_path: str | None,
647+
device_kernel_path: str | None,
643648
kernel_lib_path: str,
644649
kernel_params: list[KernelParam] | None,
645650
target: str | Target,
@@ -651,10 +656,6 @@ def _build_kernel(
651656
) -> JITKernel | None:
652657
# Check all required components and report specific failures
653658
missing_components = []
654-
if not host_kernel_source:
655-
missing_components.append("host_kernel_source")
656-
if not device_kernel_source:
657-
missing_components.append("device_kernel_source")
658659
if not kernel_params:
659660
missing_components.append("kernel_params")
660661

@@ -674,4 +675,6 @@ def _build_kernel(
674675
execution_backend=execution_backend,
675676
pass_configs=pass_configs,
676677
compile_flags=compile_flags,
678+
host_kernel_source_path=host_kernel_path,
679+
device_kernel_source_path=device_kernel_path,
677680
)

tilelang/jit/adapter/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,17 @@ def get_kernel_source(self, kernel_only: bool = True) -> str:
9494

9595
def _post_init(self):
9696
self.func = self._convert_torch_func()
97+
98+
def _load_cached_text_source(self, source_attr: str, path_attr: str) -> str | None:
99+
source = getattr(self, source_attr, None)
100+
if source is not None:
101+
return source
102+
103+
path = getattr(self, path_attr, None)
104+
if path is None:
105+
return None
106+
107+
with open(path, encoding="utf-8") as file:
108+
source = file.read()
109+
setattr(self, source_attr, source)
110+
return source

tilelang/jit/adapter/cutedsl/adapter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,22 @@ def from_database(
104104
result_idx: list[int],
105105
target: str,
106106
func_or_mod: tir.PrimFunc | tvm.IRModule,
107-
host_kernel_source: str,
108-
device_kernel_source: str,
107+
host_kernel_source: str | None,
108+
device_kernel_source: str | None,
109109
kernel_lib_path: str,
110110
verbose: bool = False,
111111
pass_configs: dict[str, Any] | None = None,
112112
compile_flags: list[str] | None = None,
113+
host_kernel_source_path: str | None = None,
114+
device_kernel_source_path: str | None = None,
113115
):
114116
adapter = cls.__new__(cls)
115117
adapter.params = params
116118
adapter.result_idx = adapter._legalize_result_idx(result_idx)
117119
adapter.host_kernel_source = host_kernel_source
118120
adapter.device_kernel_source = device_kernel_source
121+
adapter._host_kernel_source_path = host_kernel_source_path
122+
adapter._device_kernel_source_path = device_kernel_source_path
119123

120124
if isinstance(func_or_mod, tir.PrimFunc):
121125
gsym = func_or_mod.attrs.get("global_symbol")
@@ -228,7 +232,10 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
228232
str | None
229233
The kernel source code, or None if not available
230234
"""
231-
return self.device_kernel_source
235+
source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path")
236+
if source is not None:
237+
self.kernel_global_source = source
238+
return source
232239

233240
def _forward_from_prebuild_lib(self, *args, stream: int | None = None, device_id: int = 0):
234241
"""Low-level function to call the compiled CUDA kernel.

tilelang/jit/adapter/cython/adapter.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,22 @@ def from_database(
155155
result_idx: list[int],
156156
target: str,
157157
func_or_mod: tir.PrimFunc | tvm.IRModule,
158-
host_kernel_source: str,
159-
device_kernel_source: str,
158+
host_kernel_source: str | None,
159+
device_kernel_source: str | None,
160160
kernel_lib_path: str,
161161
verbose: bool = False,
162162
pass_configs: dict[str, Any] | None = None,
163163
compile_flags: list[str] | None = None,
164+
host_kernel_source_path: str | None = None,
165+
device_kernel_source_path: str | None = None,
164166
):
165167
adapter = cls.__new__(cls)
166168
adapter.params = params
167169
adapter.result_idx = adapter._legalize_result_idx(result_idx)
168170
adapter.host_kernel_source = host_kernel_source
169171
adapter.device_kernel_source = device_kernel_source
172+
adapter._host_kernel_source_path = host_kernel_source_path
173+
adapter._device_kernel_source_path = device_kernel_source_path
170174
adapter.kernel_global_source = device_kernel_source # Set alias for compatibility
171175
adapter.pass_configs = pass_configs
172176

@@ -386,12 +390,16 @@ def is_dynamic(self):
386390
def get_kernel_source(self, kernel_only: bool = False):
387391
"""Returns the source code of the compiled kernel."""
388392
if kernel_only:
389-
return self.device_kernel_source
393+
source = self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path")
394+
if source is not None:
395+
self.kernel_global_source = source
396+
return source
390397
else:
391398
# Wrapper only has host kernel source
392-
assert self.host_kernel_source is not None, "Wrapped source is not available"
393-
return self.host_kernel_source
399+
source = self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path")
400+
assert source is not None, "Wrapped source is not available"
401+
return source
394402

395403
def get_host_source(self):
396404
"""Returns the source code of the host function."""
397-
return self.host_kernel_source
405+
return self._load_cached_text_source("host_kernel_source", "_host_kernel_source_path")

tilelang/jit/adapter/nvrtc/adapter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,23 @@ def from_database(
102102
result_idx: list[int],
103103
target: str,
104104
func_or_mod: tir.PrimFunc | tvm.IRModule,
105-
host_kernel_source: str,
106-
device_kernel_source: str,
105+
host_kernel_source: str | None,
106+
device_kernel_source: str | None,
107107
kernel_lib_path: str,
108108
verbose: bool = False,
109109
pass_configs: dict[str, Any] | None = None,
110110
compile_flags: list[str] | None = None,
111+
host_kernel_source_path: str | None = None,
112+
device_kernel_source_path: str | None = None,
111113
):
112114
adapter = cls.__new__(cls)
113115
adapter.params = params
114116
adapter.result_idx = adapter._legalize_result_idx(result_idx)
115117
adapter.host_kernel_source = host_kernel_source
116118
adapter.device_kernel_source = device_kernel_source
119+
adapter.host_func = host_kernel_source
120+
adapter._host_kernel_source_path = host_kernel_source_path
121+
adapter._device_kernel_source_path = device_kernel_source_path
117122

118123
if isinstance(func_or_mod, tir.PrimFunc):
119124
adapter.ir_module = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
@@ -201,9 +206,9 @@ def get_kernel_source(self, kernel_only: bool = True) -> str | None:
201206
The kernel source code, or None if not available
202207
"""
203208
if kernel_only:
204-
return self.device_kernel_source
209+
return self._load_cached_text_source("device_kernel_source", "_device_kernel_source_path")
205210
else:
206-
return self.host_func
211+
return self._load_cached_text_source("host_func", "_host_kernel_source_path")
207212

208213
def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
209214
"""Low-level function to call the compiled CUDA kernel."""

0 commit comments

Comments
 (0)