Skip to content

Commit e6a49dc

Browse files
committed
add back argtypes
1 parent 10ee857 commit e6a49dc

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

ext/CUDAExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using CUDA_Compiler_jll
99
public launch
1010

1111
# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12-
const _compilation_cache = Dict{Any, Any}() # (method, sm_arch, opt_level) => CuFunction
12+
const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level) => CuFunction
1313

1414
"""
1515
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3)
@@ -65,7 +65,7 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6565
method = which(f, argtypes)
6666

6767
# Check compilation cache - returns CuFunction directly
68-
cache_key = (method, sm_arch, opt_level)
68+
cache_key = (method, argtypes, sm_arch, opt_level)
6969
cufunc = get(_compilation_cache, cache_key, nothing)
7070
if cufunc === nothing || cuTile.compile_hook[] !== nothing
7171
cubin = compile(f, argtypes; name, sm_arch, opt_level)

0 commit comments

Comments
 (0)