Skip to content

Commit a737a31

Browse files
authored
Merge pull request #643 from OP2/wence/fix/permuted-map
Wence/fix/permuted map
2 parents be8adea + d8c7509 commit a737a31

File tree

2 files changed

+94
-31
lines changed

2 files changed

+94
-31
lines changed

pyop2/codegen/builder.py

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,18 @@ class Map(object):
3030

3131
__slots__ = ("values", "offset", "interior_horizontal",
3232
"variable", "unroll", "layer_bounds",
33-
"prefetch", "permutation")
33+
"prefetch", "_pmap_count")
3434

3535
def __init__(self, map_, interior_horizontal, layer_bounds,
36-
values=None, offset=None, unroll=False):
36+
offset=None, unroll=False):
3737
self.variable = map_.iterset._extruded and not map_.iterset.constant_layers
3838
self.unroll = unroll
3939
self.layer_bounds = layer_bounds
4040
self.interior_horizontal = interior_horizontal
4141
self.prefetch = {}
42-
if values is not None:
43-
raise RuntimeError
44-
self.values = values
45-
if map_.offset is not None:
46-
assert offset is not None
47-
self.offset = offset
48-
return
49-
5042
offset = map_.offset
5143
shape = (None, ) + map_.shape[1:]
5244
values = Argument(shape, dtype=map_.dtype, pfx="map")
53-
if isinstance(map_, PermutedMap):
54-
self.permutation = NamedLiteral(map_.permutation, parent=values, suffix="permutation")
55-
if offset is not None:
56-
offset = offset[map_.permutation]
57-
else:
58-
self.permutation = None
5945
if offset is not None:
6046
if len(set(map_.offset)) == 1:
6147
offset = Literal(offset[0], casting=True)
@@ -64,6 +50,7 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
6450

6551
self.values = values
6652
self.offset = offset
53+
self._pmap_count = itertools.count()
6754

6855
@property
6956
def shape(self):
@@ -73,7 +60,7 @@ def shape(self):
7360
def dtype(self):
7461
return self.values.dtype
7562

76-
def indexed(self, multiindex, layer=None):
63+
def indexed(self, multiindex, layer=None, permute=lambda x: x):
7764
n, i, f = multiindex
7865
if layer is not None and self.offset is not None:
7966
# For extruded mesh, prefetch the indirections for each map, so that they don't
@@ -82,10 +69,7 @@ def indexed(self, multiindex, layer=None):
8269
base_key = None
8370
if base_key not in self.prefetch:
8471
j = Index()
85-
if self.permutation is None:
86-
base = Indexed(self.values, (n, j))
87-
else:
88-
base = Indexed(self.values, (n, Indexed(self.permutation, (j,))))
72+
base = Indexed(self.values, (n, permute(j)))
8973
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))
9074

9175
base = self.prefetch[base_key]
@@ -112,26 +96,58 @@ def indexed(self, multiindex, layer=None):
11296
return Indexed(self.prefetch[key], (f, i)), (f, i)
11397
else:
11498
assert f.extent == 1 or f.extent is None
115-
if self.permutation is None:
116-
base = Indexed(self.values, (n, i))
117-
else:
118-
base = Indexed(self.values, (n, Indexed(self.permutation, (i,))))
99+
base = Indexed(self.values, (n, permute(i)))
119100
return base, (f, i)
120101

121-
def indexed_vector(self, n, shape, layer=None):
102+
def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
122103
shape = self.shape[1:] + shape
123104
if self.interior_horizontal:
124105
shape = (2, ) + shape
125106
else:
126107
shape = (1, ) + shape
127108
f, i, j = (Index(e) for e in shape)
128-
base, (f, i) = self.indexed((n, i, f), layer=layer)
109+
base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
129110
init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
130111
pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
131112
multiindex = tuple(Index(e) for e in pack.shape)
132113
return Indexed(pack, multiindex), multiindex
133114

134115

116+
class PMap(Map):
117+
__slots__ = ("permutation",)
118+
119+
def __init__(self, map_, permutation):
120+
# Copy over properties
121+
self.variable = map_.variable
122+
self.unroll = map_.unroll
123+
self.layer_bounds = map_.layer_bounds
124+
self.interior_horizontal = map_.interior_horizontal
125+
self.prefetch = {}
126+
self.values = map_.values
127+
self.offset = map_.offset
128+
offset = map_.offset
129+
# TODO: this is a hack, rep2loopy should be in charge of
130+
# generating all names!
131+
count = next(map_._pmap_count)
132+
if offset is not None:
133+
if offset.shape:
134+
# Have a named literal
135+
offset = offset.value[permutation]
136+
offset = NamedLiteral(offset, parent=self.values, suffix=f"permutation{count}_offset")
137+
else:
138+
offset = map_.offset
139+
self.offset = offset
140+
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")
141+
142+
def indexed(self, multiindex, layer=None):
143+
permute = lambda x: Indexed(self.permutation, (x,))
144+
return super().indexed(multiindex, layer=layer, permute=permute)
145+
146+
def indexed_vector(self, n, shape, layer=None):
147+
permute = lambda x: Indexed(self.permutation, (x,))
148+
return super().indexed_vector(n, shape, layer=layer, permute=permute)
149+
150+
135151
class Pack(metaclass=ABCMeta):
136152

137153
def pick_loop_indices(self, loop_index, layer_index=None, entity_index=None):
@@ -818,9 +834,13 @@ def map_(self, map_, unroll=False):
818834
try:
819835
return self.maps[key]
820836
except KeyError:
821-
map_ = Map(map_, interior_horizontal,
822-
(self.bottom_layer, self.top_layer),
823-
unroll=unroll)
837+
if isinstance(map_, PermutedMap):
838+
imap = self.map_(map_.map_, unroll=unroll)
839+
map_ = PMap(imap, map_.permutation)
840+
else:
841+
map_ = Map(map_, interior_horizontal,
842+
(self.bottom_layer, self.top_layer),
843+
unroll=unroll)
824844
self.maps[key] = map_
825845
return map_
826846

@@ -854,7 +874,11 @@ def wrapper_args(self):
854874
args.extend(self.arguments)
855875
# maps are refcounted
856876
for map_ in self.maps.values():
857-
args.append(map_.values)
877+
# But we don't need to emit stuff for PMaps because they
878+
# are a Map (already seen + a permutation [encoded in the
879+
# indexing]).
880+
if not isinstance(map_, PMap):
881+
args.append(map_.values)
858882
return tuple(args)
859883

860884
def kernel_call(self):

test/unit/test_indirect_loop.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,45 @@ def test_mixed_non_mixed_dat_itspace(self, mdat, mmap, iterset):
278278
assert all(mdat[0].data == 1.0) and mdat[1].data == 4096.0
279279

280280

281+
def test_permuted_map():
282+
fromset = op2.Set(1)
283+
toset = op2.Set(4)
284+
d1 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
285+
d2 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
286+
d1.data[:] = np.arange(4, dtype=np.int32)
287+
k = op2.Kernel("""
288+
void copy(int *to, const int * restrict from) {
289+
for (int i = 0; i < 4; i++) { to[i] = from[i]; }
290+
}""", "copy")
291+
m1 = op2.Map(fromset, toset, 4, values=[1, 2, 3, 0])
292+
m2 = op2.PermutedMap(m1, [3, 2, 0, 1])
293+
op2.par_loop(k, fromset, d2(op2.WRITE, m2), d1(op2.READ, m1))
294+
expect = np.empty_like(d1.data)
295+
expect[m1.values[..., m2.permutation]] = d1.data[m1.values]
296+
assert (d1.data == np.arange(4, dtype=np.int32)).all()
297+
assert (d2.data == expect).all()
298+
299+
300+
def test_permuted_map_both():
301+
fromset = op2.Set(1)
302+
toset = op2.Set(4)
303+
d1 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
304+
d2 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
305+
d1.data[:] = np.arange(4, dtype=np.int32)
306+
k = op2.Kernel("""
307+
void copy(int *to, const int * restrict from) {
308+
for (int i = 0; i < 4; i++) { to[i] = from[i]; }
309+
}""", "copy")
310+
m1 = op2.Map(fromset, toset, 4, values=[0, 2, 1, 3])
311+
m2 = op2.PermutedMap(m1, [3, 2, 1, 0])
312+
m3 = op2.PermutedMap(m1, [0, 2, 3, 1])
313+
op2.par_loop(k, fromset, d2(op2.WRITE, m2), d1(op2.READ, m3))
314+
expect = np.empty_like(d1.data)
315+
expect[m1.values[..., m2.permutation]] = d1.data[m1.values[..., m3.permutation]]
316+
assert (d1.data == np.arange(4, dtype=np.int32)).all()
317+
assert (d2.data == expect).all()
318+
319+
281320
if __name__ == '__main__':
282321
import os
283322
pytest.main(os.path.abspath(__file__))

0 commit comments

Comments
 (0)