Skip to content

Commit 8914d30

Browse files
authored
Merge branch 'JuliaGPU:main' into IntegerReduce
2 parents 48ea44b + ac2a860 commit 8914d30

3 files changed

Lines changed: 42 additions & 3 deletions

File tree

ext/CUDAExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using CUDA_Compiler_jll
99
public launch
1010

1111
# Compilation cache - stores CuFunction directly to avoid re-loading CuModule
12-
const _compilation_cache = Dict{Any, Any}() # (f, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
12+
const _compilation_cache = Dict{Any, Any}() # (method, argtypes, sm_arch, opt_level, num_ctas, occupancy) => CuFunction
1313

1414
"""
1515
launch(f, grid, args...; name=nothing, sm_arch=default_sm_arch(), opt_level=3, num_ctas=nothing, occupancy=nothing)
@@ -65,8 +65,11 @@ function cuTile.launch(@nospecialize(f), grid, args...;
6565
# Determine kernel name
6666
kernel_name = name !== nothing ? name : string(nameof(f))
6767

68+
# Use method instance in case of a redefinition
69+
method = which(f, argtypes)
70+
6871
# Check compilation cache - returns CuFunction directly
69-
cache_key = (f, argtypes, sm_arch, opt_level, num_ctas, occupancy)
72+
cache_key = (method, argtypes, sm_arch, opt_level, num_ctas, occupancy)
7073
cufunc = get(_compilation_cache, cache_key, nothing)
7174
if cufunc === nothing || cuTile.compile_hook[] !== nothing
7275
cubin = compile(f, argtypes; name, sm_arch, opt_level, num_ctas, occupancy)

src/bytecode/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ end
179179
function julia_to_tile_dtype!(table::TypeTable, ::Type{T}) where T
180180
if T === Bool
181181
I1(table)
182-
elseif T === Int8
182+
elseif T === Int8 || T === UInt8
183183
I8(table)
184184
elseif T === Int16 || T === UInt16
185185
I16(table)

test/execution.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,3 +2100,39 @@ end
21002100
end
21012101

21022102
end
2103+
2104+
@testset "redefine kernel method" begin
2105+
mod = @eval module $(gensym())
2106+
import cuTile as ct
2107+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
2108+
pid = ct.bid(1)
2109+
ta = ct.load(a, (pid,), (16,))
2110+
tb = ct.load(b, (pid,), (16,))
2111+
ct.store(c, (pid,), ta + tb)
2112+
return
2113+
end
2114+
end
2115+
2116+
a = CUDA.ones(Float32, 1024)
2117+
b = CUDA.ones(Float32, 1024)
2118+
c = CUDA.zeros(Float32, 1024)
2119+
2120+
ct.launch(mod.vadd_kernel, 64, a, b, c)
2121+
2122+
@test Array(c) Array(a) + Array(b)
2123+
2124+
@eval mod begin
2125+
function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, c::ct.TileArray{Float32,1})
2126+
pid = ct.bid(1)
2127+
ta = ct.load(a, (pid,), (16,))
2128+
tb = ct.load(b, (pid,), (16,))
2129+
ct.store(c, (pid,), ta + tb * 2)
2130+
return
2131+
end
2132+
end
2133+
2134+
ct.launch(mod.vadd_kernel, 64, a, b, c)
2135+
2136+
@test Array(c) Array(a) + Array(b) * 2
2137+
end
2138+

0 commit comments

Comments
 (0)