Skip to content

Commit 0632a9c

Browse files
christiangnrdmaleadtclaude
authored
Metal: add sincos parameter attributes (#762)
Co-authored-by: Tim Besard <tim.besard@gmail.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent a912d8e commit 0632a9c

2 files changed

Lines changed: 68 additions & 12 deletions

File tree

src/metal.jl

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,50 +1706,68 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod
17061706
isdeclaration(f) || continue
17071707
fn = LLVM.name(f)
17081708

1709-
attrs = function_attributes(f)
1710-
function add_attributes(names...)
1709+
fn_attrs = function_attributes(f)
1710+
function add_fn_attributes(names...)
17111711
for name in names
17121712
if LLVM.version() >= v"16" && name in ["argmemonly", "inaccessiblememonly",
17131713
"inaccessiblemem_or_argmemonly",
17141714
"readnone", "readonly", "writeonly"]
17151715
# XXX: workaround for changes from https://reviews.llvm.org/D135780
17161716
continue
17171717
end
1718-
push!(attrs, EnumAttribute(name, 0))
1718+
push!(fn_attrs, EnumAttribute(name, 0))
1719+
end
1720+
changed = true
1721+
end
1722+
1723+
function add_param_attributes(idx, names...)
1724+
param_attrs = parameter_attributes(f, idx)
1725+
for name in names
1726+
if name == "nocapture" && LLVM.version() >= v"21"
1727+
# `nocapture` was replaced by `captures(none)` in LLVM 21 (an
1728+
# integer-valued IntAttr, value 0 == CaptureInfo::none()).
1729+
push!(param_attrs, EnumAttribute("captures", 0))
1730+
else
1731+
push!(param_attrs, EnumAttribute(name, 0))
1732+
end
17191733
end
17201734
changed = true
17211735
end
17221736

17231737
# synchronization
17241738
if fn == "air.wg.barrier" || fn == "air.simdgroup.barrier"
1725-
add_attributes("nounwind", "mustprogress", "convergent", "willreturn")
1739+
add_fn_attributes("nounwind", "mustprogress", "convergent", "willreturn")
1740+
1741+
# sincos
1742+
elseif match(r"^air.(fast_)?sincos", fn) !== nothing
1743+
add_param_attributes(2, "nocapture", "writeonly")
17261744

17271745
# atomics
17281746
elseif match(r"air.atomic.(local|global).load", fn) !== nothing
17291747
# TODO: "memory(argmem: read)" on LLVM 16+
1730-
add_attributes("argmemonly", "readonly", "nounwind")
1748+
add_fn_attributes("argmemonly", "readonly", "nounwind")
17311749
elseif match(r"air.atomic.(local|global).store", fn) !== nothing
17321750
# TODO: "memory(argmem: write)" on LLVM 16+
1733-
add_attributes("argmemonly", "writeonly", "nounwind")
1751+
add_fn_attributes("argmemonly", "writeonly", "nounwind")
17341752
elseif match(r"air.atomic.(local|global).(xchg|cmpxchg)", fn) !== nothing
17351753
# TODO: "memory(argmem: readwrite)" on LLVM 16+
1736-
add_attributes("argmemonly", "nounwind")
1754+
add_fn_attributes("argmemonly", "nounwind")
17371755
elseif match(r"^air.atomic.(local|global).(add|sub|min|max|and|or|xor)", fn) !== nothing
17381756
# TODO: "memory(argmem: readwrite)" on LLVM 16+
1739-
add_attributes("argmemonly", "nounwind")
1757+
add_fn_attributes("argmemonly", "nounwind")
17401758

17411759
# simdgroup
17421760
elseif match(r"air.simdgroup_matrix_8x8_multiply_accumulate", fn) !== nothing
1743-
add_attributes("convergent", "mustprogress", "nounwind", "willreturn")
1761+
add_fn_attributes("convergent", "mustprogress", "nounwind", "willreturn")
17441762
elseif match(r"air.simdgroup_matrix_8x8_load", fn) !== nothing
1745-
add_attributes("convergent", "mustprogress", "nofree", "nounwind", "readonly", "willreturn")
1763+
add_fn_attributes("convergent", "mustprogress", "nofree", "nounwind", "readonly", "willreturn")
17461764
elseif match(r"air.simdgroup_matrix_8x8_store", fn) !== nothing
1747-
add_attributes("convergent", "mustprogress", "nounwind", "willreturn", "writeonly")
1765+
add_fn_attributes("convergent", "mustprogress", "nounwind", "willreturn", "writeonly")
17481766

17491767
# simd permute
17501768
elseif match(r"air.simd_(ballot|all|vote_all|any|vote_any|shuffle|shuffle_xor|shuffle_down|\
17511769
shuffle_up|shuffle_and_fill_down|shuffle_and_fill_up)", fn) !== nothing
1752-
add_attributes("convergent", "mustprogress", "nounwind", "willreturn")
1770+
add_fn_attributes("convergent", "mustprogress", "nounwind", "willreturn")
17531771
end
17541772
end
17551773

test/metal.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ end
159159
end
160160
end
161161

162+
@testset "sincos intrinsics" begin
163+
mod = @eval module $(gensym())
164+
using LLVM.Interop
165+
166+
function kernel(x)
167+
c = Ref{Float32}()
168+
s = @typed_ccall("air.sincos.f32", llvmcall, Float32,
169+
(Float32, Ptr{Float32}), x, c)
170+
return s + c[]
171+
end
172+
173+
function kernel_fast(x)
174+
c = Ref{Float32}()
175+
s = @typed_ccall("air.fast_sincos.f32", llvmcall, Float32,
176+
(Float32, Ptr{Float32}), x, c)
177+
return s + c[]
178+
end
179+
end
180+
181+
# the output argument of sincos should be annotated `nocapture writeonly`
182+
# (`nocapture` having been replaced by `captures(none)` in LLVM 21)
183+
@test @filecheck begin
184+
@check "declare float @air.sincos.f32"
185+
@check_same cond=typed_ptrs "(float, float* nocapture writeonly)"
186+
@check_same cond=(opaque_ptrs && LLVM.version() < v"21") "(float, ptr nocapture writeonly)"
187+
@check_same cond=(LLVM.version() >= v"21") "(float, ptr writeonly captures(none))"
188+
Metal.code_llvm(mod.kernel, Tuple{Float32}; dump_module=true)
189+
end
190+
191+
@test @filecheck begin
192+
@check "declare float @air.fast_sincos.f32"
193+
@check_same cond=typed_ptrs "(float, float* nocapture writeonly)"
194+
@check_same cond=(opaque_ptrs && LLVM.version() < v"21") "(float, ptr nocapture writeonly)"
195+
@check_same cond=(LLVM.version() >= v"21") "(float, ptr writeonly captures(none))"
196+
Metal.code_llvm(mod.kernel_fast, Tuple{Float32}; dump_module=true)
197+
end
198+
end
199+
162200
@testset "unsupported type detection" begin
163201
mod = @eval module $(gensym())
164202
function kernel(ptr)

0 commit comments

Comments
 (0)