Skip to content

Commit 9be2056

Browse files
committed
Implements Loechner Reindexing
1 parent e62f99b commit 9be2056

4 files changed

Lines changed: 340 additions & 0 deletions

File tree

doc/misc.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,13 @@ Here's a Bibtex entry for your convenience::
456456
doi = "{10.1145/2627373.2627387}",
457457
}
458458

459+
References
460+
==========
461+
462+
.. [Seghir_2006] Seghir and Loechner, Proceedings of the 2006 International
463+
Conference on Compilers, Architecture and Synthesis for Embedded systems,
464+
`(DOI) <https://doi.org/10.1145/1176760.1176771>`__
465+
459466
Getting help
460467
============
461468

doc/ref_transform.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ Influencing data access
5454

5555
.. autofunction:: allocate_temporaries_for_base_storage
5656

57+
.. automodule:: loopy.transform.reindex
58+
5759
Padding Data
5860
------------
5961

loopy/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
from loopy.transform.buffer import buffer_array
107107
from loopy.transform.fusion import fuse_kernels
108108

109+
from loopy.transform.reindex import reindex_temporary_using_seghir_loechner_scheme
110+
109111
from loopy.transform.arithmetic import (
110112
fold_constants,
111113
collect_common_factors_on_increment)
@@ -234,6 +236,8 @@
234236

235237
"fold_constants", "collect_common_factors_on_increment",
236238

239+
"reindex_temporary_using_seghir_loechner_scheme",
240+
237241
"split_array_axis", "split_array_dim", "split_arg_axis",
238242
"find_padding_multiple", "add_padding",
239243

loopy/transform/reindex.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: reindex_temporary_using_seghir_loechner_scheme
5+
"""
6+
7+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
8+
9+
__license__ = """
10+
Permission is hereby granted, free of charge, to any person obtaining a copy
11+
of this software and associated documentation files (the "Software"), to deal
12+
in the Software without restriction, including without limitation the rights
13+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14+
copies of the Software, and to permit persons to whom the Software is
15+
furnished to do so, subject to the following conditions:
16+
17+
The above copyright notice and this permission notice shall be included in
18+
all copies or substantial portions of the Software.
19+
20+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26+
THE SOFTWARE.
27+
"""
28+
29+
30+
import islpy as isl
31+
from typing import Union, Iterable, Tuple
32+
from loopy.typing import ExpressionT
33+
from loopy.kernel import LoopKernel
34+
from loopy.diagnostic import LoopyError
35+
from loopy.symbolic import CombineMapper
36+
from loopy.kernel.instruction import (MultiAssignmentBase,
37+
CInstruction, BarrierInstruction)
38+
from loopy.symbolic import RuleAwareIdentityMapper
39+
40+
41+
ISLMapT = Union[isl.BasicMap, isl.Map]
42+
ISLSetT = Union[isl.BasicSet, isl.Set]
43+
44+
45+
def _add_prime_to_dim_names(isl_map: ISLMapT,
46+
dts: Iterable[isl.dim_type]) -> ISLMapT:
47+
"""
48+
Returns a copy of *isl_map* with dims of types *dts* having their names
49+
suffixed with an apostrophe (``'``).
50+
51+
.. testsetup::
52+
53+
>>> import islpy as isl
54+
>>> from loopy.transform.reindex import _add_prime_to_dim_names
55+
56+
.. doctest::
57+
58+
>>> amap = isl.Map("{[i]->[j=2i]}")
59+
>>> _add_prime_to_dim_names(amap, [isl.dim_type.in_, isl.dim_type.out])
60+
Map("{ [i'] -> [j' = 2i'] }")
61+
"""
62+
for dt in dts:
63+
for idim in range(isl_map.dim(dt)):
64+
old_name = isl_map.get_dim_name(dt, idim)
65+
new_name = f"{old_name}'"
66+
isl_map = isl_map.set_dim_name(dt, idim, new_name)
67+
68+
return isl_map
69+
70+
71+
def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT
72+
) -> Tuple[isl.PwQPolynomial,
73+
isl.PwQPolynomial]:
74+
"""
75+
Returns ``(reindex_map, new_shape)``, where,
76+
77+
* ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1,
78+
.., in)}`` representing that an array indexed via the subscripts
79+
``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as
80+
``f(i1, .., in)``.
81+
* ``new_shape`` is a quasi-polynomial corresponding to the shape of the
82+
re-indexed 1-dimensional array.
83+
"""
84+
85+
# {{{ create amap: an ISL map which is an identity map from access_map's range
86+
87+
amap = isl.BasicMap.identity(
88+
access_range
89+
.space
90+
.add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out)))
91+
92+
# set amap's dim names
93+
for idim in range(amap.dim(isl.dim_type.in_)):
94+
amap = amap.set_dim_name(isl.dim_type.in_, idim,
95+
f"_lpy_in_{idim}")
96+
amap = amap.set_dim_name(isl.dim_type.out, idim,
97+
f"_lpy_out_{idim}")
98+
99+
amap = amap.intersect_domain(access_range)
100+
101+
# }}}
102+
103+
n_in = amap.dim(isl.dim_type.out)
104+
n_out = amap.dim(isl.dim_type.out)
105+
106+
amap_lexmin = amap.lexmin()
107+
primed_amap_lexmin = _add_prime_to_dim_names(amap_lexmin, [isl.dim_type.in_,
108+
isl.dim_type.out])
109+
110+
lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin)
111+
112+
# make the lexmin map parametric in terms of it's previous access expressions.
113+
lex_lt_set = (lex_lt_map
114+
.move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in)
115+
.domain())
116+
117+
# {{{ initialize amap_to_count
118+
119+
amap_to_count = _add_prime_to_dim_names(amap, [isl.dim_type.in_])
120+
amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in)
121+
122+
for idim in range(n_in):
123+
amap_to_count = amap_to_count.set_dim_name(
124+
isl.dim_type.param, idim,
125+
amap.get_dim_name(isl.dim_type.in_, idim))
126+
127+
amap_to_count = amap_to_count.intersect_domain(lex_lt_set)
128+
129+
# }}}
130+
131+
result = amap_to_count.range().card()
132+
133+
# {{{ simplify 'result' by gisting with 'access_range'
134+
135+
aligned_access_range = access_range.move_dims(isl.dim_type.param, 0,
136+
isl.dim_type.set, 0, n_out)
137+
138+
for idim in range(result.dim(isl.dim_type.param)):
139+
aligned_access_range = (
140+
aligned_access_range
141+
.set_dim_name(isl.dim_type.param, idim,
142+
result.space.get_dim_name(isl.dim_type.param,
143+
idim)))
144+
145+
result = result.gist_params(aligned_access_range.params())
146+
147+
# }}}
148+
149+
return result, access_range.card()
150+
151+
152+
class _IndexCollector(CombineMapper):
153+
"""
154+
A mapper that collects all instances of
155+
:class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`.
156+
"""
157+
def __init__(self, var_name):
158+
super().__init__()
159+
self.var_name = var_name
160+
161+
def combine(self, values):
162+
from functools import reduce
163+
return reduce(frozenset.union, values, frozenset())
164+
165+
def map_subscript(self, expr):
166+
if expr.aggregate.name == self.var_name:
167+
return frozenset([expr]) | super().map_subscript(expr)
168+
else:
169+
return super().map_subscript(expr)
170+
171+
def map_constant(self, expr):
172+
return frozenset()
173+
174+
map_variable = map_constant
175+
map_function_symbol = map_constant
176+
map_tagged_variable = map_constant
177+
map_type_cast = map_constant
178+
map_nan = map_constant
179+
180+
181+
class ReindexingApplier(RuleAwareIdentityMapper):
182+
def __init__(self, rule_mapping_context,
183+
var_to_reindex,
184+
reindexed_var_name,
185+
new_index_expr,
186+
index_names):
187+
188+
super().__init__(rule_mapping_context)
189+
190+
self.var_to_reindex = var_to_reindex
191+
self.reindexed_var_name = reindexed_var_name
192+
self.new_index_expr = new_index_expr
193+
self.index_names = index_names
194+
195+
def map_subscript(self, expr, expn_state):
196+
if expr.aggregate.name != self.var_to_reindex:
197+
return super().map_subscript(expr, expn_state)
198+
199+
from loopy.symbolic import SubstitutionMapper
200+
from pymbolic.mapper.substitutor import make_subst_func
201+
from pymbolic.primitives import Subscript, Variable
202+
203+
rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple)
204+
205+
assert len(self.index_names) == len(rec_indices)
206+
subst_func = make_subst_func({idx_name: rec_idx
207+
for idx_name, rec_idx in zip(self.index_names,
208+
rec_indices)})
209+
210+
return SubstitutionMapper(subst_func)(
211+
Subscript(Variable(self.reindexed_var_name),
212+
self.new_index_expr)
213+
)
214+
215+
216+
def reindex_temporary_using_seghir_loechner_scheme(kernel: LoopKernel,
217+
var_name: str,
218+
) -> LoopKernel:
219+
"""
220+
Returns a kernel with expressions of the form ``var_name[i1, .., in]``
221+
replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a
222+
quasi-polynomial as outlined in [Seghir_2006]_.
223+
"""
224+
from loopy.transform.subst import expand_subst
225+
from loopy.symbolic import (BatchedAccessMapMapper, pw_qpolynomial_to_expr,
226+
SubstitutionRuleMappingContext)
227+
228+
if var_name not in kernel.temporary_variables:
229+
raise LoopyError(f"'{var_name}' not in temporary variable in kernel"
230+
f" '{kernel.name}'.")
231+
232+
# {{{ compute the access_range of *var_name* in *kernel*
233+
234+
subst_kernel = expand_subst(kernel)
235+
access_map_recorder = BatchedAccessMapMapper(
236+
subst_kernel,
237+
frozenset([var_name]))
238+
239+
access_exprs: Tuple[ExpressionT, ...]
240+
241+
for insn in subst_kernel.instructions:
242+
if var_name in insn.dependency_names():
243+
if isinstance(insn, MultiAssignmentBase):
244+
access_exprs = (insn.assignees,
245+
insn.expression,
246+
tuple(insn.predicates))
247+
elif isinstance(insn, (CInstruction, BarrierInstruction)):
248+
access_exprs = tuple(insn.predicates)
249+
else:
250+
raise NotImplementedError(type(insn))
251+
252+
access_map_recorder(access_exprs, insn.within_inames)
253+
254+
vng = kernel.get_var_name_generator()
255+
new_var_name = vng(var_name+"_reindexed")
256+
257+
access_range = access_map_recorder.get_access_range(var_name)
258+
259+
del subst_kernel
260+
del access_map_recorder
261+
262+
# }}}
263+
264+
subst, new_shape = _get_seghir_loechner_reindexing_from_range(
265+
access_range)
266+
267+
# {{{ simplify new_shape with the assumptions from kernel
268+
269+
new_shape = new_shape.gist_params(kernel.assumptions)
270+
271+
# }}}
272+
273+
# {{{ update kernel.temporary_variables
274+
275+
new_shape = new_shape.drop_unused_params()
276+
277+
new_temps = dict(kernel.temporary_variables).copy()
278+
new_temps[new_var_name] = new_temps.pop(var_name).copy(
279+
name=new_var_name,
280+
shape=pw_qpolynomial_to_expr(new_shape),
281+
strides=None,
282+
dim_tags=None,
283+
dim_names=None,
284+
)
285+
286+
kernel = kernel.copy(temporary_variables=new_temps)
287+
288+
# }}}
289+
290+
# {{{ perform the substitution i.e. reindex the accesses
291+
292+
subst_expr = pw_qpolynomial_to_expr(subst)
293+
subst_dim_names = tuple(
294+
subst.space.get_dim_name(isl.dim_type.param, idim)
295+
for idim in range(access_range.dim(isl.dim_type.out)))
296+
assert not (set(subst_dim_names) & kernel.all_variable_names())
297+
298+
rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions,
299+
vng)
300+
reindexing_mapper = ReindexingApplier(rule_mapping_context,
301+
var_name, new_var_name,
302+
subst_expr, subst_dim_names)
303+
304+
def _does_access_var_name(kernel, insn, *args):
305+
return var_name in insn.dependency_names()
306+
307+
kernel = reindexing_mapper.map_kernel(kernel,
308+
within=_does_access_var_name,
309+
map_args=False,
310+
map_tvs=False)
311+
kernel = rule_mapping_context.finish_kernel(kernel)
312+
313+
# }}}
314+
315+
# Note: Distributing a piece of code that depends on loopy and distributes
316+
# code that conditionally/unconditionally calls this routine does *NOT*
317+
# become a derivative of GPLv2. Since, as per point (0) of GPLV2 a
318+
# derivative is defined as: "a work containing the Program or a portion of
319+
# it, either verbatim or with modifications and/or translated into another
320+
# language."
321+
#
322+
# Loopy does *NOT* contain any portion of the barvinok library in it's
323+
# source code.
324+
325+
return kernel
326+
327+
# vim: fdm=marker

0 commit comments

Comments
 (0)