Skip to content

Commit 552bc6c

Browse files
committed
Allow redefinition of kernel methods
1 parent 387d870 commit 552bc6c

2 files changed

Lines changed: 40 additions & 2 deletions

File tree

ext/CUDAExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using CUDA_Compiler_jll
99
public launch
1010

1111
# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12-
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
12+
const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
1313

1414
"""
1515
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing)
@@ -65,8 +65,11 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6565
# Determine kernel name
6666
kernel_name = name !== nothing ? name : string(nameof(f))
6767

68+
# Use method instance in case of a redefinition
69+
method = which(f, argtypes)
70+
6871
# Check compilation cache - returns CuFunction directly
69-
cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy)
72+
cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy)
7073
cufunc = get(_compilation_cache, cache_key, nothing)
7174
if cufunc === nothing || cuTile.compile_hook[] !== nothing
7275
cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy)

test/execution.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,41 @@ end
15901590

15911591
end
15921592

1593+
@testset "redefine kernel method" begin
1594+
mod = @eval module $(gensym())
1595+
import cuTile as ct
1596+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
1597+
pid = ct.bid(1)
1598+
ta = ct.load(a, (pid,), (16,))
1599+
tb = ct.load(b, (pid,), (16,))
1600+
ct.store(c, (pid,), ta + tb)
1601+
return
1602+
end
1603+
end
1604+
1605+
a = CUDA.ones(Float32, 1024)
1606+
b = CUDA.ones(Float32, 1024)
1607+
c = CUDA.zeros(Float32, 1024)
1608+
1609+
ct.launch(mod.vadd_kernel, 64, a, b, c)
1610+
1611+
@test Array(c) Array(a) + Array(b)
1612+
1613+
@eval mod begin
1614+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
1615+
pid = ct.bid(1)
1616+
ta = ct.load(a, (pid,), (16,))
1617+
tb = ct.load(b, (pid,), (16,))
1618+
ct.store(c, (pid,), ta + tb * 2)
1619+
return
1620+
end
1621+
end
1622+
1623+
ct.launch(mod.vadd_kernel, 64, a, b, c)
1624+
1625+
@test Array(c) Array(a) + Array(b) * 2
1626+
end
1627+
15931628
@testset "Entry Hints Integration" begin
15941629

15951630
@testset "launch with num_ctas" begin

0 commit comments

Comments
 (0)