Skip to content

Commit 10ee857

Browse files
committed
Allow redefinition of kernel methods
1 parent b419422 commit 10ee857

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

ext/CUDAExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using CUDA_Compiler_jll
99
public launch
1010

1111
# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12-
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level) => CuFunction
12+
const _compilation_cache = Dict{Any, Any}() # (method, sm_arch, opt_level) => CuFunction
1313

1414
"""
1515
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3)
@@ -61,8 +61,11 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6161
# Determine kernel name
6262
kernel_name = name !== nothing ? name : string(nameof(f))
6363

64+
# Use method instance in case of a redefinition
65+
method = which(f, argtypes)
66+
6467
# Check compilation cache - returns CuFunction directly
65-
cache_key = (f, argtypes, sm_arch, opt_level)
68+
cache_key = (method, sm_arch, opt_level)
6669
cufunc = get(_compilation_cache, cache_key, nothing)
6770
if cufunc === nothing || cuTile.compile_hook[] !== nothing
6871
cubin = compile(f, argtypes; name, sm_arch, opt_level)

test/execution.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,3 +1589,33 @@ end
15891589
end
15901590

15911591
end
1592+
1593+
@testset "redefinition of kernel" begin
1594+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
1595+
pid = ct.bid(1)
1596+
ta = ct.load(a, (pid,), (16,))
1597+
tb = ct.load(b, (pid,), (16,))
1598+
ct.store(c, (pid,), ta + tb)
1599+
return
1600+
end
1601+
1602+
a = CUDA.ones(Float32, 1024)
1603+
b = CUDA.ones(Float32, 1024)
1604+
c = CUDA.zeros(Float32, 1024)
1605+
1606+
ct.launch(vadd_kernel, 64, a, b, c)
1607+
1608+
@test Array(c) Array(a) + Array(b)
1609+
1610+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
1611+
pid = ct.bid(1)
1612+
ta = ct.load(a, (pid,), (16,))
1613+
tb = ct.load(b, (pid,), (16,))
1614+
ct.store(c, (pid,), ta + tb * 2)
1615+
return
1616+
end
1617+
1618+
ct.launch(vadd_kernel, 64, a, b, c)
1619+
1620+
@test Array(c) Array(a) + Array(b) * 2
1621+
end

0 commit comments

Comments
 (0)