|
| 1 | +using PrecompileTools: @compile_workload |
| 2 | + |
| 3 | +# Compiling a GCN kernel requires being able to initialize the AMDGPU LLVM |
| 4 | +# backend, so we only run the precompile workload when that's supported, to be |
| 5 | +# able to load this package also on systems where the backend isn't available. |
| 6 | +# |
| 7 | +# This mirrors CUDA.jl's precompile workload: it warms up the GPUCompiler -> |
| 8 | +# AMDGPU codegen pipeline during precompilation so that the first kernel launch |
| 9 | +# at runtime doesn't have to JIT-compile the entire compiler. It does NOT need a |
| 10 | +# GPU (or even the ROCm runtime to be discovered) -- it only uses LLVM. |
| 11 | +if :AMDGPU in LLVM.backends() |
| 12 | + @compile_workload begin |
| 13 | + let |
| 14 | + function _precompile_kernel(a) |
| 15 | + i = workitemIdx().x |
| 16 | + @inbounds a[i] += 1.0f0 |
| 17 | + return |
| 18 | + end |
| 19 | + |
| 20 | + # Build a device-free compiler config for a baseline GCN target. |
| 21 | + # `gfx900` (wavefront 64) is a portable baseline that exercises the |
| 22 | + # full pipeline; the cached *code* is reused regardless of the actual |
| 23 | + # device's ISA at runtime (only the kernel binary differs). |
| 24 | + target = GPUCompiler.GCNCompilerTarget(; |
| 25 | + dev_isa="gfx900", features="-wavefrontsize32,+wavefrontsize64") |
| 26 | + params = Compiler.HIPCompilerParams(true, true) |
| 27 | + config = GPUCompiler.CompilerConfig(target, params; |
| 28 | + kernel=true, name=nothing, always_inline=true) |
| 29 | + |
| 30 | + tt = Tuple{ROCDeviceArray{Float32, 1, AS.Global}} |
| 31 | + source = GPUCompiler.methodinstance(typeof(_precompile_kernel), tt) |
| 32 | + job = GPUCompiler.CompilerJob(source, config) |
| 33 | + |
| 34 | + # On Julia < 1.12, GPU compilation during precompilation leaks foreign |
| 35 | + # MIs into native compilation, causing LLVM errors. Guard like CUDA.jl. |
| 36 | + @static if VERSION >= v"1.12-" |
| 37 | + GPUCompiler.JuliaContext() do ctx |
| 38 | + GPUCompiler.compile(:obj, job) |
| 39 | + end |
| 40 | + |
| 41 | + # The compile above runs during precompilation, when ROCm |
| 42 | + # discovery (`__init__`) has NOT run, so `libdevice_libs` is |
| 43 | + # empty. That poisons these module-level caches with empty |
| 44 | + # entries (e.g. an `ocml` `DevLib` with no path), which would be |
| 45 | + # baked into the precompile image and prevent device-library |
| 46 | + # linking at runtime (`unsupported call to __ocml_*`). Reset them |
| 47 | + # so they are repopulated correctly once discovery has run. |
| 48 | + empty!(Compiler.DEVICE_LIBS) |
| 49 | + empty!(Compiler._global_hostcalls) |
| 50 | + end |
| 51 | + end |
| 52 | + end |
| 53 | +end |
| 54 | + |
| 55 | +# Kernel launch infrastructure that the workload above cannot reach, because it |
| 56 | +# requires a live device (mirrors CUDA.jl's explicit precompile directives: |
| 57 | +# `cufunction`, `link`, and `actual_compilation`). |
| 58 | +precompile(Tuple{typeof(Compiler.hipfunction), typeof(identity), Type{Tuple{Nothing}}}) |
| 59 | +precompile(Tuple{typeof(GPUCompiler.actual_compilation), |
| 60 | + Dict{Any, HIP.HIPFunction}, Core.MethodInstance, UInt64, |
| 61 | + Compiler.HIPCompilerConfig, typeof(Compiler.hipcompile), typeof(Compiler.hiplink)}) |
| 62 | +precompile(Tuple{typeof(Compiler.hiplink), Compiler.HIPCompilerJob, |
| 63 | + NamedTuple{(:obj, :entry, :global_hostcalls), |
| 64 | + Tuple{Vector{UInt8}, String, Vector{Symbol}}}}) |
| 65 | + |
| 66 | +# Hot entry points of the bundled ROCm libraries, mirroring CUDA.jl's per-library |
| 67 | +# precompile directives. These compile the (GPU-free) Julia wrappers so the first |
| 68 | +# `A * B`, factorization, FFT plan, etc. doesn't pay full first-use compilation. |
| 69 | +let RM = (T) -> ROCArray{T, 2, Mem.HIPBuffer} |
| 70 | + # rocBLAS: handle creation, GEMM and high-level matmul. |
| 71 | + precompile(Tuple{typeof(rocBLAS.create_handle)}) |
| 72 | + precompile(Tuple{typeof(rocBLAS.lib_state)}) |
| 73 | + for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 74 | + precompile(Tuple{typeof(rocBLAS.gemm!), Char, Char, T, RM(T), RM(T), T, RM(T)}) |
| 75 | + end |
| 76 | + for T in (Float32, Float64) |
| 77 | + precompile(Tuple{typeof(*), RM(T), RM(T)}) |
| 78 | + precompile(Tuple{typeof(LinearAlgebra.mul!), RM(T), RM(T), RM(T)}) |
| 79 | + end |
| 80 | + |
| 81 | + # rocSOLVER: common factorizations. |
| 82 | + for T in (Float32, Float64) |
| 83 | + precompile(Tuple{typeof(rocSOLVER.getrf!), RM(T)}) |
| 84 | + precompile(Tuple{typeof(rocSOLVER.geqrf!), RM(T)}) |
| 85 | + precompile(Tuple{typeof(rocSOLVER.potrf!), Char, RM(T)}) |
| 86 | + end |
| 87 | + |
| 88 | + # rocFFT: plan creation for common types. |
| 89 | + for T in (ComplexF32, ComplexF64) |
| 90 | + precompile(Tuple{typeof(rocFFT.plan_fft!), RM(T), Int}) |
| 91 | + end |
| 92 | + for T in (Float32, Float64) |
| 93 | + precompile(Tuple{typeof(rocFFT.plan_rfft), RM(T), Int}) |
| 94 | + end |
| 95 | + |
| 96 | + # rocRAND / random. |
| 97 | + precompile(Tuple{typeof(rand), Type{Float32}, Dims{2}}) |
| 98 | + precompile(Tuple{typeof(Random.rand!), RM(Float32)}) |
| 99 | +end |
0 commit comments