Skip to content

Commit 10b2702

Browse files
[pocl] RNG support (#659)
Co-authored-by: Christian Guinard <28689358+christiangnrd@users.noreply.github.com>
1 parent 9bf8fd5 commit 10b2702

9 files changed

Lines changed: 677 additions & 15 deletions

File tree

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1212
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1414
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
15+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
17+
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
1518
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
1619
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
1720
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
@@ -39,6 +42,9 @@ LLVM = "9.4.1"
3942
LinearAlgebra = "1.6"
4043
MacroTools = "0.5"
4144
PrecompileTools = "1"
45+
Random = "1"
46+
Random123 = "1.7.1"
47+
RandomNumbers = "1.6.0"
4248
SPIRVIntrinsics = "0.5"
4349
SPIRV_LLVM_Backend_jll = "22"
4450
SPIRV_Tools_jll = "2024.4, 2025.1"

src/pocl/compiler/compilation.jl

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,91 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919
in(fn, known_intrinsics) ||
2020
contains(fn, "__spirv_")
2121

22+
GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState
23+
24+
function GPUCompiler.finish_module!(
25+
@nospecialize(job::OpenCLCompilerJob),
26+
mod::LLVM.Module, entry::LLVM.Function
27+
)
28+
entry = invoke(
29+
GPUCompiler.finish_module!,
30+
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
31+
job, mod, entry
32+
)
33+
34+
# if this kernel uses our RNG, we should prime the shared state.
35+
# XXX: these transformations should really happen at the Julia IR level...
36+
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
37+
# insert call to `initialize_rng_state`
38+
f = initialize_rng_state
39+
ft = typeof(f)
40+
tt = Tuple{}
41+
42+
# create a deferred compilation job for `initialize_rng_state`
43+
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
44+
cfg = CompilerConfig(job.config; kernel = false, name = nothing)
45+
job = CompilerJob(src, cfg, job.world)
46+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
47+
GPUCompiler.deferred_codegen_jobs[id] = job
48+
49+
# generate IR for calls to `deferred_codegen` and the resulting function pointer
50+
top_bb = first(blocks(entry))
51+
bb = BasicBlock(top_bb, "initialize_rng")
52+
@dispose builder = IRBuilder() begin
53+
position!(builder, bb)
54+
subprogram = LLVM.subprogram(entry)
55+
if subprogram !== nothing
56+
loc = DILocation(0, 0, subprogram)
57+
debuglocation!(builder, loc)
58+
end
59+
debuglocation!(builder, first(instructions(top_bb)))
60+
61+
# call the `deferred_codegen` marker function
62+
T_ptr = if LLVM.version() >= v"17"
63+
LLVM.PointerType()
64+
elseif VERSION >= v"1.12.0-DEV.225"
65+
LLVM.PointerType(LLVM.Int8Type())
66+
else
67+
LLVM.Int64Type()
68+
end
69+
T_id = convert(LLVMType, Int)
70+
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
71+
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
72+
functions(mod)["deferred_codegen"]
73+
else
74+
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
75+
end
76+
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
77+
78+
# call the `initialize_rng_state` function
79+
rt = Core.Compiler.return_type(f, tt)
80+
llvm_rt = convert(LLVMType, rt)
81+
llvm_ft = LLVM.FunctionType(llvm_rt)
82+
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
83+
call!(builder, llvm_ft, fptr)
84+
br!(builder, top_bb)
85+
86+
# note the use of the device-side RNG in this kernel
87+
push!(function_attributes(entry), StringAttribute("julia.opencl.rng", ""))
88+
end
89+
90+
# XXX: put some of the above behind GPUCompiler abstractions
91+
# (e.g., a compile-time version of `deferred_codegen`)
92+
end
93+
return entry
94+
end
95+
96+
function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
97+
for f in GPUCompiler.kernels(mod)
98+
kernel_intrinsics = Dict(
99+
"julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
100+
"julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
101+
)
102+
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
103+
end
104+
return
105+
end
106+
22107

23108
## compiler implementation (cache, configure, compile, and link)
24109

@@ -59,11 +144,14 @@ end
59144
# compile to executable machine code
60145
function compile(@nospecialize(job::CompilerJob))
61146
# TODO: this creates a context; cache those.
62-
obj, meta = JuliaContext() do ctx
63-
GPUCompiler.compile(:obj, job)
64-
end
147+
return obj, meta = JuliaContext() do ctx
148+
obj, meta = GPUCompiler.compile(:obj, job)
65149

66-
return (; obj, entry = LLVM.name(meta.entry))
150+
entry = LLVM.name(meta.entry)
151+
device_rng = StringAttribute("julia.opencl.rng", "") in collect(function_attributes(meta.entry))
152+
153+
(; obj, entry, device_rng)
154+
end
67155
end
68156

69157
# link into an executable kernel
@@ -74,5 +162,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
74162
error("Your device does not support SPIR-V, which is currently required for native execution.")
75163
end
76164
cl.build!(prog)
77-
return cl.Kernel(prog, compiled.entry)
165+
return (; kernel = cl.Kernel(prog, compiled.entry), compiled.device_rng)
78166
end

src/pocl/compiler/execution.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ end
146146

147147
abstract type AbstractKernel{F, TT} end
148148

149+
pass_arg(@nospecialize dt) = !(GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt))
150+
149151
@inline @generated function (kernel::AbstractKernel{F, TT})(
150152
args...;
151153
call_kwargs...
@@ -154,8 +156,7 @@ abstract type AbstractKernel{F, TT} end
154156
args = (:(kernel.f), (:(clconvert(args[$i], svm_pointers)) for i in 1:length(args))...)
155157

156158
# filter out ghost arguments that shouldn't be passed
157-
predicate = dt -> GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt)
158-
to_pass = map(!predicate, sig.parameters)
159+
to_pass = map(pass_arg, sig.parameters)
159160
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
160161
call_args = Union{Expr, Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
161162

@@ -167,12 +168,15 @@ abstract type AbstractKernel{F, TT} end
167168
end
168169
end
169170

171+
pushfirst!(call_t, KernelState)
172+
pushfirst!(call_args, :(KernelState(kernel.rng_state ? Base.rand(UInt32) : UInt32(0))))
173+
170174
# finalize types
171175
call_tt = Base.to_tuple_type(call_t)
172176

173177
return quote
174178
svm_pointers = Ptr{Cvoid}[]
175-
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, call_kwargs...)
179+
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, kernel.rng_state, call_kwargs...)
176180
end
177181
end
178182

@@ -182,6 +186,7 @@ end
182186
struct HostKernel{F, TT} <: AbstractKernel{F, TT}
183187
f::F
184188
fun::cl.Kernel
189+
rng_state::Bool
185190
end
186191

187192

@@ -198,15 +203,15 @@ function clfunction(f::F, tt::TT = Tuple{}; kwargs...) where {F, TT}
198203
cache = compiler_cache(ctx)
199204
source = methodinstance(F, tt)
200205
config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig
201-
fun = GPUCompiler.cached_compilation(cache, source, config, compile, link)
206+
linked = GPUCompiler.cached_compilation(cache, source, config, compile, link)
202207

203208
# create a callable object that captures the function instance. we don't need to think
204209
# about world age here, as GPUCompiler already does and will return a different object
205-
h = hash(fun, hash(f, hash(tt)))
210+
h = hash(linked.kernel, hash(f, hash(tt)))
206211
kernel = get(_kernel_instances, h, nothing)
207212
if kernel === nothing
208213
# create the kernel state object
209-
kernel = HostKernel{F, tt}(f, fun)
214+
kernel = HostKernel{F, tt}(f, linked.kernel, linked.device_rng)
210215
_kernel_instances[h] = kernel
211216
end
212217
return kernel::HostKernel{F, tt}

0 commit comments

Comments
 (0)