Skip to content

Commit 9c9c41d

Browse files
kaushikcfdinducer
authored andcommitted
Implements Loechner Reindexing
1 parent ea6bb44 commit 9c9c41d

4 files changed

Lines changed: 338 additions & 0 deletions

File tree

doc/misc.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,13 @@ Here's a Bibtex entry for your convenience::
456456
doi = "{10.1145/2627373.2627387}",
457457
}
458458

459+
References
460+
==========
461+
462+
.. [Seghir_2006] Seghir and Loechner, Proceedings of the 2006 International
463+
Conference on Compilers, Architecture and Synthesis for Embedded systems,
464+
`(DOI) <https://doi.org/10.1145/1176760.1176771>`__
465+
459466
Getting help
460467
============
461468

doc/ref_transform.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ Influencing data access
5454

5555
.. autofunction:: allocate_temporaries_for_base_storage
5656

57+
.. automodule:: loopy.transform.reindex
58+
5759
Padding Data
5860
------------
5961

loopy/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@
108108
from loopy.transform.fusion import fuse_kernels
109109
from loopy.transform.concatenate import concatenate_arrays
110110

111+
from loopy.transform.reindex import reindex_temporary_using_seghir_loechner_scheme
112+
111113
from loopy.transform.arithmetic import (
112114
fold_constants,
113115
collect_common_factors_on_increment)
@@ -239,6 +241,8 @@
239241

240242
"fold_constants", "collect_common_factors_on_increment",
241243

244+
"reindex_temporary_using_seghir_loechner_scheme",
245+
242246
"split_array_axis", "split_array_dim", "split_arg_axis",
243247
"find_padding_multiple", "add_padding",
244248

loopy/transform/reindex.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
.. currentmodule:: loopy
3+
4+
.. autofunction:: reindex_temporary_using_seghir_loechner_scheme
5+
"""
6+
7+
__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
8+
9+
__license__ = """
10+
Permission is hereby granted, free of charge, to any person obtaining a copy
11+
of this software and associated documentation files (the "Software"), to deal
12+
in the Software without restriction, including without limitation the rights
13+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14+
copies of the Software, and to permit persons to whom the Software is
15+
furnished to do so, subject to the following conditions:
16+
17+
The above copyright notice and this permission notice shall be included in
18+
all copies or substantial portions of the Software.
19+
20+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
26+
THE SOFTWARE.
27+
"""
28+
29+
30+
import islpy as isl
31+
from typing import Union, Iterable, Tuple
32+
from loopy.typing import ExpressionT
33+
from loopy.kernel import LoopKernel
34+
from loopy.diagnostic import LoopyError
35+
from loopy.symbolic import CombineMapper
36+
from loopy.kernel.instruction import (MultiAssignmentBase,
37+
CInstruction, BarrierInstruction)
38+
from loopy.symbolic import RuleAwareIdentityMapper
39+
40+
41+
ISLMapT = Union[isl.BasicMap, isl.Map]
42+
ISLSetT = Union[isl.BasicSet, isl.Set]
43+
44+
45+
def _add_prime_to_dim_names(isl_map: ISLMapT,
46+
dts: Iterable[isl.dim_type]) -> ISLMapT:
47+
"""
48+
Returns a copy of *isl_map* with dims of types *dts* having their names
49+
suffixed with an apostrophe (``'``).
50+
51+
.. testsetup::
52+
53+
>>> import islpy as isl
54+
>>> from loopy.transform.reindex import _add_prime_to_dim_names
55+
56+
.. doctest::
57+
58+
>>> amap = isl.Map("{[i]->[j=2i]}")
59+
>>> _add_prime_to_dim_names(amap, [isl.dim_type.in_, isl.dim_type.out])
60+
Map("{ [i'] -> [j' = 2i'] }")
61+
"""
62+
for dt in dts:
63+
for idim in range(isl_map.dim(dt)):
64+
old_name = isl_map.get_dim_name(dt, idim)
65+
new_name = f"{old_name}'"
66+
isl_map = isl_map.set_dim_name(dt, idim, new_name)
67+
68+
return isl_map
69+
70+
71+
def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT
72+
) -> Tuple[isl.PwQPolynomial,
73+
isl.PwQPolynomial]:
74+
"""
75+
Returns ``(reindex_map, new_shape)``, where,
76+
77+
* ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1,
78+
.., in)}`` representing that an array indexed via the subscripts
79+
``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as
80+
``f(i1, .., in)``.
81+
* ``new_shape`` is a quasi-polynomial corresponding to the shape of the
82+
re-indexed 1-dimensional array.
83+
"""
84+
85+
# {{{ create amap: an ISL map which is an identity map from access_map's range
86+
87+
amap = isl.BasicMap.identity(
88+
access_range
89+
.space
90+
.add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out)))
91+
92+
# set amap's dim names
93+
for idim in range(amap.dim(isl.dim_type.in_)):
94+
amap = amap.set_dim_name(isl.dim_type.in_, idim,
95+
f"_lpy_in_{idim}")
96+
amap = amap.set_dim_name(isl.dim_type.out, idim,
97+
f"_lpy_out_{idim}")
98+
99+
amap = amap.intersect_domain(access_range)
100+
101+
# }}}
102+
103+
n_in = amap.dim(isl.dim_type.out)
104+
n_out = amap.dim(isl.dim_type.out)
105+
106+
amap_lexmin = amap.lexmin()
107+
primed_amap_lexmin = _add_prime_to_dim_names(amap_lexmin, [isl.dim_type.in_,
108+
isl.dim_type.out])
109+
110+
lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin)
111+
112+
# make the lexmin map parametric in terms of it's previous access expressions.
113+
lex_lt_set = (lex_lt_map
114+
.move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in)
115+
.domain())
116+
117+
# {{{ initialize amap_to_count
118+
119+
amap_to_count = _add_prime_to_dim_names(amap, [isl.dim_type.in_])
120+
amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in)
121+
122+
for idim in range(n_in):
123+
amap_to_count = amap_to_count.set_dim_name(
124+
isl.dim_type.param, idim,
125+
amap.get_dim_name(isl.dim_type.in_, idim))
126+
127+
amap_to_count = amap_to_count.intersect_domain(lex_lt_set)
128+
129+
# }}}
130+
131+
result = amap_to_count.range().card()
132+
133+
# {{{ simplify 'result' by gisting with 'access_range'
134+
135+
aligned_access_range = access_range.move_dims(isl.dim_type.param, 0,
136+
isl.dim_type.set, 0, n_out)
137+
138+
for idim in range(result.dim(isl.dim_type.param)):
139+
aligned_access_range = (
140+
aligned_access_range
141+
.set_dim_name(isl.dim_type.param, idim,
142+
result.space.get_dim_name(isl.dim_type.param,
143+
idim)))
144+
145+
result = result.gist_params(aligned_access_range.params())
146+
147+
# }}}
148+
149+
return result, access_range.card()
150+
151+
152+
class _IndexCollector(CombineMapper):
153+
"""
154+
A mapper that collects all instances of
155+
:class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`.
156+
"""
157+
def __init__(self, var_name):
158+
super().__init__()
159+
self.var_name = var_name
160+
161+
def combine(self, values):
162+
from functools import reduce
163+
return reduce(frozenset.union, values, frozenset())
164+
165+
def map_subscript(self, expr):
166+
if expr.aggregate.name == self.var_name:
167+
return frozenset([expr]) | super().map_subscript(expr)
168+
else:
169+
return super().map_subscript(expr)
170+
171+
def map_constant(self, expr):
172+
return frozenset()
173+
174+
map_variable = map_constant
175+
map_function_symbol = map_constant
176+
map_tagged_variable = map_constant
177+
map_type_cast = map_constant
178+
map_nan = map_constant
179+
180+
181+
class ReindexingApplier(RuleAwareIdentityMapper):
182+
def __init__(self, rule_mapping_context,
183+
var_to_reindex,
184+
reindexed_var_name,
185+
new_index_expr,
186+
index_names):
187+
188+
super().__init__(rule_mapping_context)
189+
190+
self.var_to_reindex = var_to_reindex
191+
self.reindexed_var_name = reindexed_var_name
192+
self.new_index_expr = new_index_expr
193+
self.index_names = index_names
194+
195+
def map_subscript(self, expr, expn_state):
196+
if expr.aggregate.name != self.var_to_reindex:
197+
return super().map_subscript(expr, expn_state)
198+
199+
from loopy.symbolic import SubstitutionMapper
200+
from pymbolic.mapper.substitutor import make_subst_func
201+
from pymbolic.primitives import Subscript, Variable
202+
203+
rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple)
204+
205+
assert len(self.index_names) == len(rec_indices)
206+
subst_func = make_subst_func(dict(zip(self.index_names, rec_indices)))
207+
208+
return SubstitutionMapper(subst_func)(
209+
Subscript(Variable(self.reindexed_var_name),
210+
self.new_index_expr)
211+
)
212+
213+
214+
def reindex_temporary_using_seghir_loechner_scheme(kernel: LoopKernel,
215+
var_name: str,
216+
) -> LoopKernel:
217+
"""
218+
Returns a kernel with expressions of the form ``var_name[i1, .., in]``
219+
replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a
220+
quasi-polynomial as outlined in [Seghir_2006]_.
221+
"""
222+
from loopy.transform.subst import expand_subst
223+
from loopy.symbolic import (BatchedAccessMapMapper, pw_qpolynomial_to_expr,
224+
SubstitutionRuleMappingContext)
225+
226+
if var_name not in kernel.temporary_variables:
227+
raise LoopyError(f"'{var_name}' not in temporary variable in kernel"
228+
f" '{kernel.name}'.")
229+
230+
# {{{ compute the access_range of *var_name* in *kernel*
231+
232+
subst_kernel = expand_subst(kernel)
233+
access_map_recorder = BatchedAccessMapMapper(
234+
subst_kernel,
235+
frozenset([var_name]))
236+
237+
access_exprs: Tuple[ExpressionT, ...]
238+
239+
for insn in subst_kernel.instructions:
240+
if var_name in insn.dependency_names():
241+
if isinstance(insn, MultiAssignmentBase):
242+
access_exprs = (insn.assignees,
243+
insn.expression,
244+
tuple(insn.predicates))
245+
elif isinstance(insn, (CInstruction, BarrierInstruction)):
246+
access_exprs = tuple(insn.predicates)
247+
else:
248+
raise NotImplementedError(type(insn))
249+
250+
access_map_recorder(access_exprs, insn.within_inames)
251+
252+
vng = kernel.get_var_name_generator()
253+
new_var_name = vng(var_name+"_reindexed")
254+
255+
access_range = access_map_recorder.get_access_range(var_name)
256+
257+
del subst_kernel
258+
del access_map_recorder
259+
260+
# }}}
261+
262+
subst, new_shape = _get_seghir_loechner_reindexing_from_range(
263+
access_range)
264+
265+
# {{{ simplify new_shape with the assumptions from kernel
266+
267+
new_shape = new_shape.gist_params(kernel.assumptions)
268+
269+
# }}}
270+
271+
# {{{ update kernel.temporary_variables
272+
273+
new_shape = new_shape.drop_unused_params()
274+
275+
new_temps = dict(kernel.temporary_variables).copy()
276+
new_temps[new_var_name] = new_temps.pop(var_name).copy(
277+
name=new_var_name,
278+
shape=pw_qpolynomial_to_expr(new_shape),
279+
strides=None,
280+
dim_tags=None,
281+
dim_names=None,
282+
)
283+
284+
kernel = kernel.copy(temporary_variables=new_temps)
285+
286+
# }}}
287+
288+
# {{{ perform the substitution i.e. reindex the accesses
289+
290+
subst_expr = pw_qpolynomial_to_expr(subst)
291+
subst_dim_names = tuple(
292+
subst.space.get_dim_name(isl.dim_type.param, idim)
293+
for idim in range(access_range.dim(isl.dim_type.out)))
294+
assert not (set(subst_dim_names) & kernel.all_variable_names())
295+
296+
rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions,
297+
vng)
298+
reindexing_mapper = ReindexingApplier(rule_mapping_context,
299+
var_name, new_var_name,
300+
subst_expr, subst_dim_names)
301+
302+
def _does_access_var_name(kernel, insn, *args):
303+
return var_name in insn.dependency_names()
304+
305+
kernel = reindexing_mapper.map_kernel(kernel,
306+
within=_does_access_var_name,
307+
map_args=False,
308+
map_tvs=False)
309+
kernel = rule_mapping_context.finish_kernel(kernel)
310+
311+
# }}}
312+
313+
# Note: Distributing a piece of code that depends on loopy and distributes
314+
# code that conditionally/unconditionally calls this routine does *NOT*
315+
# become a derivative of GPLv2. Since, as per point (0) of GPLV2 a
316+
# derivative is defined as: "a work containing the Program or a portion of
317+
# it, either verbatim or with modifications and/or translated into another
318+
# language."
319+
#
320+
# Loopy does *NOT* contain any portion of the barvinok library in it's
321+
# source code.
322+
323+
return kernel
324+
325+
# vim: fdm=marker

0 commit comments

Comments
 (0)