Skip to content

Commit cb84277

Browse files
committed
Implements Loechner Reindexing
1 parent 2d13a25 commit cb84277

4 files changed

Lines changed: 333 additions & 0 deletions

File tree

doc/misc.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,13 @@ Here's a Bibtex entry for your convenience::
456456
doi = "{10.1145/2627373.2627387}",
457457
}
458458

459+
References
460+
==========
461+
462+
.. [Seghir_2006] Seghir and Loechner, Proceedings of the 2006 International
463+
Conference on Compilers, Architecture and Synthesis for Embedded systems,
464+
`(DOI) <https://doi.org/10.1145/1176760.1176771>`__
465+
459466
Getting help
460467
============
461468

doc/ref_transform.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ Influencing data access
5252

5353
.. automodule:: loopy.transform.privatize
5454

55+
.. automodule:: loopy.transform.reindexing
56+
5557
Padding Data
5658
------------
5759

loopy/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
from loopy.transform.buffer import buffer_array
106106
from loopy.transform.fusion import fuse_kernels
107107

108+
from loopy.transform.reindexing import reindex_using_seghir_loechner_scheme
109+
108110
from loopy.transform.arithmetic import (
109111
fold_constants,
110112
collect_common_factors_on_increment)
@@ -233,6 +235,8 @@
233235

234236
"fold_constants", "collect_common_factors_on_increment",
235237

238+
"reindex_using_seghir_loechner_scheme",
239+
236240
"split_array_axis", "split_array_dim", "split_arg_axis",
237241
"find_padding_multiple", "add_padding",
238242

loopy/transform/reindexing.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: reindex_using_seghir_loechner_scheme
5+
"""
6+
7+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
8+
9+
__license__ = """
10+
Permission is hereby granted, free of charge, to any person obtaining a copy
11+
of this software and associated documentation files (the "Software"), to deal
12+
in the Software without restriction, including without limitation the rights
13+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14+
copies of the Software, and to permit persons to whom the Software is
15+
furnished to do so, subject to the following conditions:
16+
17+
The above copyright notice and this permission notice shall be included in
18+
all copies or substantial portions of the Software.
19+
20+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26+
THE SOFTWARE.
27+
"""
28+
29+
30+
import islpy as isl
31+
from typing import Union, Sequence, Iterable, Tuple, List
32+
from loopy.kernel import LoopKernel
33+
from loopy.diagnostic import LoopyError
34+
from loopy.symbolic import CombineMapper
35+
from loopy.kernel.instruction import (MultiAssignmentBase,
36+
CInstruction, BarrierInstruction)
37+
from loopy.symbolic import RuleAwareIdentityMapper
38+
39+
40+
ISLMapT = Union[isl.BasicMap, isl.Map]
41+
ISLSetT = Union[isl.BasicSet, isl.Set]
42+
43+
44+
def make_dim_names_primed(isl_map: ISLMapT, dts: Iterable[isl.dim_type]) -> ISLMapT:
45+
"""
46+
Returns a copy of *isl_map* with dims of types *dts* having their names
47+
suffixed with an apostrophe (``'``).
48+
49+
.. testsetup::
50+
51+
>>> import islpy as isl
52+
>>> from loopy.transform.reindexing import make_dim_names_primed
53+
54+
.. doctest::
55+
56+
>>> amap = isl.Map("{[i]->[j=2i]}")
57+
>>> make_dim_names_primed(amap, [isl.dim_type.in_, isl.dim_type.out])
58+
Map("{ [i'] -> [j' = 2i'] }")
59+
"""
60+
for dt in dts:
61+
for idim in range(isl_map.dim(dt)):
62+
old_name = isl_map.get_dim_name(dt, idim)
63+
new_name = f"{old_name}'"
64+
isl_map = isl_map.set_dim_name(dt, idim, new_name)
65+
66+
return isl_map
67+
68+
69+
def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT
70+
) -> Tuple[isl.PwQPolynomial,
71+
isl.PwQPolynomial]:
72+
"""
73+
Returns ``(reindex_map, new_shape)``, where,
74+
75+
* ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1,
76+
.., in)}`` representing that an array indexed via the subscripts
77+
``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as
78+
``f(i1, .., in)``.
79+
* ``new_shape`` is a quasi-polynomial corresponding to the shape of the
80+
re-indexed 1-dimensional array.
81+
"""
82+
83+
# {{{ create amap: an ISL map which is an identity map from access_map's range
84+
85+
amap = isl.BasicMap.identity(
86+
access_range
87+
.space
88+
.add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out)))
89+
90+
# set amap's dim names
91+
for idim in range(amap.dim(isl.dim_type.in_)):
92+
amap = amap.set_dim_name(isl.dim_type.in_, idim,
93+
f"_lpy_in_{idim}")
94+
amap = amap.set_dim_name(isl.dim_type.out, idim,
95+
f"_lpy_out_{idim}")
96+
97+
amap = amap.intersect_domain(access_range)
98+
99+
# }}}
100+
101+
n_in = amap.dim(isl.dim_type.out)
102+
n_out = amap.dim(isl.dim_type.out)
103+
104+
amap_lexmin = amap.lexmin()
105+
primed_amap_lexmin = make_dim_names_primed(amap_lexmin, [isl.dim_type.in_,
106+
isl.dim_type.out])
107+
108+
lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin)
109+
110+
# make the lexmin map parametric in terms of it's previous access expressions.
111+
lex_lt_set = (lex_lt_map
112+
.move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in)
113+
.domain())
114+
115+
# {{{ initialize amap_to_count
116+
117+
amap_to_count = make_dim_names_primed(amap, [isl.dim_type.in_])
118+
amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in)
119+
120+
for idim in range(n_in):
121+
amap_to_count = amap_to_count.set_dim_name(
122+
isl.dim_type.param, idim,
123+
amap.get_dim_name(isl.dim_type.in_, idim))
124+
125+
amap_to_count = amap_to_count.intersect_domain(lex_lt_set)
126+
127+
# }}}
128+
129+
result = amap_to_count.range().card()
130+
131+
# {{{ simplify 'result' by gisting with 'access_range'
132+
133+
aligned_access_range = access_range.move_dims(isl.dim_type.param, 0,
134+
isl.dim_type.set, 0, n_out)
135+
136+
for idim in range(result.dim(isl.dim_type.param)):
137+
aligned_access_range = (
138+
aligned_access_range
139+
.set_dim_name(isl.dim_type.param, idim,
140+
result.space.get_dim_name(isl.dim_type.param,
141+
idim)))
142+
143+
result = result.gist_params(aligned_access_range.params())
144+
145+
# }}}
146+
147+
return result, access_range.card()
148+
149+
150+
class _IndexCollector(CombineMapper):
151+
"""
152+
A mapper that collects all instances of
153+
:class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`.
154+
"""
155+
def __init__(self, var_name):
156+
super().__init__()
157+
self.var_name = var_name
158+
159+
def combine(self, values):
160+
from functools import reduce
161+
return reduce(frozenset.union, values, frozenset())
162+
163+
def map_subscript(self, expr):
164+
if expr.aggregate.name == self.var_name:
165+
return frozenset([expr]) | super().map_subscript(expr)
166+
else:
167+
return super().map_subscript(expr)
168+
169+
def map_constant(self, expr):
170+
return frozenset()
171+
172+
map_variable = map_constant
173+
map_function_symbol = map_constant
174+
map_tagged_variable = map_constant
175+
map_type_cast = map_constant
176+
map_nan = map_constant
177+
178+
179+
def _union_access_ranges(access_ranges: Sequence[ISLSetT]):
180+
result = access_ranges[0]
181+
182+
for access_range in access_ranges[1:]:
183+
assert result.dim(isl.dim_type.out) == access_range.dim(isl.dim_type.out)
184+
result, accesss_range = isl.align_two(result, access_range)
185+
result = result | access_range
186+
187+
return result
188+
189+
190+
class ReindexingApplier(RuleAwareIdentityMapper):
191+
def __init__(self, rule_mapping_context,
192+
var_to_reindex,
193+
reindexed_var_name,
194+
new_index_expr,
195+
index_names):
196+
197+
super().__init__(rule_mapping_context)
198+
199+
self.var_to_reindex = var_to_reindex
200+
self.reindexed_var_name = reindexed_var_name
201+
self.new_index_expr = new_index_expr
202+
self.index_names = index_names
203+
204+
def map_subscript(self, expr, expn_state):
205+
if expr.aggregate.name != self.var_to_reindex:
206+
return super().map_subscript(expr, expn_state)
207+
208+
from loopy.symbolic import SubstitutionMapper
209+
from pymbolic.mapper.substitutor import make_subst_func
210+
from pymbolic.primitives import Subscript, Variable
211+
212+
rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple)
213+
214+
assert len(self.index_names) == len(rec_indices)
215+
subst_func = make_subst_func({idx_name: rec_idx
216+
for idx_name, rec_idx in zip(self.index_names,
217+
rec_indices)})
218+
219+
return SubstitutionMapper(subst_func)(
220+
Subscript(Variable(self.reindexed_var_name),
221+
self.new_index_expr)
222+
)
223+
224+
225+
def reindex_using_seghir_loechner_scheme(kernel: LoopKernel, var_name: str):
226+
"""
227+
Returns a kernel with expressions of the form ``var_name[i1, .., in]``
228+
replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a
229+
quasi-polynomial as outlined in [Seghir_2006]_.
230+
"""
231+
from loopy.transform.subst import expand_subst
232+
from loopy.symbolic import (get_access_map, pw_qpolynomial_to_expr,
233+
qpolynomial_to_expr, SubstitutionRuleMappingContext)
234+
235+
if var_name not in kernel.temporary_variables:
236+
raise LoopyError(f"'{var_name}' not in temporary variable in kernel"
237+
f" '{kernel.name}'.")
238+
239+
# {{{ compute the access_range of *var_name* in *kernel*
240+
241+
subst_kernel = expand_subst(kernel)
242+
vng = kernel.get_var_name_generator()
243+
new_var_name = vng(var_name+"_reindexed")
244+
245+
access_ranges: List[ISLSetT] = []
246+
247+
for insn in subst_kernel.instructions:
248+
domain = subst_kernel.get_inames_domain(insn.within_inames)
249+
if isinstance(insn, MultiAssignmentBase):
250+
access_exprs = _IndexCollector(var_name)((insn.assignees,
251+
insn.expression,
252+
tuple(insn.predicates)))
253+
elif isinstance(insn, (CInstruction, BarrierInstruction)):
254+
access_exprs = _IndexCollector(var_name)(insn.predicates)
255+
else:
256+
raise NotImplementedError(type(insn))
257+
258+
for access_expr in access_exprs:
259+
access_ranges.append(get_access_map(domain,
260+
access_expr.index_tuple,
261+
assumptions=kernel.assumptions)
262+
.range())
263+
264+
del subst_kernel
265+
266+
# }}}
267+
268+
subst, new_shape = _get_seghir_loechner_reindexing_from_range(
269+
_union_access_ranges(access_ranges))
270+
271+
# {{{ update kernel.temporary_variables
272+
273+
new_shape = new_shape.drop_unused_params()
274+
275+
if new_shape.dim(isl.dim_type.param) != 0:
276+
raise NotImplementedError("Temporaries with parametric shapes not yet"
277+
" supported.")
278+
279+
(_, shape_qpoly), = new_shape.get_pieces()
280+
281+
new_temps = kernel.temporary_variables.copy()
282+
new_temps[new_var_name] = new_temps.pop(var_name).copy(
283+
name=new_var_name,
284+
shape=qpolynomial_to_expr(shape_qpoly),
285+
strides=None,
286+
dim_tags=None,
287+
dim_names=None,
288+
)
289+
290+
# }}}
291+
292+
# {{{ perform the substitution i.e. reindex the accesses
293+
294+
subst_expr = pw_qpolynomial_to_expr(subst)
295+
subst_dim_names = tuple(
296+
subst.space.get_dim_name(isl.dim_type.param, idim)
297+
for idim in range(len(kernel.temporary_variables[var_name].shape)))
298+
assert not (set(subst_dim_names) & kernel.all_variable_names())
299+
300+
kernel = kernel.copy(temporary_variables=new_temps)
301+
rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions,
302+
vng)
303+
reindexing_mapper = ReindexingApplier(rule_mapping_context,
304+
var_name, new_var_name,
305+
subst_expr, subst_dim_names)
306+
307+
def _does_access_var_name(kernel, insn, *args):
308+
return var_name in insn.dependency_names()
309+
310+
kernel = reindexing_mapper.map_kernel(kernel,
311+
within=_does_access_var_name,
312+
map_args=False,
313+
map_tvs=False)
314+
kernel = rule_mapping_context.finish_kernel(kernel)
315+
316+
# }}}
317+
318+
return kernel
319+
320+
# vim: fdm=marker

0 commit comments

Comments
 (0)