@@ -19,6 +19,91 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919 in (fn, known_intrinsics) ||
2020 contains (fn, " __spirv_" )
2121
22+ GPUCompiler. kernel_state_type (:: OpenCLCompilerJob ) = KernelState
23+
24+ function GPUCompiler. finish_module! (
25+ @nospecialize (job:: OpenCLCompilerJob ),
26+ mod:: LLVM.Module , entry:: LLVM.Function
27+ )
28+ entry = invoke (
29+ GPUCompiler. finish_module!,
30+ Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM. Module, LLVM. Function},
31+ job, mod, entry
32+ )
33+
34+ # if this kernel uses our RNG, we should prime the shared state.
35+ # XXX : these transformations should really happen at the Julia IR level...
36+ if haskey (functions (mod), " julia.opencl.random_keys" ) && job. config. kernel
37+ # insert call to `initialize_rng_state`
38+ f = initialize_rng_state
39+ ft = typeof (f)
40+ tt = Tuple{}
41+
42+ # create a deferred compilation job for `initialize_rng_state`
43+ src = methodinstance (ft, tt, GPUCompiler. tls_world_age ())
44+ cfg = CompilerConfig (job. config; kernel = false , name = nothing )
45+ job = CompilerJob (src, cfg, job. world)
46+ id = length (GPUCompiler. deferred_codegen_jobs) + 1
47+ GPUCompiler. deferred_codegen_jobs[id] = job
48+
49+ # generate IR for calls to `deferred_codegen` and the resulting function pointer
50+ top_bb = first (blocks (entry))
51+ bb = BasicBlock (top_bb, " initialize_rng" )
52+ @dispose builder = IRBuilder () begin
53+ position! (builder, bb)
54+ subprogram = LLVM. subprogram (entry)
55+ if subprogram != = nothing
56+ loc = DILocation (0 , 0 , subprogram)
57+ debuglocation! (builder, loc)
58+ end
59+ debuglocation! (builder, first (instructions (top_bb)))
60+
61+ # call the `deferred_codegen` marker function
62+ T_ptr = if LLVM. version () >= v " 17"
63+ LLVM. PointerType ()
64+ elseif VERSION >= v " 1.12.0-DEV.225"
65+ LLVM. PointerType (LLVM. Int8Type ())
66+ else
67+ LLVM. Int64Type ()
68+ end
69+ T_id = convert (LLVMType, Int)
70+ deferred_codegen_ft = LLVM. FunctionType (T_ptr, [T_id])
71+ deferred_codegen = if haskey (functions (mod), " deferred_codegen" )
72+ functions (mod)[" deferred_codegen" ]
73+ else
74+ LLVM. Function (mod, " deferred_codegen" , deferred_codegen_ft)
75+ end
76+ fptr = call! (builder, deferred_codegen_ft, deferred_codegen, [ConstantInt (id)])
77+
78+ # call the `initialize_rng_state` function
79+ rt = Core. Compiler. return_type (f, tt)
80+ llvm_rt = convert (LLVMType, rt)
81+ llvm_ft = LLVM. FunctionType (llvm_rt)
82+ fptr = inttoptr! (builder, fptr, LLVM. PointerType (llvm_ft))
83+ call! (builder, llvm_ft, fptr)
84+ br! (builder, top_bb)
85+
86+ # note the use of the device-side RNG in this kernel
87+ push! (function_attributes (entry), StringAttribute (" julia.opencl.rng" , " " ))
88+ end
89+
90+ # XXX : put some of the above behind GPUCompiler abstractions
91+ # (e.g., a compile-time version of `deferred_codegen`)
92+ end
93+ return entry
94+ end
95+
96+ function GPUCompiler. finish_linked_module! (@nospecialize (job:: OpenCLCompilerJob ), mod:: LLVM.Module )
97+ for f in GPUCompiler. kernels (mod)
98+ kernel_intrinsics = Dict (
99+ " julia.opencl.random_keys" => (; name = " random_keys" , typ = LLVMPtr{UInt32, AS. Workgroup}),
100+ " julia.opencl.random_counters" => (; name = " random_counters" , typ = LLVMPtr{UInt32, AS. Workgroup}),
101+ )
102+ GPUCompiler. add_input_arguments! (job, mod, f, kernel_intrinsics)
103+ end
104+ return
105+ end
106+
22107
23108# # compiler implementation (cache, configure, compile, and link)
24109
59144# compile to executable machine code
60145function compile (@nospecialize (job:: CompilerJob ))
61146 # TODO : this creates a context; cache those.
62- obj, meta = JuliaContext () do ctx
63- GPUCompiler. compile (:obj , job)
64- end
147+ return obj, meta = JuliaContext () do ctx
148+ obj, meta = GPUCompiler. compile (:obj , job)
65149
66- return (; obj, entry = LLVM. name (meta. entry))
150+ entry = LLVM. name (meta. entry)
151+ device_rng = StringAttribute (" julia.opencl.rng" , " " ) in collect (function_attributes (meta. entry))
152+
153+ (; obj, entry, device_rng)
154+ end
67155end
68156
69157# link into an executable kernel
@@ -74,5 +162,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
74162 error (" Your device does not support SPIR-V, which is currently required for native execution." )
75163 end
76164 cl. build! (prog)
77- return cl. Kernel (prog, compiled. entry)
165+ return (; kernel = cl. Kernel (prog, compiled. entry), compiled . device_rng )
78166end
0 commit comments