Skip to content

Commit 8adcb92

Browse files
authored
Merge pull request #836 from JuliaGPU/tb/ptx_llvm22
PTX: modernize compilation now that LLVM 22 is used
2 parents 0632a9c + 5fe5a1b commit 8adcb92

5 files changed

Lines changed: 169 additions & 343 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GPUCompiler"
22
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
3-
version = "1.18.0"
3+
version = "1.19.0"
44
authors = ["Tim Besard <tim.besard@gmail.com>"]
55

66
[workspace]

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
296296
# Does the target need to pass kernel arguments by value?
297297
pass_by_value(@nospecialize(job::CompilerJob)) = true
298298

299-
# Should the target use byref instead of byval+lower_byval for kernel arguments?
299+
# Should the target use byref instead of byval for kernel arguments?
300300
# When true, aggregate arguments are passed as pointers with the byref attribute,
301301
# allowing the backend to load fields directly from the argument memory (e.g. kernarg
302302
# segment on AMDGPU) instead of materializing the entire struct via first-class aggregates.

src/irgen.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -660,24 +660,53 @@ end
660660

661661
# byval lowering
662662
#
663-
# some back-ends don't support byval, or support it badly, so lower it eagerly ourselves
664-
# https://reviews.llvm.org/D79744
663+
# the NVPTX back-end accesses byval kernel arguments directly in parameter space when all
664+
# uses are simple loads (possibly via GEPs), but otherwise falls back to copying the
665+
# argument to local memory during machine-code generation, too late for the optimizer to
666+
# clean up. eagerly materialize such arguments ourselves instead, leaving the directly-
667+
# accessible ones to the back-end. (we historically lowered all byval arguments, working
668+
# around bad codegen in old back-ends; see #92 and https://reviews.llvm.org/D79744.)
669+
670+
# can the back-end service all uses of this argument straight from parameter space?
671+
function loads_parameter_directly(param::LLVM.Argument)
672+
worklist = LLVM.Value[param]
673+
while !isempty(worklist)
674+
value = popfirst!(worklist)
675+
for use in uses(value)
676+
inst = user(use)
677+
if inst isa LLVM.LoadInst
678+
# serviced by ld.param
679+
elseif inst isa LLVM.GetElementPtrInst ||
680+
inst isa LLVM.BitCastInst ||
681+
inst isa LLVM.AddrSpaceCastInst
682+
push!(worklist, inst)
683+
else
684+
# phis, selects, stores, calls (e.g. memcpy), etc. make the back-end
685+
# fall back to a local copy
686+
return false
687+
end
688+
end
689+
end
690+
return true
691+
end
692+
665693
function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
666694
ft = function_type(f)
667695
@tracepoint "lower byval" begin
668696

669-
# find the byval parameters
697+
# find the byval parameters that need lowering
670698
byval = BitVector(undef, length(parameters(ft)))
671699
types = Vector{LLVMType}(undef, length(parameters(ft)))
672-
for i in 1:length(byval)
700+
for (i, param) in enumerate(parameters(f))
673701
byval[i] = false
674702
for attr in collect(parameter_attributes(f, i))
675703
if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType()))
676-
byval[i] = true
704+
byval[i] = !loads_parameter_directly(param)
677705
types[i] = value(attr)
678706
end
679707
end
680708
end
709+
any(byval) || return f
681710

682711
# fixup metadata
683712
#
@@ -778,6 +807,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
778807
end
779808

780809

810+
781811
# kernel state arguments
782812
#
783813
# to facilitate passing stateful information to kernels without having to recompile, e.g.,
@@ -1183,7 +1213,7 @@ end
11831213
#
11841214
# the kernel state argument is always passed by value to avoid codegen issues with byval.
11851215
# some back-ends however do not support passing kernel arguments by value, so this pass
1186-
# serves to convert that argument (and is conceptually the inverse of `lower_byval`).
1216+
# serves to convert that argument.
11871217
function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
11881218
f::LLVM.Function)
11891219
ft = function_type(f)

0 commit comments

Comments
 (0)