@@ -660,24 +660,53 @@ end
660660
661661# byval lowering
662662#
663- # some back-ends don't support byval, or support it badly, so lower it eagerly ourselves
664- # https://reviews.llvm.org/D79744
663+ # the NVPTX back-end accesses byval kernel arguments directly in parameter space when all
664+ # uses are simple loads (possibly via GEPs), but otherwise falls back to copying the
665+ # argument to local memory during machine-code generation, too late for the optimizer to
666+ # clean up. eagerly materialize such arguments ourselves instead, leaving the directly-
667+ # accessible ones to the back-end. (we historically lowered all byval arguments, working
668+ # around bad codegen in old back-ends; see #92 and https://reviews.llvm.org/D79744.)
669+
670+ # can the back-end service all uses of this argument straight from parameter space?
671+ function loads_parameter_directly (param:: LLVM.Argument )
672+ worklist = LLVM. Value[param]
673+ while ! isempty (worklist)
674+ value = popfirst! (worklist)
675+ for use in uses (value)
676+ inst = user (use)
677+ if inst isa LLVM. LoadInst
678+ # serviced by ld.param
679+ elseif inst isa LLVM. GetElementPtrInst ||
680+ inst isa LLVM. BitCastInst ||
681+ inst isa LLVM. AddrSpaceCastInst
682+ push! (worklist, inst)
683+ else
684+ # phis, selects, stores, calls (e.g. memcpy), etc. make the back-end
685+ # fall back to a local copy
686+ return false
687+ end
688+ end
689+ end
690+ return true
691+ end
692+
665693function lower_byval (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module , f:: LLVM.Function )
666694 ft = function_type (f)
667695 @tracepoint " lower byval" begin
668696
669- # find the byval parameters
697+ # find the byval parameters that need lowering
670698 byval = BitVector (undef, length (parameters (ft)))
671699 types = Vector {LLVMType} (undef, length (parameters (ft)))
672- for i in 1 : length (byval )
700+ for (i, param) in enumerate ( parameters (f) )
673701 byval[i] = false
674702 for attr in collect (parameter_attributes (f, i))
675703 if kind (attr) == kind (TypeAttribute (" byval" , LLVM. VoidType ()))
676- byval[i] = true
704+ byval[i] = ! loads_parameter_directly (param)
677705 types[i] = value (attr)
678706 end
679707 end
680708 end
709+ any (byval) || return f
681710
682711 # fixup metadata
683712 #
@@ -778,6 +807,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
778807end
779808
780809
810+
781811# kernel state arguments
782812#
783813# to facilitate passing stateful information to kernels without having to recompile, e.g.,
@@ -1183,7 +1213,7 @@ end
11831213#
11841214# the kernel state argument is always passed by value to avoid codegen issues with byval.
11851215# some back-ends however do not support passing kernel arguments by value, so this pass
1186- # serves to convert that argument (and is conceptually the inverse of `lower_byval`) .
1216+ # serves to convert that argument.
11871217function kernel_state_to_reference! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
11881218 f:: LLVM.Function )
11891219 ft = function_type (f)
0 commit comments