Skip to content

Commit 349cf73

Browse files
gbaraldiclaude
andcommitted
Add precompilation workload to cut TTFX, move SpecialFunctions to an extension
Mirror CUDA.jl's startup strategy: - Add `src/precompile.jl` with a PrecompileTools `@compile_workload` that runs the GPUCompiler -> AMDGPU codegen pipeline on a dummy kernel during precompilation. The first kernel launch no longer has to JIT-compile the whole compiler: cold first-kernel time drops ~8.2s -> ~1.0s. The workload builds the compiler config manually (baseline gfx900 target) so it needs neither a GPU nor ROCm discovery, and is guarded by `:AMDGPU in LLVM.backends()` and Julia >= 1.12 (matching CUDA.jl's foreign-MI workaround). Because the workload compiles before `__init__`/ROCm discovery runs, `libdevice_libs` is empty and the device-lib caches (`DEVICE_LIBS`, `_global_hostcalls`) would be poisoned with empty entries and baked into the image, breaking `__ocml_*` linking at runtime. We `empty!` them after the workload so they repopulate correctly once discovery has run. Adds the same explicit launch/library precompile directives CUDA.jl uses (hipfunction/actual_compilation/hiplink, plus rocBLAS/rocSOLVER/rocFFT/rand entry points). - Move SpecialFunctions to a weakdep + `AMDGPUSpecialFunctionsExt`, extracting the SpecialFunctions device overrides out of `device/gcn/math.jl` (Base-math overrides stay in the package). `using AMDGPU` no longer loads SpecialFunctions and its dependency tree. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 7c9aab0 commit 349cf73

5 files changed

Lines changed: 131 additions & 13 deletions

File tree

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
2424
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
2525
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2626
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
27+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2728
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2829
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2930
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -32,7 +33,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3233
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
3334
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
3435
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
35-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3636
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3737
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3838
UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
@@ -41,11 +41,13 @@ UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
4141
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4242
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4343
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
44+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
4445

4546
[extensions]
4647
AMDGPUChainRulesCoreExt = "ChainRulesCore"
4748
AMDGPUEnzymeCoreExt = "EnzymeCore"
4849
AMDGPUSparseMatricesCSRExt = "SparseMatricesCSR"
50+
AMDGPUSpecialFunctionsExt = "SpecialFunctions"
4951

5052
[compat]
5153
AbstractFFTs = "1.0"
@@ -64,6 +66,7 @@ KernelAbstractions = "0.9.2"
6466
LLD_jll = "15, 16, 17, 18, 19, 20"
6567
LLVM = "9"
6668
LLVM_jll = "15, 16, 17, 18, 19, 20"
69+
PrecompileTools = "1"
6770
Preferences = "1"
6871
PrettyTables = "3"
6972
ROCmDeviceLibs_jll = "=5.6.1, =6.2.1, =7.0.2"

ext/AMDGPUSpecialFunctionsExt.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
module AMDGPUSpecialFunctionsExt
2+
3+
# Device-side overrides mapping SpecialFunctions.jl functions to OCML intrinsics.
4+
# Kept in an extension (like CUDA.jl) so that `using AMDGPU` does not pay the
5+
# load-time cost of SpecialFunctions and its dependencies unless they are needed.
6+
7+
import AMDGPU
8+
import AMDGPU.Device: @device_override, fntypes
9+
import SpecialFunctions
10+
11+
# SpecialFunctions (SF.fname, OCML intrinsic).
12+
const DEFINED_SF_INTRINSICS = [
13+
(:loggamma, :lgamma), (:gamma, :tgamma),
14+
(:bessely0, :y0), (:bessely1, :y1), (:besselj0, :j0), (:besselj1, :j1),
15+
(:erf, :erf), (:erfc, :erfc), (:erfcx, :erfcx), (:erfinv, :erfinv), (:erfcinv, :erfcinv),
16+
]
17+
18+
for jltype in (Float64, Float32, Float16)
19+
type_suffix = fntypes[jltype]
20+
for (fname, intrinsic) in DEFINED_SF_INTRINSICS
21+
@eval @device_override SpecialFunctions.$(fname)(x::$jltype) = ccall(
22+
$("extern __ocml_$(intrinsic)_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
23+
end
24+
end
25+
26+
end

src/AMDGPU.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ include("ROCKernels.jl")
147147
import .ROCKernels: ROCBackend
148148
export ROCBackend
149149

150+
include("precompile.jl")
151+
150152
function __init__()
151153
# Used to shutdown hostcalls if any is running.
152154
atexit(() -> begin Runtime.RT_EXITING[] = true end)

src/device/gcn/math.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import Base: FastMath
2-
import SpecialFunctions
32

43
const DEFINED_UNARY_INTRNISICS = [
54
(:Base, :acos), (:Base, :acosh), (nothing, :acospi), (:Base, :cos), (:Base, :cosh), (:Base, :cospi),
@@ -11,12 +10,6 @@ const DEFINED_UNARY_INTRNISICS = [
1110
(:Base, :floor), (:Base, :ceil), (:Base, :trunc),
1211
(nothing, :nearbyint), (nothing, :nextafter),
1312
]
14-
# SpecialFunctions (SF.fname, OCML intrinsic).
15-
const DEFINED_SF_INTRINSICS = [
16-
(:loggamma, :lgamma), (:gamma, :tgamma),
17-
(:bessely0, :y0), (:bessely1, :y1), (:besselj0, :j0), (:besselj1, :j1),
18-
(:erf, :erf), (:erfc, :erfc), (:erfcx, :erfcx), (:erfinv, :erfinv), (:erfcinv, :erfcinv),
19-
]
2013

2114
for jltype in (Float64, Float32, Float16)
2215
type_suffix = fntypes[jltype]
@@ -36,11 +29,6 @@ for jltype in (Float64, Float32, Float16)
3629
end
3730
end
3831

39-
for (fname, intrinsic) in DEFINED_SF_INTRINSICS
40-
@eval @device_override SpecialFunctions.$(fname)(x::$jltype) = ccall(
41-
$("extern __ocml_$(intrinsic)_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
42-
end
43-
4432
@eval @device_override Base.abs(x::$jltype) = ccall(
4533
$("extern __ocml_fabs_$(type_suffix)"), llvmcall, $jltype, ($jltype,), x)
4634

src/precompile.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
using PrecompileTools: @compile_workload
2+
3+
# Compiling a GCN kernel requires being able to initialize the AMDGPU LLVM
4+
# backend, so we only run the precompile workload when that's supported, to be
5+
# able to load this package also on systems where the backend isn't available.
6+
#
7+
# This mirrors CUDA.jl's precompile workload: it warms up the GPUCompiler ->
8+
# AMDGPU codegen pipeline during precompilation so that the first kernel launch
9+
# at runtime doesn't have to JIT-compile the entire compiler. It does NOT need a
10+
# GPU (or even the ROCm runtime to be discovered) -- it only uses LLVM.
11+
if :AMDGPU in LLVM.backends()
12+
@compile_workload begin
13+
let
14+
function _precompile_kernel(a)
15+
i = workitemIdx().x
16+
@inbounds a[i] += 1.0f0
17+
return
18+
end
19+
20+
# Build a device-free compiler config for a baseline GCN target.
21+
# `gfx900` (wavefront 64) is a portable baseline that exercises the
22+
# full pipeline; the cached *code* is reused regardless of the actual
23+
# device's ISA at runtime (only the kernel binary differs).
24+
target = GPUCompiler.GCNCompilerTarget(;
25+
dev_isa="gfx900", features="-wavefrontsize32,+wavefrontsize64")
26+
params = Compiler.HIPCompilerParams(true, true)
27+
config = GPUCompiler.CompilerConfig(target, params;
28+
kernel=true, name=nothing, always_inline=true)
29+
30+
tt = Tuple{ROCDeviceArray{Float32, 1, AS.Global}}
31+
source = GPUCompiler.methodinstance(typeof(_precompile_kernel), tt)
32+
job = GPUCompiler.CompilerJob(source, config)
33+
34+
# On Julia < 1.12, GPU compilation during precompilation leaks foreign
35+
# MIs into native compilation, causing LLVM errors. Guard like CUDA.jl.
36+
@static if VERSION >= v"1.12-"
37+
GPUCompiler.JuliaContext() do ctx
38+
GPUCompiler.compile(:obj, job)
39+
end
40+
41+
# The compile above runs during precompilation, when ROCm
42+
# discovery (`__init__`) has NOT run, so `libdevice_libs` is
43+
# empty. That poisons these module-level caches with empty
44+
# entries (e.g. an `ocml` `DevLib` with no path), which would be
45+
# baked into the precompile image and prevent device-library
46+
# linking at runtime (`unsupported call to __ocml_*`). Reset them
47+
# so they are repopulated correctly once discovery has run.
48+
empty!(Compiler.DEVICE_LIBS)
49+
empty!(Compiler._global_hostcalls)
50+
end
51+
end
52+
end
53+
end
54+
55+
# Kernel launch infrastructure that the workload above cannot reach, because it
56+
# requires a live device (mirrors CUDA.jl's explicit precompile directives:
57+
# `cufunction`, `link`, and `actual_compilation`).
58+
precompile(Tuple{typeof(Compiler.hipfunction), typeof(identity), Type{Tuple{Nothing}}})
59+
precompile(Tuple{typeof(GPUCompiler.actual_compilation),
60+
Dict{Any, HIP.HIPFunction}, Core.MethodInstance, UInt64,
61+
Compiler.HIPCompilerConfig, typeof(Compiler.hipcompile), typeof(Compiler.hiplink)})
62+
precompile(Tuple{typeof(Compiler.hiplink), Compiler.HIPCompilerJob,
63+
NamedTuple{(:obj, :entry, :global_hostcalls),
64+
Tuple{Vector{UInt8}, String, Vector{Symbol}}}})
65+
66+
# Hot entry points of the bundled ROCm libraries, mirroring CUDA.jl's per-library
67+
# precompile directives. These compile the (GPU-free) Julia wrappers so the first
68+
# `A * B`, factorization, FFT plan, etc. doesn't pay full first-use compilation.
69+
let RM = (T) -> ROCArray{T, 2, Mem.HIPBuffer}
70+
# rocBLAS: handle creation, GEMM and high-level matmul.
71+
precompile(Tuple{typeof(rocBLAS.create_handle)})
72+
precompile(Tuple{typeof(rocBLAS.lib_state)})
73+
for T in (Float32, Float64, ComplexF32, ComplexF64)
74+
precompile(Tuple{typeof(rocBLAS.gemm!), Char, Char, T, RM(T), RM(T), T, RM(T)})
75+
end
76+
for T in (Float32, Float64)
77+
precompile(Tuple{typeof(*), RM(T), RM(T)})
78+
precompile(Tuple{typeof(LinearAlgebra.mul!), RM(T), RM(T), RM(T)})
79+
end
80+
81+
# rocSOLVER: common factorizations.
82+
for T in (Float32, Float64)
83+
precompile(Tuple{typeof(rocSOLVER.getrf!), RM(T)})
84+
precompile(Tuple{typeof(rocSOLVER.geqrf!), RM(T)})
85+
precompile(Tuple{typeof(rocSOLVER.potrf!), Char, RM(T)})
86+
end
87+
88+
# rocFFT: plan creation for common types.
89+
for T in (ComplexF32, ComplexF64)
90+
precompile(Tuple{typeof(rocFFT.plan_fft!), RM(T), Int})
91+
end
92+
for T in (Float32, Float64)
93+
precompile(Tuple{typeof(rocFFT.plan_rfft), RM(T), Int})
94+
end
95+
96+
# rocRAND / random.
97+
precompile(Tuple{typeof(rand), Type{Float32}, Dims{2}})
98+
precompile(Tuple{typeof(Random.rand!), RM(Float32)})
99+
end

0 commit comments

Comments
 (0)