Skip to content

Commit f42d1cf

Browse files
maleadtclaude
andauthored
Metal: Split pointer conversion from IPO AS inference pass (#827)
Also create a shared helper for a common rewriting operation. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 8bd3aa7 commit f42d1cf

6 files changed

Lines changed: 322 additions & 422 deletions

File tree

src/gcn.jl

Lines changed: 22 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -115,74 +115,32 @@ function add_kernarg_address_spaces!(
115115
end
116116
needs_rewrite || return f
117117

118-
# generate the new function type with constant address space on byref params
119-
new_types = LLVMType[]
120-
for (i, param) in enumerate(parameters(ft))
121-
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
122-
if supports_typed_pointers(context())
123-
push!(new_types, LLVM.PointerType(eltype(param), #=constant=# 4))
124-
else
125-
push!(new_types, LLVM.PointerType(#=constant=# 4))
126-
end
127-
else
128-
push!(new_types, param)
129-
end
130-
end
131-
new_ft = LLVM.FunctionType(return_type(ft), new_types)
132-
new_f = LLVM.Function(mod, "", new_ft)
133-
linkage!(new_f, linkage(f))
134-
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
135-
LLVM.name!(new_arg, LLVM.name(arg))
136-
end
137-
138-
# insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR
139-
# (which expects flat pointers) continues to work. The AMDGPU backend's
140-
# AMDGPULowerKernelArguments traces these casts and produces s_load.
141-
new_args = LLVM.Value[]
142-
@dispose builder=IRBuilder() begin
143-
entry_bb = BasicBlock(new_f, "conversion")
144-
position!(builder, entry_bb)
145-
146-
for (i, param) in enumerate(parameters(ft))
147-
if byref_mask[i] && param isa LLVM.PointerType && addrspace(param) == 0
148-
cast = addrspacecast!(builder, parameters(new_f)[i], param)
149-
push!(new_args, cast)
150-
else
151-
push!(new_args, parameters(new_f)[i])
152-
end
153-
end
154-
155-
# clone the original function body
156-
value_map = Dict{LLVM.Value, LLVM.Value}(
157-
param => new_args[i] for (i, param) in enumerate(parameters(f))
158-
)
159-
value_map[f] = new_f
160-
clone_into!(
161-
new_f, f; value_map,
162-
changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
163-
)
164-
165-
# fall through from conversion block to cloned entry
166-
br!(builder, blocks(new_f)[2])
167-
end
168-
169-
# copy parameter attributes AFTER clone_into!, because CloneFunctionInto
170-
# overwrites all attributes via setAttributes. For byref params, the VMap
171-
# maps old args to addrspacecast instructions (not Arguments), so LLVM's
172-
# attribute remapping silently drops them. We must re-add them here.
173-
for i in 1:length(parameters(ft))
118+
# generate the new function type with constant address space on byref flat-pointer params
119+
param_types = parameters(ft)
120+
flat_byref(i) = byref_mask[i] && param_types[i] isa LLVM.PointerType && addrspace(param_types[i]) == 0
121+
new_types = Union{Nothing,LLVMType}[
122+
flat_byref(i) ? (supports_typed_pointers(context()) ?
123+
LLVM.PointerType(eltype(param_types[i]), #=constant=# 4) :
124+
LLVM.PointerType(#=constant=# 4)) :
125+
nothing
126+
for i in 1:length(param_types)]
127+
128+
# insert addrspacecasts from kernarg (4) back to flat (0) so that the cloned IR (which expects
129+
# flat pointers) continues to work; the AMDGPU backend's AMDGPULowerKernelArguments traces these
130+
# casts and produces s_load.
131+
new_f = clone_with_converted_args!(mod, f, new_types,
132+
(builder, param, i) -> addrspacecast!(builder, param, param_types[i]))
133+
134+
# copy parameter attributes AFTER clone_into!, because CloneFunctionInto overwrites all
135+
# attributes via setAttributes. For byref params, the VMap maps old args to addrspacecast
136+
# instructions (not Arguments), so LLVM's attribute remapping silently drops them.
137+
for i in 1:length(param_types)
174138
for attr in collect(parameter_attributes(f, i))
175139
push!(parameter_attributes(new_f, i), attr)
176140
end
177141
end
178142

179-
# replace the old function
180-
fn = LLVM.name(f)
181-
prune_constexpr_uses!(f)
182-
@assert isempty(uses(f))
183-
replace_metadata_uses!(f, new_f)
184-
erase!(f)
185-
LLVM.name!(new_f, fn)
143+
replace_function!(f, new_f)
186144

187145
# clean up the extra conversion block
188146
@dispose pb=NewPMPassBuilder() begin
@@ -192,7 +150,7 @@ function add_kernarg_address_spaces!(
192150
run!(pb, mod)
193151
end
194152

195-
return functions(mod)[fn]
153+
return new_f
196154
end
197155

198156

src/irgen.jl

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,52 +1202,12 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
12021202
end
12031203

12041204
@tracepoint "kernel state to reference" begin
1205-
# generate the new function type & definition
1206-
new_types = LLVM.LLVMType[]
1207-
# convert the first parameter (kernel state) to a pointer
1208-
push!(new_types, LLVM.PointerType(T_state))
1209-
# keep all other parameters as-is
1210-
for i in 2:length(parameters(ft))
1211-
push!(new_types, parameters(ft)[i])
1212-
end
1213-
1214-
new_ft = LLVM.FunctionType(return_type(ft), new_types)
1215-
new_f = LLVM.Function(mod, "", new_ft)
1216-
linkage!(new_f, linkage(f))
1217-
1218-
# name the parameters
1205+
# turn the leading kernel-state value parameter into a pointer the body loads from
1206+
new_types = Union{Nothing,LLVM.LLVMType}[
1207+
i == 1 ? LLVM.PointerType(T_state) : nothing for i in 1:length(parameters(ft))]
1208+
new_f = clone_with_converted_args!(mod, f, new_types,
1209+
(builder, param, i) -> load!(builder, T_state, param, "state"))
12191210
LLVM.name!(parameters(new_f)[1], "state_ptr")
1220-
for (i, (arg, new_arg)) in enumerate(zip(parameters(f)[2:end], parameters(new_f)[2:end]))
1221-
LLVM.name!(new_arg, LLVM.name(arg))
1222-
end
1223-
1224-
# emit IR performing the "conversions"
1225-
new_args = LLVM.Value[]
1226-
@dispose builder=IRBuilder() begin
1227-
entry = BasicBlock(new_f, "conversion")
1228-
position!(builder, entry)
1229-
1230-
# load the kernel state value from the pointer
1231-
state_val = load!(builder, T_state, parameters(new_f)[1], "state")
1232-
push!(new_args, state_val)
1233-
1234-
# all other arguments are passed through directly
1235-
for i in 2:length(parameters(new_f))
1236-
push!(new_args, parameters(new_f)[i])
1237-
end
1238-
1239-
# map the arguments
1240-
value_map = Dict{LLVM.Value, LLVM.Value}(
1241-
param => new_args[i] for (i,param) in enumerate(parameters(f))
1242-
)
1243-
value_map[f] = new_f
1244-
1245-
clone_into!(new_f, f; value_map,
1246-
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
1247-
1248-
# fall through
1249-
br!(builder, blocks(new_f)[2])
1250-
end
12511211

12521212
# set the attributes for the state pointer parameter
12531213
attrs = parameter_attributes(new_f, 1)
@@ -1262,11 +1222,7 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
12621222
push!(attrs, EnumAttribute("readonly", 0))
12631223

12641224
# remove the old function
1265-
fn = LLVM.name(f)
1266-
@assert isempty(uses(f))
1267-
replace_metadata_uses!(f, new_f)
1268-
erase!(f)
1269-
LLVM.name!(new_f, fn)
1225+
replace_function!(f, new_f)
12701226

12711227
# minimal optimization
12721228
@dispose pb=NewPMPassBuilder() begin

0 commit comments

Comments
 (0)