@@ -178,6 +178,79 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
178178 errors
179179end
180180
181+ # aggregate load splitting (JuliaGPU/Metal.jl#792)
182+ #
183+ # Julia can emit a single by-value `load` of a large, deeply-nested aggregate (e.g. an
184+ # Oceananigans `RectilinearGrid` passed by reference) that feeds several `extractvalue`s.
185+ # Apple's AGX back-end crashes during native-code generation when lowering such a wide
186+ # aggregate load. We rewrite each `extractvalue (load p), idxs` into a narrow field load
187+ # `load (inbounds_gep p, 0, idxs)` and delete the now-dead wide load, so only per-field
188+ # loads reach the back-end.
189+ #
190+ # This is exactly LLVM's own `extractvalue (load)` -> `load (gep)` fold in InstCombine
191+ # (visitExtractValueInst), with one difference: InstCombine guards it on the load having a
192+ # single use, declining multiply-used loads as "a struct with padding [where] we don't want
193+ # to do the transformation as it loses padding knowledge". That guard is a codegen
194+ # heuristic (one wide load can be cheaper than N field loads), not a correctness condition,
195+ # so dropping it is sound — and necessary here, since the crashing pattern is precisely a
196+ # multiply-used aggregate load that InstCombine therefore leaves intact.
197+ #
198+ # Restricted to simple (non-volatile, non-atomic) loads all of whose users are
199+ # `extractvalue` — the by-value-aggregate-argument pattern — so the wide load can be fully
200+ # eliminated. Like LLVM's fold, the field loads take their type's natural (ABI) alignment,
201+ # valid because the aggregate base load is at least that aligned, and AA metadata is copied
202+ # from the wide load (sound for the narrower field loads it subsumes).
203+ function split_aggregate_loads! (mod:: LLVM.Module )
204+ aa_kinds = (LLVM. MD_tbaa, LLVM. MD_tbaa_struct, LLVM. MD_alias_scope, LLVM. MD_noalias)
205+ changed = false
206+ for f in functions (mod)
207+ isdeclaration (f) && continue
208+ worklist = LLVM. LoadInst[]
209+ for bb in blocks (f), inst in instructions (bb)
210+ inst isa LLVM. LoadInst || continue
211+ T = value_type (inst)
212+ (T isa LLVM. StructType || T isa LLVM. ArrayType) || continue
213+ iszero (LLVM. API. LLVMGetVolatile (inst)) || continue
214+ LLVM. API. LLVMGetOrdering (inst) == LLVM. API. LLVMAtomicOrderingNotAtomic || continue
215+ uselist = collect (uses (inst))
216+ isempty (uselist) && continue
217+ all (u -> user (u) isa LLVM. ExtractValueInst, uselist) || continue
218+ push! (worklist, inst)
219+ end
220+ for ld in worklist
221+ ptr = operands (ld)[1 ]
222+ aggty = value_type (ld)
223+ md = metadata (ld)
224+ i32 = LLVM. Int32Type ()
225+ @dispose builder= IRBuilder () begin
226+ # build the field loads at the wide load's location, not the extractvalue's
227+ position! (builder, ld)
228+ for u in collect (uses (ld))
229+ ev = user (u):: LLVM.ExtractValueInst
230+ n = LLVM. API. LLVMGetNumIndices (ev)
231+ idxptr = LLVM. API. LLVMGetIndices (ev)
232+ # extractvalue has integer indices; getelementptr takes Values, prefixed
233+ # with an i32 0 to step through the pointer to the aggregate's first element.
234+ gepidx = LLVM. Value[ConstantInt (i32, 0 )]
235+ for k in 1 : n
236+ push! (gepidx, ConstantInt (i32, unsafe_load (idxptr, k)))
237+ end
238+ gep = inbounds_gep! (builder, aggty, ptr, gepidx)
239+ fieldload = load! (builder, value_type (ev), gep)
240+ for kind in aa_kinds
241+ haskey (md, kind) && (metadata (fieldload)[kind] = md[kind])
242+ end
243+ replace_uses! (ev, fieldload)
244+ erase! (ev)
245+ end
246+ end
247+ erase! (ld)
248+ changed = true
249+ end
250+ end
251+ return changed
252+ end
253+
181254function finish_ir! (@nospecialize (job:: CompilerJob{MetalCompilerTarget} ), mod:: LLVM.Module ,
182255 entry:: LLVM.Function )
183256 entry_fn = LLVM. name (entry)
@@ -197,6 +270,10 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
197270 # reporters) load from the constant space rather than crashing Metal's validator.
198271 propagate_argument_address_spaces! (mod)
199272
273+ # split multiply-used by-value aggregate loads into narrow per-field loads; the AGX
274+ # back-end crashes during native codegen on wide aggregate loads (#792).
275+ split_aggregate_loads! (mod)
276+
200277 # propagate specific address spaces through addrspacecast chains introduced
201278 # by the rewrites above, so that loads/stores happen in the right address
202279 # space (e.g. constant globals in addrspace 2 rather than via a cast to 0,
0 commit comments