Skip to content

Commit 708587b

Browse files
authored
Merge pull request #210 from omlins/memopt_shmem
Improve shared memory handling and reduce launch overhead
2 parents f7cc251 + 8b9f4fd commit 708587b

7 files changed

Lines changed: 122 additions & 31 deletions

src/memopt.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
9090
loopstart = minimum(values(loopentrys))
9191
loopend = loopsize
9292
use_any_shmem = any(values(use_shmems))
93+
shmem_optvars = tuple((A for A in optvars if use_shmems[A])...)::Tuple{Vararg{Symbol}}
9394
shmem_index_groups = define_shmem_index_groups(hx1s, hy1s, hx2s, hy2s, optvars, use_shmems, loopdim)
9495
shmem_vars = define_shmem_vars(oz_maxs, hx1s, hy1s, hx2s, hy2s, optvars, indices, use_shmems, use_shmem_xs, use_shmem_ys, shmem_index_groups, use_shmemhalos, use_shmemindices, loopdim)
9596
shmem_exprs = define_shmem_exprs(shmem_vars, loopdim)
@@ -469,7 +470,7 @@ $(( # NOTE: the if statement is not needed here as we only deal with registers
469470
else
470471
@ArgumentError("memopt: only loopdim=3 is currently supported.")
471472
end
472-
store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, loopdim, loopsize, optranges, use_shmemhalos)
473+
store_metadata(metadata_module, is_parallel_kernel, caller, offset_mins, offset_maxs, offsets, optvars, shmem_optvars, use_any_shmem, loopdim, loopsize, optranges, use_shmemhalos)
473474
# @show QuoteNode(ParallelKernel.simplify_varnames!(ParallelKernel.remove_linenumbernodes!(deepcopy(body))))
474475
return body
475476
end
@@ -1019,10 +1020,15 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false)
10191020
end
10201021
end
10211022

1022-
function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
1023+
function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
10231024
memopt = true
10241025
nonconst_metadata = get_nonconst_metadata(caller)
10251026
stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars)
1027+
use_shmemhalos = NamedTuple(A => use_shmemhalos[A] for A in optvars)
1028+
loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1)
1029+
shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2
1030+
shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3
1031+
shmem_spans = NamedTuple(A => (length(stencilranges[A][shmem_dim1]) - 1, length(stencilranges[A][shmem_dim2]) - 1) for A in optvars)
10261032
if nonconst_metadata
10271033
storeexpr = quote
10281034
is_parallel_kernel = $is_parallel_kernel
@@ -1031,9 +1037,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle
10311037
stencilranges = $stencilranges
10321038
offsets = $offsets
10331039
optvars = $optvars
1040+
shmem_optvars = $shmem_optvars
1041+
shmem_spans = $shmem_spans
10341042
loopdim = $loopdim
10351043
loopsize = $loopsize
1044+
loopsizes = $loopsizes
1045+
shmem_dim1 = $shmem_dim1
1046+
shmem_dim2 = $shmem_dim2
10361047
optranges = $optranges
1048+
use_any_shmem = $use_any_shmem
10371049
use_shmemhalos = $use_shmemhalos
10381050
end
10391051
else
@@ -1044,9 +1056,15 @@ function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, calle
10441056
const stencilranges = $stencilranges
10451057
const offsets = $offsets
10461058
const optvars = $optvars
1059+
const shmem_optvars = $shmem_optvars
1060+
const shmem_spans = $shmem_spans
10471061
const loopdim = $loopdim
10481062
const loopsize = $loopsize
1063+
const loopsizes = $loopsizes
1064+
const shmem_dim1 = $shmem_dim1
1065+
const shmem_dim2 = $shmem_dim2
10491066
const optranges = $optranges
1067+
const use_any_shmem = $use_any_shmem
10501068
const use_shmemhalos = $use_shmemhalos
10511069
end
10521070
end

src/parallel.jl

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,38 @@ end
361361

362362
## @PARALLEL CALL FUNCTIONS
363363

364+
@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
365+
terms = [:(
366+
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
367+
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
368+
sizeof(T)
369+
) for A in optvars]
370+
if isempty(terms)
371+
return :(0)
372+
elseif length(terms) == 1
373+
return terms[1]
374+
else
375+
return Expr(:call, :+, terms...)
376+
end
377+
end
378+
379+
@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
380+
if is_parallel_kernel
381+
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
382+
else
383+
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
384+
end
385+
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
386+
return quote
387+
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
388+
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
389+
if nb_dims_match isa Bool
390+
nb_dims_match || $errorcall
391+
end
392+
$range_expr
393+
end
394+
end
395+
364396
function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
365397
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
366398
package = get_package(caller)
@@ -369,20 +401,25 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
369401
configcall_kwarg_expr = :(configcall=$configcall)
370402
metadata_call = create_metadata_call(configcall)
371403
metadata_module = metadata_call
404+
loopdim = :($(metadata_module).loopdim)
405+
loopsizes = :($(metadata_module).loopsizes)
372406
stencilranges = :($(metadata_module).stencilranges)
373407
use_shmemhalos = :($(metadata_module).use_shmemhalos)
374-
optvars = :($(metadata_module).optvars)
375-
loopdim = :($(metadata_module).loopdim)
376-
loopsize = :($(metadata_module).loopsize)
377-
loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1))
408+
use_any_shmem = :($(metadata_module).use_any_shmem)
409+
shmem_dim1 = :($(metadata_module).shmem_dim1)
410+
shmem_dim2 = :($(metadata_module).shmem_dim2)
411+
shmem_optvars = :($(metadata_module).shmem_optvars)
412+
shmem_spans = :($(metadata_module).shmem_spans)
378413
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
379414
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
380415
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
381416
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
382-
dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
383-
dim2 = :(($loopdim==3) ? 2 : ($loopdim==2) ? 3 : 3) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
384-
A = gensym("A")
385-
shmem = :(sum(($nthreads[$dim1]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim1])-1))*($nthreads[$dim2]+$use_shmemhalos[$A]*(length($(stencilranges)[$A][$dim2])-1))*sizeof($numbertype) for $A in $optvars))
417+
if get_nonconst_metadata(caller)
418+
A = gensym("A")
419+
shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0)
420+
else
421+
shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype))
422+
end
386423
if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs
387424
else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ...
388425
end
@@ -396,7 +433,12 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e
396433
metadata_module = metadata_call
397434
loopdim = :($(metadata_module).loopdim)
398435
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
399-
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
436+
if get_nonconst_metadata(caller)
437+
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
438+
else
439+
nb_parallel_indices = :($(metadata_module).nb_parallel_indices)
440+
ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...)))
441+
end
400442
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
401443
end
402444

@@ -552,8 +594,10 @@ function create_metadata_function(kernel::Expr, metadata_module::Module) # NOTE:
552594
kernelname = get_name(kernel)
553595
functionname = get_meta_function(kernelname)
554596
metadata_function = set_name(metadata_function, functionname)
555-
set_body!(metadata_function, :(return $metadata_module))
556-
return metadata_function
597+
set_body!(metadata_function, quote
598+
return $metadata_module
599+
end)
600+
return :(@inline $metadata_function)
557601
end
558602

559603
function create_metadata_call(configcall::Expr)

test/test_FiniteDifferences1D.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ eval(:(
3636
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
3737
@testset "(padding=$__padding__)" begin
3838
@require !@is_initialized()
39-
@init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__)
39+
@init_parallel_stencil($package, $FloatDefault, 1, padding=__padding__, nonconst_metadata=true)
4040
@require @is_initialized()
4141
nx = (9,)
4242
A = @IField(nx, @rand);

test/test_FiniteDifferences2D.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ eval(:(
3636
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
3737
@testset "(padding=$__padding__)" begin
3838
@require !@is_initialized()
39-
@init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__)
39+
@init_parallel_stencil($package, $FloatDefault, 2, padding=__padding__, nonconst_metadata=true)
4040
@require @is_initialized()
4141
nxy = (9, 7)
4242
A = @IField(nxy, @rand);

test/test_FiniteDifferences3D.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ eval(:(
3636
$(interpolate(:__padding__, (false, package!=PKG_POLYESTER), :( #TODO: this needs to be restored to (false, true) when Polyester supports padding.
3737
@testset "(padding=$__padding__)" begin
3838
@require !@is_initialized()
39-
@init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__)
39+
@init_parallel_stencil($package, $FloatDefault, 3, padding=__padding__, nonconst_metadata=true)
4040
@require @is_initialized()
4141
nxyz = (9, 7, 8)
4242
A = @IField(nxyz, @rand)

test/test_kernel_language.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Base.retry_load_extensions()
3838
eval(:(
3939
@testset "$(basename(@__FILE__)) (package: $(nameof($package)))" begin
4040
@require !@is_initialized()
41-
@init_parallel_stencil($package, $FloatDefault, 3)
41+
@init_parallel_stencil($package, $FloatDefault, 3, nonconst_metadata=true)
4242
@require @is_initialized()
4343

4444
@testset "Pass-through macro mapping" begin

0 commit comments

Comments
 (0)