|
| 1 | +""" |
| 2 | +.. currentmodule:: loopy |
| 3 | +
|
| 4 | +.. autofunction:: reindex_temporary_using_seghir_loechner_scheme |
| 5 | +""" |
| 6 | + |
| 7 | +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" |
| 8 | + |
| 9 | +__license__ = """ |
| 10 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 11 | +of this software and associated documentation files (the "Software"), to deal |
| 12 | +in the Software without restriction, including without limitation the rights |
| 13 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 14 | +copies of the Software, and to permit persons to whom the Software is |
| 15 | +furnished to do so, subject to the following conditions: |
| 16 | +
|
| 17 | +The above copyright notice and this permission notice shall be included in |
| 18 | +all copies or substantial portions of the Software. |
| 19 | +
|
| 20 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 21 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 22 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 23 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 24 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 25 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 26 | +THE SOFTWARE. |
| 27 | +""" |
| 28 | + |
| 29 | + |
| 30 | +import islpy as isl |
| 31 | +from typing import Union, Iterable, Tuple |
| 32 | +from loopy.typing import ExpressionT |
| 33 | +from loopy.kernel import LoopKernel |
| 34 | +from loopy.diagnostic import LoopyError |
| 35 | +from loopy.symbolic import CombineMapper |
| 36 | +from loopy.kernel.instruction import (MultiAssignmentBase, |
| 37 | + CInstruction, BarrierInstruction) |
| 38 | +from loopy.symbolic import RuleAwareIdentityMapper |
| 39 | + |
| 40 | + |
| 41 | +ISLMapT = Union[isl.BasicMap, isl.Map] |
| 42 | +ISLSetT = Union[isl.BasicSet, isl.Set] |
| 43 | + |
| 44 | + |
| 45 | +def _add_prime_to_dim_names(isl_map: ISLMapT, |
| 46 | + dts: Iterable[isl.dim_type]) -> ISLMapT: |
| 47 | + """ |
| 48 | + Returns a copy of *isl_map* with dims of types *dts* having their names |
| 49 | + suffixed with an apostrophe (``'``). |
| 50 | +
|
| 51 | + .. testsetup:: |
| 52 | +
|
| 53 | + >>> import islpy as isl |
| 54 | + >>> from loopy.transform.reindex import _add_prime_to_dim_names |
| 55 | +
|
| 56 | + .. doctest:: |
| 57 | +
|
| 58 | + >>> amap = isl.Map("{[i]->[j=2i]}") |
| 59 | + >>> _add_prime_to_dim_names(amap, [isl.dim_type.in_, isl.dim_type.out]) |
| 60 | + Map("{ [i'] -> [j' = 2i'] }") |
| 61 | + """ |
| 62 | + for dt in dts: |
| 63 | + for idim in range(isl_map.dim(dt)): |
| 64 | + old_name = isl_map.get_dim_name(dt, idim) |
| 65 | + new_name = f"{old_name}'" |
| 66 | + isl_map = isl_map.set_dim_name(dt, idim, new_name) |
| 67 | + |
| 68 | + return isl_map |
| 69 | + |
| 70 | + |
| 71 | +def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT |
| 72 | + ) -> Tuple[isl.PwQPolynomial, |
| 73 | + isl.PwQPolynomial]: |
| 74 | + """ |
| 75 | + Returns ``(reindex_map, new_shape)``, where, |
| 76 | +
|
| 77 | + * ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1, |
| 78 | + .., in)}`` representing that an array indexed via the subscripts |
| 79 | + ``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as |
| 80 | + ``f(i1, .., in)``. |
| 81 | + * ``new_shape`` is a quasi-polynomial corresponding to the shape of the |
| 82 | + re-indexed 1-dimensional array. |
| 83 | + """ |
| 84 | + |
| 85 | + # {{{ create amap: an ISL map which is an identity map from access_map's range |
| 86 | + |
| 87 | + amap = isl.BasicMap.identity( |
| 88 | + access_range |
| 89 | + .space |
| 90 | + .add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out))) |
| 91 | + |
| 92 | + # set amap's dim names |
| 93 | + for idim in range(amap.dim(isl.dim_type.in_)): |
| 94 | + amap = amap.set_dim_name(isl.dim_type.in_, idim, |
| 95 | + f"_lpy_in_{idim}") |
| 96 | + amap = amap.set_dim_name(isl.dim_type.out, idim, |
| 97 | + f"_lpy_out_{idim}") |
| 98 | + |
| 99 | + amap = amap.intersect_domain(access_range) |
| 100 | + |
| 101 | + # }}} |
| 102 | + |
| 103 | + n_in = amap.dim(isl.dim_type.out) |
| 104 | + n_out = amap.dim(isl.dim_type.out) |
| 105 | + |
| 106 | + amap_lexmin = amap.lexmin() |
| 107 | + primed_amap_lexmin = _add_prime_to_dim_names(amap_lexmin, [isl.dim_type.in_, |
| 108 | + isl.dim_type.out]) |
| 109 | + |
| 110 | + lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin) |
| 111 | + |
| 112 | + # make the lexmin map parametric in terms of it's previous access expressions. |
| 113 | + lex_lt_set = (lex_lt_map |
| 114 | + .move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in) |
| 115 | + .domain()) |
| 116 | + |
| 117 | + # {{{ initialize amap_to_count |
| 118 | + |
| 119 | + amap_to_count = _add_prime_to_dim_names(amap, [isl.dim_type.in_]) |
| 120 | + amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in) |
| 121 | + |
| 122 | + for idim in range(n_in): |
| 123 | + amap_to_count = amap_to_count.set_dim_name( |
| 124 | + isl.dim_type.param, idim, |
| 125 | + amap.get_dim_name(isl.dim_type.in_, idim)) |
| 126 | + |
| 127 | + amap_to_count = amap_to_count.intersect_domain(lex_lt_set) |
| 128 | + |
| 129 | + # }}} |
| 130 | + |
| 131 | + result = amap_to_count.range().card() |
| 132 | + |
| 133 | + # {{{ simplify 'result' by gisting with 'access_range' |
| 134 | + |
| 135 | + aligned_access_range = access_range.move_dims(isl.dim_type.param, 0, |
| 136 | + isl.dim_type.set, 0, n_out) |
| 137 | + |
| 138 | + for idim in range(result.dim(isl.dim_type.param)): |
| 139 | + aligned_access_range = ( |
| 140 | + aligned_access_range |
| 141 | + .set_dim_name(isl.dim_type.param, idim, |
| 142 | + result.space.get_dim_name(isl.dim_type.param, |
| 143 | + idim))) |
| 144 | + |
| 145 | + result = result.gist_params(aligned_access_range.params()) |
| 146 | + |
| 147 | + # }}} |
| 148 | + |
| 149 | + return result, access_range.card() |
| 150 | + |
| 151 | + |
| 152 | +class _IndexCollector(CombineMapper): |
| 153 | + """ |
| 154 | + A mapper that collects all instances of |
| 155 | + :class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`. |
| 156 | + """ |
| 157 | + def __init__(self, var_name): |
| 158 | + super().__init__() |
| 159 | + self.var_name = var_name |
| 160 | + |
| 161 | + def combine(self, values): |
| 162 | + from functools import reduce |
| 163 | + return reduce(frozenset.union, values, frozenset()) |
| 164 | + |
| 165 | + def map_subscript(self, expr): |
| 166 | + if expr.aggregate.name == self.var_name: |
| 167 | + return frozenset([expr]) | super().map_subscript(expr) |
| 168 | + else: |
| 169 | + return super().map_subscript(expr) |
| 170 | + |
| 171 | + def map_constant(self, expr): |
| 172 | + return frozenset() |
| 173 | + |
| 174 | + map_variable = map_constant |
| 175 | + map_function_symbol = map_constant |
| 176 | + map_tagged_variable = map_constant |
| 177 | + map_type_cast = map_constant |
| 178 | + map_nan = map_constant |
| 179 | + |
| 180 | + |
| 181 | +class ReindexingApplier(RuleAwareIdentityMapper): |
| 182 | + def __init__(self, rule_mapping_context, |
| 183 | + var_to_reindex, |
| 184 | + reindexed_var_name, |
| 185 | + new_index_expr, |
| 186 | + index_names): |
| 187 | + |
| 188 | + super().__init__(rule_mapping_context) |
| 189 | + |
| 190 | + self.var_to_reindex = var_to_reindex |
| 191 | + self.reindexed_var_name = reindexed_var_name |
| 192 | + self.new_index_expr = new_index_expr |
| 193 | + self.index_names = index_names |
| 194 | + |
| 195 | + def map_subscript(self, expr, expn_state): |
| 196 | + if expr.aggregate.name != self.var_to_reindex: |
| 197 | + return super().map_subscript(expr, expn_state) |
| 198 | + |
| 199 | + from loopy.symbolic import SubstitutionMapper |
| 200 | + from pymbolic.mapper.substitutor import make_subst_func |
| 201 | + from pymbolic.primitives import Subscript, Variable |
| 202 | + |
| 203 | + rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple) |
| 204 | + |
| 205 | + assert len(self.index_names) == len(rec_indices) |
| 206 | + subst_func = make_subst_func(dict(zip(self.index_names, rec_indices))) |
| 207 | + |
| 208 | + return SubstitutionMapper(subst_func)( |
| 209 | + Subscript(Variable(self.reindexed_var_name), |
| 210 | + self.new_index_expr) |
| 211 | + ) |
| 212 | + |
| 213 | + |
| 214 | +def reindex_temporary_using_seghir_loechner_scheme(kernel: LoopKernel, |
| 215 | + var_name: str, |
| 216 | + ) -> LoopKernel: |
| 217 | + """ |
| 218 | + Returns a kernel with expressions of the form ``var_name[i1, .., in]`` |
| 219 | + replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a |
| 220 | + quasi-polynomial as outlined in [Seghir_2006]_. |
| 221 | + """ |
| 222 | + from loopy.transform.subst import expand_subst |
| 223 | + from loopy.symbolic import (BatchedAccessMapMapper, pw_qpolynomial_to_expr, |
| 224 | + SubstitutionRuleMappingContext) |
| 225 | + |
| 226 | + if var_name not in kernel.temporary_variables: |
| 227 | + raise LoopyError(f"'{var_name}' not in temporary variable in kernel" |
| 228 | + f" '{kernel.name}'.") |
| 229 | + |
| 230 | + # {{{ compute the access_range of *var_name* in *kernel* |
| 231 | + |
| 232 | + subst_kernel = expand_subst(kernel) |
| 233 | + access_map_recorder = BatchedAccessMapMapper( |
| 234 | + subst_kernel, |
| 235 | + frozenset([var_name])) |
| 236 | + |
| 237 | + access_exprs: Tuple[ExpressionT, ...] |
| 238 | + |
| 239 | + for insn in subst_kernel.instructions: |
| 240 | + if var_name in insn.dependency_names(): |
| 241 | + if isinstance(insn, MultiAssignmentBase): |
| 242 | + access_exprs = (insn.assignees, |
| 243 | + insn.expression, |
| 244 | + tuple(insn.predicates)) |
| 245 | + elif isinstance(insn, (CInstruction, BarrierInstruction)): |
| 246 | + access_exprs = tuple(insn.predicates) |
| 247 | + else: |
| 248 | + raise NotImplementedError(type(insn)) |
| 249 | + |
| 250 | + access_map_recorder(access_exprs, insn.within_inames) |
| 251 | + |
| 252 | + vng = kernel.get_var_name_generator() |
| 253 | + new_var_name = vng(var_name+"_reindexed") |
| 254 | + |
| 255 | + access_range = access_map_recorder.get_access_range(var_name) |
| 256 | + |
| 257 | + del subst_kernel |
| 258 | + del access_map_recorder |
| 259 | + |
| 260 | + # }}} |
| 261 | + |
| 262 | + subst, new_shape = _get_seghir_loechner_reindexing_from_range( |
| 263 | + access_range) |
| 264 | + |
| 265 | + # {{{ simplify new_shape with the assumptions from kernel |
| 266 | + |
| 267 | + new_shape = new_shape.gist_params(kernel.assumptions) |
| 268 | + |
| 269 | + # }}} |
| 270 | + |
| 271 | + # {{{ update kernel.temporary_variables |
| 272 | + |
| 273 | + new_shape = new_shape.drop_unused_params() |
| 274 | + |
| 275 | + new_temps = dict(kernel.temporary_variables).copy() |
| 276 | + new_temps[new_var_name] = new_temps.pop(var_name).copy( |
| 277 | + name=new_var_name, |
| 278 | + shape=pw_qpolynomial_to_expr(new_shape), |
| 279 | + strides=None, |
| 280 | + dim_tags=None, |
| 281 | + dim_names=None, |
| 282 | + ) |
| 283 | + |
| 284 | + kernel = kernel.copy(temporary_variables=new_temps) |
| 285 | + |
| 286 | + # }}} |
| 287 | + |
| 288 | + # {{{ perform the substitution i.e. reindex the accesses |
| 289 | + |
| 290 | + subst_expr = pw_qpolynomial_to_expr(subst) |
| 291 | + subst_dim_names = tuple( |
| 292 | + subst.space.get_dim_name(isl.dim_type.param, idim) |
| 293 | + for idim in range(access_range.dim(isl.dim_type.out))) |
| 294 | + assert not (set(subst_dim_names) & kernel.all_variable_names()) |
| 295 | + |
| 296 | + rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions, |
| 297 | + vng) |
| 298 | + reindexing_mapper = ReindexingApplier(rule_mapping_context, |
| 299 | + var_name, new_var_name, |
| 300 | + subst_expr, subst_dim_names) |
| 301 | + |
| 302 | + def _does_access_var_name(kernel, insn, *args): |
| 303 | + return var_name in insn.dependency_names() |
| 304 | + |
| 305 | + kernel = reindexing_mapper.map_kernel(kernel, |
| 306 | + within=_does_access_var_name, |
| 307 | + map_args=False, |
| 308 | + map_tvs=False) |
| 309 | + kernel = rule_mapping_context.finish_kernel(kernel) |
| 310 | + |
| 311 | + # }}} |
| 312 | + |
| 313 | + # Note: Distributing a piece of code that depends on loopy and distributes |
| 314 | + # code that conditionally/unconditionally calls this routine does *NOT* |
| 315 | + # become a derivative of GPLv2. Since, as per point (0) of GPLV2 a |
| 316 | + # derivative is defined as: "a work containing the Program or a portion of |
| 317 | + # it, either verbatim or with modifications and/or translated into another |
| 318 | + # language." |
| 319 | + # |
| 320 | + # Loopy does *NOT* contain any portion of the barvinok library in it's |
| 321 | + # source code. |
| 322 | + |
| 323 | + return kernel |
| 324 | + |
| 325 | +# vim: fdm=marker |
0 commit comments