Skip to content

Commit a0b76d2

Browse files
authored
Merge pull request #212 from omlins/memopt2d1
2D optimizations
2 parents 8d86015 + 76cf269 commit a0b76d2

5 files changed

Lines changed: 572 additions & 71 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Automatic advanced fast memory usage optimization (of shared memory and register
6060
#(...)
6161
end
6262
#(...)
63-
@parallel memopt=true diffusion3D_step!(...)
63+
@parallel diffusion3D_step!(...)
6464
```
6565
Note that arrays are automatically allocated on the hardware chosen for the computations (GPU or CPU) when using the provided allocation macros:
6666
- `@zeros`

src/ParallelKernel/parallel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ function create_gpu_or_xpu_call(package::Symbol, nblocks::Union{Symbol,Expr}, nt
891891
elseif (package == PKG_KERNELABSTRACTIONS) shmem_expr = nothing # KernelAbstractions does not accept dynamic shared-memory sizes here.
892892
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
893893
end
894-
if package != PKG_METAL
894+
if !isnothing(shmem_expr)
895895
backend_kwargs_expr = (backend_kwargs_expr..., shmem_expr)
896896
end
897897
end

src/memopt.jl

Lines changed: 178 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
6060
use_shmemhalos = isnothing(use_shmemhalos) ? use_shmemhalos : eval_arg(caller, use_shmemhalos)
6161
optranges = isnothing(optranges) ? optranges : eval_arg(caller, optranges)
6262
readonlyvars = find_vars(body, indices; readonly=true)
63-
if length(indices) != 3 @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] <kernel>: optimization can only be applied in 3-D @parallel kernels and @parallel_indices kernels with three indices.") end
63+
if length(indices) (2, 3) @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] <kernel>: optimization can only be applied in 2-D and 3-D @parallel kernels and @parallel_indices kernels.") end
64+
if loopdim != length(indices) @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] <kernel>: two-index kernels require `loopdim=2` and three-index kernels require `loopdim=3`.") end
65+
if loopdim == 2 && !isnothing(use_shmemhalos) @IncoherentArgumentError("incoherent arguments memopt in @parallel[_indices] <kernel>: shared-memory-related keywords are not supported for two-index memory-optimized kernels.") end
6466
if optvars == (Symbol(""),)
6567
optvars = Tuple(keys(readonlyvars))
6668
else
@@ -83,7 +85,109 @@ function memopt(metadata_module::Module, is_parallel_kernel::Bool, caller::Modul
8385
optranges = define_optranges(optranges, optvars, offsets, int_type, package)
8486
regqueue_heads, regqueue_tails, offset_mins, offset_maxs, nb_regs_heads, nb_regs_tails = define_regqueues(offsets, optranges, optvars, indices, int_type, loopdim)
8587

86-
if loopdim == 3
88+
if loopdim == 2
89+
loopentrys = Dict(A => 1 - (offset_maxs[A][2] - offset_mins[A][2]) for A in optvars)
90+
oy_maxs = Dict(A => offset_maxs[A][2] for A in optvars)
91+
loopstart = minimum(values(loopentrys))
92+
loopend = loopsize
93+
use_any_shmem = false
94+
shmem_optvars = ()
95+
use_shmemhalos = Dict(A => false for A in optvars)
96+
ix, iy = indices
97+
ranges = RANGES_VARNAME
98+
range_y_start = :(($ranges[2])[1])
99+
range_y_end = :(($ranges[2])[end])
100+
i = gensym_world("i", @__MODULE__)
101+
loopoffset = gensym_world("loopoffset", @__MODULE__)
102+
103+
for A in optvars
104+
regqueue_tail = regqueue_tails[A]
105+
regqueue_head = regqueue_heads[A]
106+
for ox in keys(regqueue_tail)
107+
for oy in keys(regqueue_tail[ox])
108+
body = substitute(body, regtarget(A, (ox..., oy), indices), regqueue_tail[ox][oy])
109+
end
110+
end
111+
for ox in keys(regqueue_head)
112+
for oy in keys(regqueue_head[ox])
113+
body = substitute(body, regtarget(A, (ox..., oy), indices), regqueue_head[ox][oy])
114+
end
115+
end
116+
end
117+
118+
body = quote
119+
$loopoffset = (@blockIdx().y-1)*$loopsize + $range_y_start-1
120+
$((:( $reg = 0.0
121+
)
122+
for A in optvars for regs in values(regqueue_tails[A]) for reg in values(regs)
123+
)...
124+
)
125+
$((:( $reg = 0.0
126+
)
127+
for A in optvars for regs in values(regqueue_heads[A]) for reg in values(regs)
128+
)...
129+
)
130+
$((wrap_loop(i, loopstart:0,
131+
quote
132+
$iy = $i + $loopoffset
133+
if ($iy > $range_y_end) ParallelStencil.@return_nothing; end
134+
$((wrap_if(:($i > $(loopentry-1)),
135+
:( $reg = (0<$ix+$(ox[1])<=size($A,1) && 0<$iy+$oy<=size($A,2)) ? $(regtarget(A, (ox...,oy), indices)) : $reg
136+
)
137+
;unless=(loopentry==loopstart)
138+
)
139+
for A in optvars for (ox, regs) in regqueue_heads[A] for (oy, reg) in regs for loopentry = (loopentrys[A],)
140+
)...
141+
)
142+
$((
143+
:(
144+
$(regs[oy]) = $(regs[oy+1])
145+
)
146+
for A in optvars for regs in values(regqueue_tails[A]) for oy in sort(keys(regs)) for (loopentry, oy_max) = ((loopentrys[A], oy_maxs[A]),) if oy<=oy_max-2
147+
)...
148+
)
149+
$((
150+
:(
151+
$reg = $(regqueue_heads[A][ox][oy_max])
152+
)
153+
for A in optvars for (ox, regs) in regqueue_tails[A] for (oy, reg) in regs for (loopentry, oy_max) = ((loopentrys[A], oy_maxs[A]),) if oy==oy_max-1 && haskey(regqueue_heads[A], ox) && haskey(regqueue_heads[A][ox], oy_max)
154+
)...
155+
)
156+
end
157+
))
158+
)
159+
$((wrap_loop(i, 1:loopend,
160+
quote
161+
$iy = $i + $loopoffset
162+
if ($iy > $range_y_end) ParallelStencil.@return_nothing; end
163+
$((wrap_if(:($i > $(loopentry-1)),
164+
:( $reg = (0<$ix+$(ox[1])<=size($A,1) && 0<$iy+$oy<=size($A,2)) ? $(regtarget(A, (ox...,oy), indices)) : $reg
165+
)
166+
;unless=(loopentry<=1)
167+
)
168+
for A in optvars for (ox, regs) in regqueue_heads[A] for (oy, reg) in regs for loopentry = (loopentrys[A],)
169+
)...
170+
)
171+
$body
172+
$((
173+
:(
174+
$(regs[oy]) = $(regs[oy+1])
175+
)
176+
for A in optvars for regs in values(regqueue_tails[A]) for oy in sort(keys(regs)) for (loopentry, oy_max) = ((loopentrys[A], oy_maxs[A]),) if oy<=oy_max-2
177+
)...
178+
)
179+
$((
180+
:(
181+
$reg = $(regqueue_heads[A][ox][oy_max])
182+
)
183+
for A in optvars for (ox, regs) in regqueue_tails[A] for (oy, reg) in regs for (loopentry, oy_max) = ((loopentrys[A], oy_maxs[A]),) if oy==oy_max-1 && haskey(regqueue_heads[A], ox) && haskey(regqueue_heads[A][ox], oy_max)
184+
)...
185+
)
186+
end
187+
))
188+
)
189+
end
190+
elseif loopdim == 3
87191
oz_maxs, hx1s, hy1s, hx2s, hy2s, use_shmems, use_shmem_xs, use_shmem_ys, use_shmemhalos, use_shmemindices, offset_spans, oz_spans, loopentrys = define_helper_variables(offset_mins, offset_maxs, optvars, use_shmemhalos, loopdim)
88192
oz_span_max = maximum(values(oz_spans))
89193
# TODO: this only leads to correct result after row two executions in a row, probably due to the same compiler bug has below. # loopsize = (oz_span_max<=0) ? 1 : loopsize # NOTE: if the stencilrange in z is only one point, no loop is needed.
@@ -541,6 +645,19 @@ function extract_offsets(caller::Module, body::Expr, indices::NTuple{N,<:Union{S
541645
elseif haskey(offsets_by_z[A], k1) offsets_by_z[A][k1][k2] = 1
542646
else offsets_by_z[A][k1] = Dict(k2 => 1)
543647
end
648+
elseif loopdim == 2
649+
k1 = (offsets[1],)
650+
k2 = offsets[2]
651+
if haskey(offsets_by_xy[A], k1) && haskey(offsets_by_xy[A][k1], k2) offsets_by_xy[A][k1][k2] += 1
652+
elseif haskey(offsets_by_xy[A], k1) offsets_by_xy[A][k1][k2] = 1
653+
else offsets_by_xy[A][k1] = Dict(k2 => 1)
654+
end
655+
k1 = offsets[2]
656+
k2 = (offsets[1],)
657+
if haskey(offsets_by_z[A], k1) && haskey(offsets_by_z[A][k1], k2) offsets_by_z[A][k1][k2] += 1
658+
elseif haskey(offsets_by_z[A], k1) offsets_by_z[A][k1][k2] = 1
659+
else offsets_by_z[A][k1] = Dict(k2 => 1)
660+
end
544661
else
545662
@ArgumentError("memopt: only loopdim=3 is currently supported.")
546663
end
@@ -559,18 +676,18 @@ function define_optranges(optranges_arg, optvars, offsets, int_type, package)
559676
compute_capability = get_compute_capability(package)
560677
optranges = Dict()
561678
for A in optvars
562-
zspan_max = 0
563-
oxy_zspan_max = ()
564-
for oxy in keys(offsets[A])
565-
zspan = length(keys(offsets[A][oxy]))
566-
if zspan > zspan_max
567-
zspan_max = zspan
568-
oxy_zspan_max = oxy
679+
loopspan_max = 0
680+
offsets_with_max_span = ()
681+
for nonloop_offsets in keys(offsets[A])
682+
loopspan = length(keys(offsets[A][nonloop_offsets]))
683+
if loopspan > loopspan_max
684+
loopspan_max = loopspan
685+
offsets_with_max_span = nonloop_offsets
569686
end
570687
end
571688
fullrange = typemin(int_type):typemax(int_type)
572-
pointrange_x = oxy_zspan_max[1]: oxy_zspan_max[1]
573-
pointrange_y = oxy_zspan_max[2]: oxy_zspan_max[2]
689+
pointrange_x = offsets_with_max_span[1]:offsets_with_max_span[1]
690+
pointrange_y = (length(offsets_with_max_span) > 1) ? (offsets_with_max_span[2]:offsets_with_max_span[2]) : fullrange
574691
if (!isnothing(optranges_arg) && A keys(optranges_arg)) optranges[A] = getproperty(optranges_arg, A)
575692
elseif (compute_capability < v"8" && (length(optvars) <= FULLRANGE_THRESHOLD)) optranges[A] = (fullrange, fullrange, fullrange)
576693
elseif (USE_FULLRANGE_DEFAULT == (true, true, true)) optranges[A] = (fullrange, fullrange, fullrange)
@@ -639,6 +756,45 @@ function define_regqueue(offsets::Dict{Any, Any}, optranges::NTuple{3,UnitRange}
639756
end
640757
end
641758
end
759+
elseif loopdim == 2
760+
optranges_x = optranges[1]
761+
optranges_y = optranges[2]
762+
offsets_x = filter(ox -> ox[1] optranges_x, keys(offsets))
763+
if isempty(offsets_x) @IncoherentArgumentError("incoherent argument in memopt: optranges in x dimension do not include any array access.") end
764+
offset_min = (typemax(int_type), typemax(int_type), 0)
765+
offset_max = (typemin(int_type), typemin(int_type), 0)
766+
for ox in offsets_x
767+
offsets_y = sort(filter(y -> y optranges_y, keys(offsets[ox])))
768+
if isempty(offsets_y) @IncoherentArgumentError("incoherent argument in memopt: optranges in y dimension do not include any array access.") end
769+
offset_min = (min(offset_min[1], ox[1]),
770+
min(offset_min[2], minimum(offsets_y)),
771+
0)
772+
offset_max = (max(offset_max[1], ox[1]),
773+
max(offset_max[2], maximum(offsets_y)),
774+
0)
775+
end
776+
oy_max = offset_max[2]
777+
for ox in offsets_x
778+
offsets_y = sort(filter(y -> y optranges_y, keys(offsets[ox])))
779+
k1 = ox
780+
for oy = offsets_y[1]:oy_max-1
781+
k2 = oy
782+
if haskey(regqueue_tail, k1) && haskey(regqueue_tail[k1], k2) @ModuleInternalError("regqueue_tail entry exists already.") end
783+
reg = gensym_world(varname(A, (ox..., oy)), @__MODULE__); nb_regs_tail += 1
784+
if haskey(regqueue_tail, k1) regqueue_tail[k1][k2] = reg
785+
else regqueue_tail[k1] = Dict(k2 => reg)
786+
end
787+
end
788+
oy = offsets_y[end]
789+
if oy == oy_max
790+
k2 = oy
791+
if haskey(regqueue_head, k1) && haskey(regqueue_head[k1], k2) @ModuleInternalError("regqueue_head entry exists already.") end
792+
reg = gensym_world(varname(A, (ox..., oy)), @__MODULE__); nb_regs_head += 1
793+
if haskey(regqueue_head, k1) regqueue_head[k1][k2] = reg
794+
else regqueue_head[k1] = Dict(k2 => reg)
795+
end
796+
end
797+
end
642798
else
643799
@ArgumentError("memopt: only loopdim=3 is currently supported.")
644800
end
@@ -1020,11 +1176,19 @@ function wrap_loop(index::Symbol, range::UnitRange, block::Expr; unroll=false)
10201176
end
10211177
end
10221178

1023-
function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:NTuple{3,Integer}}, offset_maxs::Dict{Symbol, <:NTuple{3,Integer}}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
1179+
function store_metadata(metadata_module::Module, is_parallel_kernel::Bool, caller::Module, offset_mins::Dict{Symbol, <:Tuple}, offset_maxs::Dict{Symbol, <:Tuple}, offsets::Dict{Symbol, Dict{Any, Any}}, optvars::NTuple{N,Symbol} where N, shmem_optvars::NTuple{M,Symbol} where M, use_any_shmem::Bool, loopdim::Integer, loopsize::Integer, optranges::Dict{Any, Any}, use_shmemhalos)
10241180
memopt = true
10251181
nonconst_metadata = get_nonconst_metadata(caller)
1026-
stencilranges = NamedTuple(A => (offset_mins[A][1]:offset_maxs[A][1], offset_mins[A][2]:offset_maxs[A][2], offset_mins[A][3]:offset_maxs[A][3]) for A in optvars)
1027-
use_shmemhalos = NamedTuple(A => use_shmemhalos[A] for A in optvars)
1182+
stencilranges = NamedTuple(A => begin
1183+
offset_min = offset_mins[A]
1184+
offset_max = offset_maxs[A]
1185+
ndims = length(offset_min)
1186+
x = offset_min[1]:offset_max[1]
1187+
y = (ndims > 1 ? offset_min[2] : 0):(ndims > 1 ? offset_max[2] : 0)
1188+
z = (ndims > 2 ? offset_min[3] : 0):(ndims > 2 ? offset_max[3] : 0)
1189+
(x, y, z)
1190+
end for A in optvars)
1191+
use_shmemhalos = NamedTuple(A => get(use_shmemhalos, A, false) for A in optvars)
10281192
loopsizes = (loopdim==3) ? (1, 1, loopsize) : (loopdim==2) ? (1, loopsize, 1) : (loopsize, 1, 1)
10291193
shmem_dim1 = (loopdim==3) ? 1 : (loopdim==2) ? 1 : 2
10301194
shmem_dim2 = (loopdim==3) ? 2 : (loopdim==2) ? 3 : 3

0 commit comments

Comments
 (0)