Skip to content

Commit 70b5dfe

Browse files
isurufinducer
andauthored
Add concatenate_arrays (#717)
* Add merge_temporary_arrays transformation * type annotate * move offset computation to a separate loop * check dependency_names before using map_instructions * rename * generate a new array name if not given * expect a list of args instead of comma separated string of args * fix flake8 errors * make mypy's job easier * Rename, clean up concatenate_arrays Co-authored-by: Andreas Kloeckner <inform@tiker.net>
1 parent 443585b commit 70b5dfe

4 files changed

Lines changed: 169 additions & 3 deletions

File tree

loopy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
from loopy.transform.precompute import precompute
106106
from loopy.transform.buffer import buffer_array
107107
from loopy.transform.fusion import fuse_kernels
108+
from loopy.transform.concatenate import concatenate_arrays
108109

109110
from loopy.transform.arithmetic import (
110111
fold_constants,
@@ -232,6 +233,7 @@
232233

233234
"precompute", "buffer_array",
234235
"fuse_kernels",
236+
"concatenate_arrays",
235237

236238
"fold_constants", "collect_common_factors_on_increment",
237239

loopy/transform/concatenate.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
__copyright__ = "Copyright (C) 2022 Isuru Fernando"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
__doc__ = """
24+
.. currentmodule:: loopy
25+
.. autofunction:: concatenate_arrays
26+
"""
27+
28+
from typing import Sequence, Optional, List
29+
30+
from loopy.kernel.data import ArrayArg, KernelArgument, TemporaryVariable, auto
31+
from loopy.symbolic import SubstitutionRuleMappingContext
32+
from loopy.kernel import LoopKernel
33+
from loopy.translation_unit import for_each_kernel
34+
35+
import pymbolic.primitives as prim
36+
from pytools import all_equal
37+
38+
39+
@for_each_kernel
40+
def concatenate_arrays(
41+
kernel: LoopKernel,
42+
array_names: Sequence[str],
43+
new_name: Optional[str] = None,
44+
axis_nr: int = 0) -> LoopKernel:
45+
"""Merges arrays (temporaries or arguments) into one array along the axis
46+
given by *axis_nr*.
47+
48+
:arg array_names: a list of names of temporary variables.
49+
50+
:arg axis_nr: the (zero-based) index of the axis of the arrays to be merged.
51+
52+
:arg new_name: new name for the merged temporary. If not given, a new name
53+
is generated.
54+
"""
55+
assert isinstance(kernel, LoopKernel)
56+
57+
var_name_gen = kernel.get_var_name_generator()
58+
new_name = new_name or var_name_gen("concatenated_array")
59+
new_aggregate = prim.Variable(new_name)
60+
61+
arrays = []
62+
for array_name in array_names:
63+
ary = kernel.get_var_descriptor(array_name)
64+
if ary.shape is None or ary.shape is auto:
65+
raise ValueError(f"Shape of temporary variable '{array_name}' is "
66+
"unknown. Cannot merge with unknown shapes")
67+
68+
assert isinstance(ary.shape, tuple)
69+
shape = list(ary.shape)
70+
# make the shape value at axis_nr a constant so that we can
71+
# check that the rest of the attributes (except name) are equal.
72+
shape[axis_nr] = 1
73+
arrays.append(ary.copy(shape=tuple(shape), name=new_name))
74+
75+
if not all_equal(arrays):
76+
raise ValueError("Arrays must be identical except for shape "
77+
"(except for shape) in order to concatenate.")
78+
79+
offsets = {}
80+
axis_length = 0
81+
for array_name in array_names:
82+
offsets[array_name] = axis_length
83+
ary = kernel.temporary_variables[array_name]
84+
assert isinstance(ary.shape, tuple)
85+
axis_length += ary.shape[axis_nr]
86+
87+
new_ary = arrays[0]
88+
new_shape = list(new_ary.shape)
89+
new_shape[axis_nr] = axis_length
90+
new_ary = new_ary.copy(shape=tuple(new_shape))
91+
92+
# {{{ rewrite subscripts
93+
94+
from loopy.transform.padding import SubscriptRewriter
95+
96+
def modify_array_access(expr):
97+
idx = expr.index
98+
if not isinstance(idx, tuple):
99+
idx = (idx,)
100+
idx = list(idx)
101+
idx[axis_nr] += offsets[expr.aggregate.name]
102+
103+
return new_aggregate.index(tuple(idx))
104+
105+
rule_mapping_context = SubstitutionRuleMappingContext(
106+
kernel.substitutions, var_name_gen)
107+
aash = SubscriptRewriter(rule_mapping_context,
108+
array_names, modify_array_access)
109+
kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))
110+
111+
# }}}
112+
113+
if isinstance(new_ary, TemporaryVariable):
114+
new_tvs = {name: tv for name, tv in kernel.temporary_variables.items()
115+
if name not in array_names}
116+
new_tvs[new_name] = new_ary
117+
return kernel.copy(temporary_variables=new_tvs)
118+
elif isinstance(new_ary, ArrayArg):
119+
new_args: List[KernelArgument] = []
120+
inserted = False
121+
for arg in kernel.args:
122+
if arg.name in array_names:
123+
if not inserted:
124+
new_args.append(new_ary)
125+
inserted = True
126+
else:
127+
new_args.append(arg)
128+
return kernel.copy(args=new_args)
129+
else:
130+
raise AssertionError()

loopy/transform/padding.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from loopy.diagnostic import LoopyError
3232

3333

34-
class ArrayAxisSplitHelper(RuleAwareIdentityMapper):
34+
class SubscriptRewriter(RuleAwareIdentityMapper):
3535
def __init__(self, rule_mapping_context, arg_names, handler):
3636
super().__init__(rule_mapping_context)
3737
self.arg_names = arg_names
@@ -43,6 +43,21 @@ def map_subscript(self, expr, expn_state):
4343
else:
4444
return super().map_subscript(expr, expn_state)
4545

46+
def map_kernel(self, kernel, within=lambda *args: True):
47+
new_insns = [
48+
# While subst rules are not allowed in assignees, the mapper
49+
# may perform tasks entirely unrelated to subst rules, so
50+
# we must map assignees, too.
51+
insn if not kernel.substitutions and not within(kernel, insn, ()) \
52+
and not any(name in self.arg_names for name in \
53+
insn.dependency_names()) else
54+
self.map_instruction(kernel,
55+
insn.with_transformed_expressions(
56+
lambda expr: self(expr, kernel, insn))) # noqa: B023
57+
for insn in kernel.instructions]
58+
59+
return kernel.copy(instructions=new_insns)
60+
4661

4762
# {{{ split_array_dim (deprecated since June 2016)
4863

@@ -236,7 +251,7 @@ def split_access_axis(expr):
236251

237252
rule_mapping_context = SubstitutionRuleMappingContext(
238253
kernel.substitutions, var_name_gen)
239-
aash = ArrayAxisSplitHelper(rule_mapping_context,
254+
aash = SubscriptRewriter(rule_mapping_context,
240255
set(array_to_rest.keys()), split_access_axis)
241256
kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))
242257

@@ -367,7 +382,7 @@ def split_access_axis(expr):
367382

368383
rule_mapping_context = SubstitutionRuleMappingContext(
369384
kernel.substitutions, var_name_gen)
370-
aash = ArrayAxisSplitHelper(rule_mapping_context,
385+
aash = SubscriptRewriter(rule_mapping_context,
371386
{array_name}, split_access_axis)
372387
kernel = rule_mapping_context.finish_kernel(aash.map_kernel(kernel))
373388

test/test_transform.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,25 @@ def test_prefetch_to_same_temp_var(ctx_factory):
16081608
lp.auto_test_vs_ref(ref_tunit, ctx, t_unit)
16091609

16101610

1611+
def test_concatenate_arrays(ctx_factory):
1612+
ctx = ctx_factory()
1613+
1614+
t_unit = lp.make_kernel(
1615+
"{[i]: 0<=i<10}",
1616+
"""
1617+
<> a[i] = x[i] {id=init_a}
1618+
<> b[i] = y[i] {id=init_b}
1619+
out[i] = a[i] + b[i] {id=insn,dep=init_a:init_b}
1620+
""")
1621+
1622+
t_unit = lp.add_dtypes(t_unit, {"x": "float64", "y": "float64"})
1623+
ref_t_unit = t_unit
1624+
1625+
t_unit = lp.concatenate_arrays(t_unit, ["a", "b"], "c")
1626+
assert t_unit.default_entrypoint.temporary_variables["c"].shape == (20,)
1627+
lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit)
1628+
1629+
16111630
if __name__ == "__main__":
16121631
if len(sys.argv) > 1:
16131632
exec(sys.argv[1])

0 commit comments

Comments
 (0)