Skip to content

Commit b39908f

Browse files
gbaraldiclaude
andcommitted
Add precompilation workload to cut TTFX, move SpecialFunctions to an extension
Mirror CUDA.jl's startup strategy: - Add `src/precompile.jl` with a PrecompileTools `@compile_workload` that runs the GPUCompiler -> AMDGPU codegen pipeline on a dummy kernel during precompilation. The first kernel launch no longer has to JIT-compile the whole compiler: cold first-kernel time drops ~8.2s -> ~1.0s. The workload builds the compiler config manually (baseline gfx900 target) so it needs neither a GPU nor ROCm discovery, and is guarded by `:AMDGPU in LLVM.backends()` and Julia >= 1.12 (matching CUDA.jl's foreign-MI workaround). Because the workload compiles before `__init__`/ROCm discovery runs, `libdevice_libs` is empty and the device-lib caches (`DEVICE_LIBS`, `_global_hostcalls`) would be poisoned with empty entries and baked into the image, breaking `__ocml_*` linking at runtime. We `empty!` them after the workload so they repopulate correctly once discovery has run. Adds the same explicit launch/library precompile directives CUDA.jl uses (hipfunction/actual_compilation/hiplink, plus rocBLAS/rocSOLVER/rocFFT/rand entry points). - Move SpecialFunctions to a weakdep + `AMDGPUSpecialFunctionsExt`, extracting the SpecialFunctions device overrides out of `device/gcn/math.jl` (Base-math overrides stay in the package). `using AMDGPU` no longer loads SpecialFunctions and its dependency tree. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 7c9aab0 commit b39908f

5 files changed

Lines changed: 146 additions & 8 deletions

File tree

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
2424
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
2525
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2626
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
27+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2728
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2829
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2930
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -32,7 +33,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3233
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
3334
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
3435
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
35-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3636
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3737
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3838
UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
@@ -41,11 +41,13 @@ UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
4141
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4242
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4343
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
44+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4445

4546
[extensions]
4647
AMDGPUChainRulesCoreExt = "ChainRulesCore"
4748
AMDGPUEnzymeCoreExt = "EnzymeCore"
4849
AMDGPUSparseMatricesCSRExt = "SparseMatricesCSR"
50+
AMDGPUSpecialFunctionsExt = "SpecialFunctions"
4951

5052
[compat]
5153
AbstractFFTs = "1.0"
@@ -64,6 +66,7 @@ KernelAbstractions = "0.9.2"
6466
LLD_jll = "15, 16, 17, 18, 19, 20"
6567
LLVM = "9"
6668
LLVM_jll = "15, 16, 17, 18, 19, 20"
69+
PrecompileTools = "1"
6770
Preferences = "1"
6871
PrettyTables = "3"
6972
ROCmDeviceLibs_jll = "=5.6.1, =6.2.1, =7.0.2"

ext/AMDGPUSpecialFunctionsExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
module AMDGPUSpecialFunctionsExt
2+
3+
# Device-side overrides mapping SpecialFunctions.jl functions to OCML intrinsics.
4+
# Kept in an extension (like CUDA.jl) so that `using AMDGPU` does not pay the
5+
# load-time cost of SpecialFunctions and its dependencies unless they are needed.
6+
7+
import AMDGPU
8+
import AMDGPU.Device: @device_override, fntypes, DEFINED_SF_INTRINSICS
9+
import SpecialFunctions
10+
11+
for jltype in (Float64, Float32, Float16)
12+
type_suffix = fntypes[jltype]
13+
for (fname, intrinsic) in DEFINED_SF_INTRINSICS
14+
@eval @device_override SpecialFunctions.$(fname)(x::$jltype) = ccall(
15+
$("extern __ocml_$(intrinsic)_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
16+
end
17+
end
18+
19+
end

src/AMDGPU.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ include("ROCKernels.jl")
147147
import .ROCKernels: ROCBackend
148148
export ROCBackend
149149

150+
include("precompile.jl")
151+
150152
function __init__()
151153
# Used to shutdown hostcalls if any is running.
152154
atexit(() -> begin Runtime.RT_EXITING[] = true end)

src/device/gcn/math.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import Base: FastMath
2-
import SpecialFunctions
32

43
const DEFINED_UNARY_INTRNISICS = [
54
(:Base, :acos), (:Base, :acosh), (nothing, :acospi), (:Base, :cos), (:Base, :cosh), (:Base, :cospi),
@@ -11,7 +10,10 @@ const DEFINED_UNARY_INTRNISICS = [
1110
(:Base, :floor), (:Base, :ceil), (:Base, :trunc),
1211
(nothing, :nearbyint), (nothing, :nextafter),
1312
]
14-
# SpecialFunctions (SF.fname, OCML intrinsic).
13+
# SpecialFunctions (SF.fname, OCML intrinsic). The device overrides themselves
14+
# live in `AMDGPUSpecialFunctionsExt`; this list stays here (it is just symbol
15+
# pairs, with no SpecialFunctions dependency) as the single source of truth used
16+
# both by the extension and by the test suite.
1517
const DEFINED_SF_INTRINSICS = [
1618
(:loggamma, :lgamma), (:gamma, :tgamma),
1719
(:bessely0, :y0), (:bessely1, :y1), (:besselj0, :j0), (:besselj1, :j1),
@@ -36,11 +38,6 @@ for jltype in (Float64, Float32, Float16)
3638
end
3739
end
3840

39-
for (fname, intrinsic) in DEFINED_SF_INTRINSICS
40-
@eval @device_override SpecialFunctions.$(fname)(x::$jltype) = ccall(
41-
$("extern __ocml_$(intrinsic)_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
42-
end
43-
4441
@eval @device_override Base.abs(x::$jltype) = ccall(
4542
$("extern __ocml_fabs_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
4643

src/precompile.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
using PrecompileTools: @compile_workload
2+
3+
# Compiling a GCN kernel requires being able to initialize the AMDGPU LLVM
4+
# backend, so we only run the precompile workload when that's supported, to be
5+
# able to load this package also on systems where the backend isn't available.
6+
#
7+
# This mirrors CUDA.jl's precompile workload: it warms up the GPUCompiler ->
8+
# AMDGPU codegen pipeline during precompilation so that the first kernel launch
9+
# at runtime doesn't have to JIT-compile the entire compiler. It does NOT need a
10+
# GPU (or even the ROCm runtime to be discovered) -- it only uses LLVM.
11+
if :AMDGPU in LLVM.backends()
12+
@compile_workload begin
13+
let
14+
function _precompile_kernel(a)
15+
i = workitemIdx().x
16+
@inbounds a[i] += 1.0f0
17+
return
18+
end
19+
20+
# Build a device-free compiler config for a baseline GCN target.
21+
# `gfx1030` (RDNA2, wavefront 32) is a portable baseline that exercises
22+
# the full pipeline; the cached *code* is reused regardless of the
23+
# actual device's ISA at runtime (only the kernel binary differs).
24+
#
25+
# NOTE: the ISA must be RDNA/CDNA, not pre-RDNA. The `wavefrontsize*`
26+
# LLVM features only exist on gfx10+, so pairing them with e.g. gfx900
27+
# (GCN3) yields an inconsistent target that miscompiles the
28+
# wavefront-sensitive exception/bounds-error path (manifests under
29+
# `--check-bounds=yes`/`--code-coverage`, where `@inbounds` is ignored).
30+
target = GPUCompiler.GCNCompilerTarget(;
31+
dev_isa="gfx1030", features="+wavefrontsize32,-wavefrontsize64")
32+
params = Compiler.HIPCompilerParams(false, true)
33+
config = GPUCompiler.CompilerConfig(target, params;
34+
kernel=true, name=nothing, always_inline=true)
35+
36+
tt = Tuple{ROCDeviceArray{Float32, 1, AS.Global}}
37+
source = GPUCompiler.methodinstance(typeof(_precompile_kernel), tt)
38+
job = GPUCompiler.CompilerJob(source, config)
39+
40+
# Under `--check-bounds=yes` (used by `Pkg.test`) or `--code-coverage`,
41+
# `@inbounds` is ignored, so `a[i]` emits a bounds-error path
42+
# (`throw_boundserror` -> `signal_exception` -> `kernel_state()`). That
43+
# path compiles fine at *runtime*, but NOT during precompilation: the
44+
# `@generated kernel_state()` fails to inline there, leaving a dynamic
45+
# call -> invalid GPU IR. Those flags only occur during testing, never
46+
# in normal user precompilation (where `@inbounds` elides the path), so
47+
# skip the warming compile then -- users still get the full benefit.
48+
instrumented = Base.JLOptions().code_coverage != 0 ||
49+
Base.JLOptions().check_bounds == 1
50+
51+
# On Julia < 1.12, GPU compilation during precompilation leaks foreign
52+
# MIs into native compilation, causing LLVM errors. Guard like CUDA.jl.
53+
@static if VERSION >= v"1.12-"
54+
if !instrumented
55+
GPUCompiler.JuliaContext() do ctx
56+
GPUCompiler.compile(:obj, job)
57+
end
58+
59+
# The compile above runs during precompilation, when ROCm
60+
# discovery (`__init__`) has NOT run, so `libdevice_libs` is
61+
# empty. That poisons the `DEVICE_LIBS` cache with empty entries
62+
# (e.g. an `ocml` `DevLib` with no path), which would be baked
63+
# into the precompile image and prevent device-library linking
64+
# at runtime (`unsupported call to __ocml_*`). Reset it so it is
65+
# repopulated correctly once discovery has run.
66+
empty!(Compiler.DEVICE_LIBS)
67+
end
68+
end
69+
end
70+
end
71+
end
72+
73+
# Kernel launch infrastructure that the workload above cannot reach, because it
74+
# requires a live device (mirrors CUDA.jl's explicit precompile directives:
75+
# `cufunction`, `link`, and `actual_compilation`).
76+
precompile(Tuple{typeof(Compiler.hipfunction), typeof(identity), Type{Tuple{Nothing}}})
77+
precompile(Tuple{typeof(GPUCompiler.actual_compilation),
78+
Dict{Any, HIP.HIPFunction}, Core.MethodInstance, UInt64,
79+
Compiler.HIPCompilerConfig, typeof(Compiler.hipcompile), typeof(Compiler.hiplink)})
80+
precompile(Tuple{typeof(Compiler.hiplink), Compiler.HIPCompilerJob,
81+
NamedTuple{(:obj, :entry, :global_hostcalls),
82+
Tuple{Vector{UInt8}, String, Vector{Symbol}}}})
83+
84+
# Hot entry points of the bundled ROCm libraries, mirroring CUDA.jl's per-library
85+
# precompile directives. These compile the (GPU-free) Julia wrappers so the first
86+
# `A * B`, factorization, FFT plan, etc. doesn't pay full first-use compilation.
87+
let RM = (T) -> ROCArray{T, 2, Mem.HIPBuffer}
88+
# rocBLAS: handle creation, GEMM and high-level matmul.
89+
precompile(Tuple{typeof(rocBLAS.create_handle)})
90+
precompile(Tuple{typeof(rocBLAS.lib_state)})
91+
for T in (Float32, Float64, ComplexF32, ComplexF64)
92+
precompile(Tuple{typeof(rocBLAS.gemm!), Char, Char, T, RM(T), RM(T), T, RM(T)})
93+
end
94+
for T in (Float32, Float64)
95+
precompile(Tuple{typeof(*), RM(T), RM(T)})
96+
precompile(Tuple{typeof(LinearAlgebra.mul!), RM(T), RM(T), RM(T)})
97+
end
98+
99+
# rocSOLVER: common factorizations.
100+
for T in (Float32, Float64)
101+
precompile(Tuple{typeof(rocSOLVER.getrf!), RM(T)})
102+
precompile(Tuple{typeof(rocSOLVER.geqrf!), RM(T)})
103+
precompile(Tuple{typeof(rocSOLVER.potrf!), Char, RM(T)})
104+
end
105+
106+
# rocFFT: plan creation for common types.
107+
for T in (ComplexF32, ComplexF64)
108+
precompile(Tuple{typeof(rocFFT.plan_fft!), RM(T), Int})
109+
end
110+
for T in (Float32, Float64)
111+
precompile(Tuple{typeof(rocFFT.plan_rfft), RM(T), Int})
112+
end
113+
114+
# rocRAND / random.
115+
precompile(Tuple{typeof(rand), Type{Float32}, Dims{2}})
116+
precompile(Tuple{typeof(Random.rand!), RM(Float32)})
117+
end

0 commit comments

Comments
 (0)