Skip to content

Commit b6710b5

Browse files
committed
codegen: Fix issue with permuted maps
If we have a parallel loop with map1 = Map(...) map2 = PermutedMap(map1, ...) par_loop(..., d1(..., map1), d2(..., map2)) The codegen should only require one global argument (for the data from map1) since map2 is a local permutation of the global argument. This was being provided by runtime argument passing, but not by the code generator, which treated map1 and map2 as distinct things. To fix this, create a new PMap wrapper for Map objects in the codegen builder that just know how to index themselves through their permutation. Now we can share the underlying global map data between the two local map accesses.
1 parent be8adea commit b6710b5

1 file changed

Lines changed: 55 additions & 31 deletions

File tree

pyop2/codegen/builder.py

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,18 @@ class Map(object):
3030

3131
__slots__ = ("values", "offset", "interior_horizontal",
3232
"variable", "unroll", "layer_bounds",
33-
"prefetch", "permutation")
33+
"prefetch", "_pmap_count")
3434

3535
def __init__(self, map_, interior_horizontal, layer_bounds,
36-
values=None, offset=None, unroll=False):
36+
offset=None, unroll=False):
3737
self.variable = map_.iterset._extruded and not map_.iterset.constant_layers
3838
self.unroll = unroll
3939
self.layer_bounds = layer_bounds
4040
self.interior_horizontal = interior_horizontal
4141
self.prefetch = {}
42-
if values is not None:
43-
raise RuntimeError
44-
self.values = values
45-
if map_.offset is not None:
46-
assert offset is not None
47-
self.offset = offset
48-
return
49-
5042
offset = map_.offset
5143
shape = (None, ) + map_.shape[1:]
5244
values = Argument(shape, dtype=map_.dtype, pfx="map")
53-
if isinstance(map_, PermutedMap):
54-
self.permutation = NamedLiteral(map_.permutation, parent=values, suffix="permutation")
55-
if offset is not None:
56-
offset = offset[map_.permutation]
57-
else:
58-
self.permutation = None
5945
if offset is not None:
6046
if len(set(map_.offset)) == 1:
6147
offset = Literal(offset[0], casting=True)
@@ -64,6 +50,7 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
6450

6551
self.values = values
6652
self.offset = offset
53+
self._pmap_count = itertools.count()
6754

6855
@property
6956
def shape(self):
@@ -73,7 +60,7 @@ def shape(self):
7360
def dtype(self):
7461
return self.values.dtype
7562

76-
def indexed(self, multiindex, layer=None):
63+
def indexed(self, multiindex, layer=None, permute=lambda x: x):
7764
n, i, f = multiindex
7865
if layer is not None and self.offset is not None:
7966
# For extruded mesh, prefetch the indirections for each map, so that they don't
@@ -82,10 +69,7 @@ def indexed(self, multiindex, layer=None):
8269
base_key = None
8370
if base_key not in self.prefetch:
8471
j = Index()
85-
if self.permutation is None:
86-
base = Indexed(self.values, (n, j))
87-
else:
88-
base = Indexed(self.values, (n, Indexed(self.permutation, (j,))))
72+
base = Indexed(self.values, (n, permute(j)))
8973
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))
9074

9175
base = self.prefetch[base_key]
@@ -112,26 +96,58 @@ def indexed(self, multiindex, layer=None):
11296
return Indexed(self.prefetch[key], (f, i)), (f, i)
11397
else:
11498
assert f.extent == 1 or f.extent is None
115-
if self.permutation is None:
116-
base = Indexed(self.values, (n, i))
117-
else:
118-
base = Indexed(self.values, (n, Indexed(self.permutation, (i,))))
99+
base = Indexed(self.values, (n, permute(i)))
119100
return base, (f, i)
120101

121-
def indexed_vector(self, n, shape, layer=None):
102+
def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
122103
shape = self.shape[1:] + shape
123104
if self.interior_horizontal:
124105
shape = (2, ) + shape
125106
else:
126107
shape = (1, ) + shape
127108
f, i, j = (Index(e) for e in shape)
128-
base, (f, i) = self.indexed((n, i, f), layer=layer)
109+
base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
129110
init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
130111
pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
131112
multiindex = tuple(Index(e) for e in pack.shape)
132113
return Indexed(pack, multiindex), multiindex
133114

134115

116+
class PMap(Map):
117+
__slots__ = ("permutation",)
118+
119+
def __init__(self, map_, permutation):
120+
# Copy over properties
121+
self.variable = map_.variable
122+
self.unroll = map_.unroll
123+
self.layer_bounds = map_.layer_bounds
124+
self.interior_horizontal = map_.interior_horizontal
125+
self.prefetch = {}
126+
self.values = map_.values
127+
self.offset = map_.offset
128+
offset = map_.offset
129+
# TODO: this is a hack, rep2loopy should be in charge of
130+
# generating all names!
131+
count = next(map_._pmap_count)
132+
if offset is not None:
133+
if offset.shape:
134+
# Have a named literal
135+
offset = offset.value[permutation]
136+
offset = NamedLiteral(offset, parent=self.values, suffix=f"permutation{count}_offset")
137+
else:
138+
offset = map_.offset
139+
self.offset = offset
140+
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")
141+
142+
def indexed(self, multiindex, layer=None):
143+
permute = lambda x: Indexed(self.permutation, (x,))
144+
return super().indexed(multiindex, layer=layer, permute=permute)
145+
146+
def indexed_vector(self, n, shape, layer=None):
147+
permute = lambda x: Indexed(self.permutation, (x,))
148+
return super().indexed_vector(n, shape, layer=layer, permute=permute)
149+
150+
135151
class Pack(metaclass=ABCMeta):
136152

137153
def pick_loop_indices(self, loop_index, layer_index=None, entity_index=None):
@@ -818,9 +834,13 @@ def map_(self, map_, unroll=False):
818834
try:
819835
return self.maps[key]
820836
except KeyError:
821-
map_ = Map(map_, interior_horizontal,
822-
(self.bottom_layer, self.top_layer),
823-
unroll=unroll)
837+
if isinstance(map_, PermutedMap):
838+
imap = self.map_(map_.map_, unroll=unroll)
839+
map_ = PMap(imap, map_.permutation)
840+
else:
841+
map_ = Map(map_, interior_horizontal,
842+
(self.bottom_layer, self.top_layer),
843+
unroll=unroll)
824844
self.maps[key] = map_
825845
return map_
826846

@@ -854,7 +874,11 @@ def wrapper_args(self):
854874
args.extend(self.arguments)
855875
# maps are refcounted
856876
for map_ in self.maps.values():
857-
args.append(map_.values)
877+
# But we don't need to emit stuff for PMaps because they
878+
# are a Map (already seen + a permutation [encoded in the
879+
# indexing]).
880+
if not isinstance(map_, PMap):
881+
args.append(map_.values)
858882
return tuple(args)
859883

860884
def kernel_call(self):

0 commit comments

Comments
 (0)