Skip to content

Commit 8d86015

Browse files
authored
Merge pull request #211 from omlins/memoptcall
Precompute all memopt args
2 parents 708587b + 2f0587e commit 8d86015

3 files changed

Lines changed: 126 additions & 78 deletions

File tree

src/ParallelKernel/parallel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
1515
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.
1616
1717
!!! note "Runtime hardware selection"
18-
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
18+
When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
1919
2020
# Arguments
2121
- `kernelcall`: a call to a kernel that is declared parallel.

src/parallel.jl

Lines changed: 108 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
3838
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.
3939
4040
!!! note "Runtime hardware selection"
41-
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
41+
When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
4242
4343
# Arguments
4444
- `kernelcall`: a call to a kernel that is declared parallel.
@@ -191,7 +191,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
191191
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
192192
else
193193
if isempty(posargs)
194-
ranges = add_nb_parallel_indices_check(:(ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
194+
ranges = :(ParallelStencil.compute_parallel_ranges(Val(($(create_metadata_call(configcall))).nb_parallel_indices), $(configcall.args[2:end]...)))
195195
ParallelKernel.parallel(caller, ranges, backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
196196
else
197197
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
@@ -361,67 +361,42 @@ end
361361

362362
## @PARALLEL CALL FUNCTIONS
363363

364-
@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
365-
terms = [:(
366-
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
367-
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
368-
sizeof(T)
369-
) for A in optvars]
370-
if isempty(terms)
371-
return :(0)
372-
elseif length(terms) == 1
373-
return terms[1]
374-
else
375-
return Expr(:call, :+, terms...)
376-
end
377-
end
378-
379-
@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
380-
if is_parallel_kernel
381-
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
382-
else
383-
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
384-
end
385-
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
386-
return quote
387-
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
388-
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
389-
if nb_dims_match isa Bool
390-
nb_dims_match || $errorcall
391-
end
392-
$range_expr
393-
end
394-
end
395-
396-
function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
364+
function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
397365
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
398366
package = get_package(caller)
399367
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
400368
nthreads_max_memopt = determine_nthreads_max_memopt(package)
401369
configcall_kwarg_expr = :(configcall=$configcall)
402-
metadata_call = create_metadata_call(configcall)
403-
metadata_module = metadata_call
404-
loopdim = :($(metadata_module).loopdim)
405-
loopsizes = :($(metadata_module).loopsizes)
406-
stencilranges = :($(metadata_module).stencilranges)
407-
use_shmemhalos = :($(metadata_module).use_shmemhalos)
408-
use_any_shmem = :($(metadata_module).use_any_shmem)
409-
shmem_dim1 = :($(metadata_module).shmem_dim1)
410-
shmem_dim2 = :($(metadata_module).shmem_dim2)
411-
shmem_optvars = :($(metadata_module).shmem_optvars)
412-
shmem_spans = :($(metadata_module).shmem_spans)
413-
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
414-
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
415-
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
416370
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
417-
if get_nonconst_metadata(caller)
418-
A = gensym("A")
419-
shmem = :($use_any_shmem ? sum(($nthreads[$shmem_dim1] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][1])) * ($nthreads[$shmem_dim2] + $use_shmemhalos[$A] * ($(shmem_spans)[$A][2])) * sizeof($numbertype) for $A in $shmem_optvars) : 0)
371+
nblocks_var = gensym("nblocks")
372+
nthreads_var = gensym("nthreads")
373+
shmem_var = gensym("shmem")
374+
range_setup_exprs, range_arg = precompute_parallel_memopt_arg(ranges, "ranges")
375+
nthreads_nblocks_expr = :(ParallelStencil.compute_memopt_nthreads_nblocks(Val($metadata_expr.loopsizes), Val($metadata_expr.loopdim), Val($metadata_expr.stencilranges), $nthreads_x_max, $nthreads_max_memopt, $range_arg))
376+
shmem_expr = :(ParallelStencil.compute_memopt_shmem(Val($metadata_expr.shmem_optvars), Val($metadata_expr.use_shmemhalos), Val($metadata_expr.shmem_spans), Val($metadata_expr.shmem_dim1), Val($metadata_expr.shmem_dim2), $nthreads_var, $numbertype))
377+
if async
378+
return quote
379+
$(range_setup_exprs...)
380+
local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr
381+
local $shmem_var = $shmem_expr
382+
@parallel_async memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall
383+
end
420384
else
421-
shmem = :(ParallelStencil.compute_memopt_shmem(Val($shmem_optvars), Val($use_shmemhalos), Val($shmem_spans), Val($shmem_dim1), Val($shmem_dim2), $nthreads, $numbertype))
385+
return quote
386+
$(range_setup_exprs...)
387+
local $nblocks_var, $nthreads_var = $nthreads_nblocks_expr
388+
local $shmem_var = $shmem_expr
389+
@parallel memopt=false $configcall_kwarg_expr $range_arg $nblocks_var $nthreads_var shmem=$shmem_var $(backend_kwargs_expr...) $kernelcall
390+
end
422391
end
423-
if (async) return :(@parallel_async memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: the package and numbertype will have to be passed here further once supported as kwargs
424-
else return :(@parallel memopt=false $configcall_kwarg_expr $ranges $nblocks $nthreads shmem=$shmem $(backend_kwargs_expr...) $kernelcall) #TODO: ...
392+
end
393+
394+
function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
395+
metadata_call = create_metadata_call(configcall)
396+
metadata_var = gensym("metadata")
397+
quote
398+
local $metadata_var = $metadata_call
399+
$(parallel_call_memopt(caller, metadata_var, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall))
425400
end
426401
end
427402

@@ -430,16 +405,13 @@ function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_e
430405
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
431406
nthreads_max_memopt = determine_nthreads_max_memopt(package)
432407
metadata_call = create_metadata_call(configcall)
433-
metadata_module = metadata_call
434-
loopdim = :($(metadata_module).loopdim)
435-
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
436-
if get_nonconst_metadata(caller)
437-
ranges = add_nb_parallel_indices_check(:( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...))), configcall)
438-
else
439-
nb_parallel_indices = :($(metadata_module).nb_parallel_indices)
440-
ranges = :(ParallelStencil.compute_memopt_ranges(Val($is_parallel_kernel), Val($nb_parallel_indices), Val($loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...)))
408+
metadata_var = gensym("metadata")
409+
ranges_var = gensym("ranges")
410+
quote
411+
local $metadata_var = $metadata_call
412+
local $ranges_var = ParallelStencil.compute_memopt_ranges(Val($metadata_var.is_parallel_kernel), Val($metadata_var.nb_parallel_indices), Val($metadata_var.loopdim), $nthreads_x_max, $nthreads_max_memopt, $(configcall.args[2:end]...))
413+
$(parallel_call_memopt(caller, metadata_var, ranges_var, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall))
441414
end
442-
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
443415
end
444416

445417

@@ -469,7 +441,7 @@ function compute_loopsize(package::Symbol)
469441
end
470442

471443

472-
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS
444+
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS, SHARED MEMORY SIZE AND RANGES
473445

474446
function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
475447
maxsize = promote_maxsize(maxsize)
@@ -497,6 +469,76 @@ function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args...
497469
end
498470

499471

472+
@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
473+
terms = [:(
474+
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
475+
(nthreads[$shmem_dim2] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[2])) *
476+
sizeof(T)
477+
) for A in optvars]
478+
if isempty(terms)
479+
return :(0)
480+
elseif length(terms) == 1
481+
return terms[1]
482+
else
483+
return Expr(:call, :+, terms...)
484+
end
485+
end
486+
487+
@generated function compute_memopt_nthreads_nblocks(::Val{loopsizes}, ::Val{loopdim}, ::Val{stencilranges}, nthreads_x_max, nthreads_max_memopt, ranges) where {loopsizes, loopdim, stencilranges}
488+
return quote
489+
maxsize = cld.(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), $loopsizes)
490+
nthreads = ParallelStencil.compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, $loopdim, $stencilranges)
491+
nblocks = ParallelStencil.ParallelKernel.compute_nblocks(maxsize, nthreads)
492+
(nblocks, nthreads)
493+
end
494+
end
495+
496+
@generated function check_nb_parallel_indices(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices}
497+
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
498+
return quote
499+
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
500+
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
501+
if nb_dims_match isa Bool
502+
nb_dims_match || $errorcall
503+
end
504+
nothing
505+
end
506+
end
507+
508+
@generated function compute_parallel_ranges(::Val{nb_parallel_indices}, args...) where {nb_parallel_indices}
509+
return quote
510+
ParallelStencil.check_nb_parallel_indices(Val($nb_parallel_indices), args...)
511+
ParallelStencil.ParallelKernel.get_ranges(args...)
512+
end
513+
end
514+
515+
@generated function compute_memopt_ranges(::Val{is_parallel_kernel}, ::Val{nb_parallel_indices}, ::Val{loopdim}, nthreads_x_max, nthreads_max_memopt, args...) where {is_parallel_kernel, nb_parallel_indices, loopdim}
516+
if is_parallel_kernel
517+
range_expr = :(ParallelStencil.get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, $loopdim, args...))
518+
else
519+
range_expr = :(ParallelStencil.ParallelKernel.get_ranges(args...))
520+
end
521+
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
522+
return quote
523+
nb_input_dims = ParallelStencil.get_nb_input_dims(args...)
524+
nb_dims_match = (nb_input_dims == $nb_parallel_indices)
525+
if nb_dims_match isa Bool
526+
nb_dims_match || $errorcall
527+
end
528+
$range_expr
529+
end
530+
end
531+
532+
function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString)
533+
if isa(arg, Symbol)
534+
return Expr[], arg
535+
else
536+
arg_var = gensym(prefix)
537+
return [:(local $arg_var = $arg)], arg_var
538+
end
539+
end
540+
541+
500542
## FUNCTIONS TO DEAL WITH MASKS (@WITHIN) AND INDICES
501543

502544
is_splatarg(x) = isa(x,Expr) && (x.head == :...)
@@ -705,10 +747,7 @@ end
705747

706748
function add_nb_parallel_indices_check(ranges::Union{Symbol,Expr}, configcall::Expr)
707749
metadata_call = create_metadata_call(configcall)
708-
nb_parallel_indices = :(($metadata_call).nb_parallel_indices)
709-
nb_input_dims = :(ParallelStencil.get_nb_input_dims($(configcall.args[2:end]...)))
710-
errorcall = :(ParallelStencil.@ArgumentError(ParallelStencil.ERRMSG_AUTOMATIC_RANGES_PARALLEL))
711-
return :(($nb_input_dims != $nb_parallel_indices && $errorcall; $ranges))
750+
return :(ParallelStencil.check_nb_parallel_indices(Val(($metadata_call).nb_parallel_indices), $(configcall.args[2:end]...)); $ranges)
712751
end
713752

714753
get_nb_input_dims(args...) = maximum((get_nb_input_dims(arg) for arg in args); init=1)

0 commit comments

Comments
 (0)