Skip to content

Commit 8a6350c

Browse files
authored
Generic support for ghost types in launch (#93)
1 parent 0f51b81 commit 8a6350c

3 files changed

Lines changed: 71 additions & 3 deletions

File tree

ext/CUDAExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module CUDAExt
22

33
using cuTile
44
using cuTile: TileArray, Constant, CGOpts, CuTileResults, emit_code, sanitize_name,
5-
constant_eltype, constant_value
5+
constant_eltype, constant_value, is_ghost_type
66

77
using CompilerCaching: CacheView, method_instance, results
88

@@ -187,9 +187,8 @@ return their fields in order.
187187
188188
This is used by the launch helper to splat arguments to cudacall.
189189
"""
190-
flatten(x) = (x,)
190+
flatten(x) = is_ghost_type(typeof(x)) ? () : (x,)
191191
flatten(arr::TileArray{T, N}) where {T, N} = (arr.ptr, arr.sizes..., arr.strides...)
192-
flatten(::Constant) = () # Ghost types are not passed to cudacall
193192

194193
"""
195194
to_tile_arg(x)

test/codegen/integration.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,3 +1123,39 @@ end
11231123
end
11241124
end
11251125
end
1126+
1127+
#=============================================================================
1128+
Ghost Type Arguments
1129+
=============================================================================#
1130+
1131+
@testset "Non-Constant Ghost Type Arguments" begin
1132+
spec = ct.ArraySpec{1}(16, true)
1133+
1134+
@testset "Nothing argument" begin
1135+
@test @filecheck begin
1136+
@check_label "entry"
1137+
@check "load_view_tko"
1138+
@check "store_view_tko"
1139+
code_tiled(Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, Nothing}) do a, b, _
1140+
pid = ct.bid(1)
1141+
tile = ct.load(a, pid, (16,))
1142+
ct.store(b, pid, tile)
1143+
return
1144+
end
1145+
end
1146+
end
1147+
1148+
@testset "Val argument" begin
1149+
@test @filecheck begin
1150+
@check_label "entry"
1151+
@check "load_view_tko"
1152+
@check "store_view_tko"
1153+
code_tiled(Tuple{ct.TileArray{Float32,1,spec}, ct.TileArray{Float32,1,spec}, Val{16}}) do a, b, _
1154+
pid = ct.bid(1)
1155+
tile = ct.load(a, pid, (16,))
1156+
ct.store(b, pid, tile)
1157+
return
1158+
end
1159+
end
1160+
end
1161+
end

test/execution/basic.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,3 +1065,36 @@ end
10651065
ct.launch(kernel!, 1)
10661066
end
10671067

1068+
@testset "non-Constant ghost type argument (nothing)" begin
1069+
function ghost_nothing_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, ::Nothing)
1070+
pid = ct.bid(1)
1071+
tile = ct.load(a, pid, (16,))
1072+
ct.store(b, pid, tile)
1073+
return
1074+
end
1075+
1076+
n = 256
1077+
a = CUDA.rand(Float32, n)
1078+
b = CUDA.zeros(Float32, n)
1079+
1080+
ct.launch(ghost_nothing_kernel, cld(n, 16), a, b, nothing)
1081+
1082+
@test Array(b) Array(a)
1083+
end
1084+
1085+
@testset "non-Constant ghost type argument (Val)" begin
1086+
function ghost_val_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, ::Val{n}) where n
1087+
pid = ct.bid(1)
1088+
tile = ct.load(a, pid, (n,))
1089+
ct.store(b, pid, tile)
1090+
return
1091+
end
1092+
1093+
n = 256
1094+
a = CUDA.rand(Float32, n)
1095+
b = CUDA.zeros(Float32, n)
1096+
1097+
ct.launch(ghost_val_kernel, cld(n, 16), a, b, Val(16))
1098+
1099+
@test Array(b) Array(a)
1100+
end

0 commit comments

Comments
 (0)