Skip to content

Commit 6ebe602

Browse files
committed
Refactor parallel memory optimization functions and update documentation for kernel call handling
1 parent 2833f2c commit 6ebe602

1 file changed

Lines changed: 11 additions & 13 deletions

File tree

src/parallel.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Declare the `kernelcall` parallel. The kernel will automatically be called as re
3838
Automatic computation of `ranges` for `@parallel <kernelcall>` is only possible if the number of parallel indices used by the kernel is equal to the number of dimensions of the highest-dimensional input arrays. Otherwise, specify the `ranges` manually with `@parallel ranges=... <kernelcall>`.
3939
4040
!!! note "Runtime hardware selection"
41-
When KernelAbstractions is initialized, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
41+
When KernelAbstractions is chosen as the package for parallelization, this wrapper consults [`current_hardware`](@ref) to determine the runtime hardware target. The symbol defaults to `:cpu` and can be switched to select other targets via [`select_hardware`](@ref).
4242
4343
# Arguments
4444
- `kernelcall`: a call to a kernel that is declared parallel.
@@ -361,15 +361,6 @@ end
361361

362362
## @PARALLEL CALL FUNCTIONS
363363

364-
function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString)
365-
if isa(arg, Symbol)
366-
return Expr[], arg
367-
else
368-
arg_var = gensym(prefix)
369-
return [:(local $arg_var = $arg)], arg_var
370-
end
371-
end
372-
373364
function parallel_call_memopt(caller::Module, metadata_expr::Union{Symbol,Expr}, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
374365
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
375366
package = get_package(caller)
@@ -450,7 +441,7 @@ function compute_loopsize(package::Symbol)
450441
end
451442

452443

453-
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS
444+
## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS, SHARED MEMORY SIZE AND RANGES
454445

455446
function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
456447
maxsize = promote_maxsize(maxsize)
@@ -478,8 +469,6 @@ function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args...
478469
end
479470

480471

481-
## FUNCTIONS TO COMPUTE SHARED MEMORY SIZE AND RANGES FOR MEMOPT
482-
483472
@generated function compute_memopt_shmem(::Val{optvars}, ::Val{use_shmemhalos}, ::Val{shmem_spans}, ::Val{shmem_dim1}, ::Val{shmem_dim2}, nthreads, ::Type{T}) where {optvars, use_shmemhalos, shmem_spans, shmem_dim1, shmem_dim2, T}
484473
terms = [:(
485474
(nthreads[$shmem_dim1] + $(getproperty(use_shmemhalos, A)) * $(getproperty(shmem_spans, A)[1])) *
@@ -540,6 +529,15 @@ end
540529
end
541530
end
542531

532+
function precompute_parallel_memopt_arg(arg::Union{Symbol,Expr}, prefix::AbstractString)
533+
if isa(arg, Symbol)
534+
return Expr[], arg
535+
else
536+
arg_var = gensym(prefix)
537+
return [:(local $arg_var = $arg)], arg_var
538+
end
539+
end
540+
543541

544542
## FUNCTIONS TO DEAL WITH MASKS (@WITHIN) AND INDICES
545543

0 commit comments

Comments
 (0)