Skip to content

Commit 7b270bd

Browse files
maleadtclaude
andauthored
Metal: split multiply-used aggregate loads (#822)
Julia emits a single by-value load of a large nested aggregate (e.g. an Oceananigans RectilinearGrid passed by reference) feeding several extractvalues. InstCombine only folds extractvalue(load) -> load(gep) for single-use loads, so a multiply-used aggregate load survives to the AGX backend, which crashes lowering it. Apply the same fold without the single-use guard (a codegen heuristic, not a correctness condition). Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 1f18d87 commit 7b270bd

2 files changed

Lines changed: 141 additions & 0 deletions

File tree

src/metal.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,79 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
178178
errors
179179
end
180180

181+
# aggregate load splitting (JuliaGPU/Metal.jl#792)
182+
#
183+
# Julia can emit a single by-value `load` of a large, deeply-nested aggregate (e.g. an
184+
# Oceananigans `RectilinearGrid` passed by reference) that feeds several `extractvalue`s.
185+
# Apple's AGX back-end crashes during native-code generation when lowering such a wide
186+
# aggregate load. We rewrite each `extractvalue (load p), idxs` into a narrow field load
187+
# `load (inbounds_gep p, 0, idxs)` and delete the now-dead wide load, so only per-field
188+
# loads reach the back-end.
189+
#
190+
# This is exactly LLVM's own `extractvalue (load)` -> `load (gep)` fold in InstCombine
191+
# (visitExtractValueInst), with one difference: InstCombine guards it on the load having a
192+
# single use, declining multiply-used loads as "a struct with padding [where] we don't want
193+
# to do the transformation as it loses padding knowledge". That guard is a codegen
194+
# heuristic (one wide load can be cheaper than N field loads), not a correctness condition,
195+
# so dropping it is sound — and necessary here, since the crashing pattern is precisely a
196+
# multiply-used aggregate load that InstCombine therefore leaves intact.
197+
#
198+
# Restricted to simple (non-volatile, non-atomic) loads all of whose users are
199+
# `extractvalue` — the by-value-aggregate-argument pattern — so the wide load can be fully
200+
# eliminated. Like LLVM's fold, the field loads take their type's natural (ABI) alignment,
201+
# valid because the aggregate base load is at least that aligned, and AA metadata is copied
202+
# from the wide load (sound for the narrower field loads it subsumes).
203+
function split_aggregate_loads!(mod::LLVM.Module)
204+
aa_kinds = (LLVM.MD_tbaa, LLVM.MD_tbaa_struct, LLVM.MD_alias_scope, LLVM.MD_noalias)
205+
changed = false
206+
for f in functions(mod)
207+
isdeclaration(f) && continue
208+
worklist = LLVM.LoadInst[]
209+
for bb in blocks(f), inst in instructions(bb)
210+
inst isa LLVM.LoadInst || continue
211+
T = value_type(inst)
212+
(T isa LLVM.StructType || T isa LLVM.ArrayType) || continue
213+
iszero(LLVM.API.LLVMGetVolatile(inst)) || continue
214+
LLVM.API.LLVMGetOrdering(inst) == LLVM.API.LLVMAtomicOrderingNotAtomic || continue
215+
uselist = collect(uses(inst))
216+
isempty(uselist) && continue
217+
all(u -> user(u) isa LLVM.ExtractValueInst, uselist) || continue
218+
push!(worklist, inst)
219+
end
220+
for ld in worklist
221+
ptr = operands(ld)[1]
222+
aggty = value_type(ld)
223+
md = metadata(ld)
224+
i32 = LLVM.Int32Type()
225+
@dispose builder=IRBuilder() begin
226+
# build the field loads at the wide load's location, not the extractvalue's
227+
position!(builder, ld)
228+
for u in collect(uses(ld))
229+
ev = user(u)::LLVM.ExtractValueInst
230+
n = LLVM.API.LLVMGetNumIndices(ev)
231+
idxptr = LLVM.API.LLVMGetIndices(ev)
232+
# extractvalue has integer indices; getelementptr takes Values, prefixed
233+
# with an i32 0 to step through the pointer to the aggregate's first element.
234+
gepidx = LLVM.Value[ConstantInt(i32, 0)]
235+
for k in 1:n
236+
push!(gepidx, ConstantInt(i32, unsafe_load(idxptr, k)))
237+
end
238+
gep = inbounds_gep!(builder, aggty, ptr, gepidx)
239+
fieldload = load!(builder, value_type(ev), gep)
240+
for kind in aa_kinds
241+
haskey(md, kind) && (metadata(fieldload)[kind] = md[kind])
242+
end
243+
replace_uses!(ev, fieldload)
244+
erase!(ev)
245+
end
246+
end
247+
erase!(ld)
248+
changed = true
249+
end
250+
end
251+
return changed
252+
end
253+
181254
function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module,
182255
entry::LLVM.Function)
183256
entry_fn = LLVM.name(entry)
@@ -197,6 +270,10 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
197270
# reporters) load from the constant space rather than crashing Metal's validator.
198271
propagate_argument_address_spaces!(mod)
199272

273+
# split multiply-used by-value aggregate loads into narrow per-field loads; the AGX
274+
# back-end crashes during native codegen on wide aggregate loads (#792).
275+
split_aggregate_loads!(mod)
276+
200277
# propagate specific address spaces through addrspacecast chains introduced
201278
# by the rewrites above, so that loads/stores happen in the right address
202279
# space (e.g. constant globals in addrspace 2 rather than via a cast to 0,

test/metal.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,4 +808,68 @@ end
808808
end
809809
end
810810

811+
@testset "aggregate load splitting" begin
812+
# JuliaGPU/Metal.jl#792: Julia emits a single by-value `load` of a large, deeply-nested
813+
# aggregate (an Oceananigans `RectilinearGrid` passed by reference) feeding several
814+
# `extractvalue`s, and Apple's AGX back-end crashes lowering that wide load. The pass
815+
# rewrites each `extractvalue (load p), idxs` into a narrow field load
816+
# `load (inbounds_gep p, 0, idxs)` and deletes the wide load. InstCombine does the same
817+
# fold but only for single-use loads, so a multiply-used aggregate load reaches the
818+
# back-end intact. (The typed-pointer syntax below parses to opaque pointers too, so this
819+
# covers both pointer regimes.)
820+
is_aggregate_load(i) = i isa LLVM.LoadInst &&
821+
(value_type(i) isa LLVM.StructType || value_type(i) isa LLVM.ArrayType)
822+
823+
# a multiply-used load, with both a top-level (`,0`) and a nested (`,2,0`) index
824+
Context() do ctx
825+
ir = """
826+
define void @f({ i64, float, { float, i64 } }* %p, i64* %o1, float* %o2) {
827+
entry:
828+
%agg = load { i64, float, { float, i64 } }, { i64, float, { float, i64 } }* %p, align 8
829+
%a = extractvalue { i64, float, { float, i64 } } %agg, 0
830+
%b = extractvalue { i64, float, { float, i64 } } %agg, 2, 0
831+
store i64 %a, i64* %o1, align 8
832+
store float %b, float* %o2, align 4
833+
ret void
834+
}
835+
"""
836+
mod = parse(LLVM.Module, ir)
837+
insts() = [i for f in functions(mod) for bb in blocks(f) for i in instructions(bb)]
838+
839+
# precondition: exactly one aggregate-typed load (used by the two extractvalues)
840+
@test count(is_aggregate_load, insts()) == 1
841+
842+
@test GPUCompiler.split_aggregate_loads!(mod)
843+
844+
# the wide aggregate load and every extractvalue are gone, replaced by narrow loads
845+
@test count(is_aggregate_load, insts()) == 0
846+
@test !any(i -> i isa LLVM.ExtractValueInst, insts())
847+
loads = filter(i -> i isa LLVM.LoadInst, insts())
848+
@test length(loads) == 2
849+
@test Set(string(value_type(l)) for l in loads) == Set(["i64", "float"])
850+
# each field load is fed by an inbounds GEP off the original pointer
851+
@test all(l -> operands(l)[1] isa LLVM.GetElementPtrInst, loads)
852+
@test (verify(mod); true)
853+
end
854+
855+
# a load with a non-extractvalue use can't be fully eliminated, so it is left alone
856+
Context() do ctx
857+
ir = """
858+
define void @g({ i64, i64 }* %p, { i64, i64 }* %q, i64* %o) {
859+
entry:
860+
%agg = load { i64, i64 }, { i64, i64 }* %p, align 8
861+
%a = extractvalue { i64, i64 } %agg, 0
862+
store { i64, i64 } %agg, { i64, i64 }* %q, align 8
863+
store i64 %a, i64* %o, align 8
864+
ret void
865+
}
866+
"""
867+
mod = parse(LLVM.Module, ir)
868+
insts() = [i for f in functions(mod) for bb in blocks(f) for i in instructions(bb)]
869+
@test !GPUCompiler.split_aggregate_loads!(mod)
870+
@test count(is_aggregate_load, insts()) == 1
871+
@test (verify(mod); true)
872+
end
873+
end
874+
811875
end

0 commit comments

Comments
 (0)