Skip to content

Commit d8c7509

Browse files
committed
Add test using permuted maps
1 parent b6710b5 commit d8c7509

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/unit/test_indirect_loop.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,45 @@ def test_mixed_non_mixed_dat_itspace(self, mdat, mmap, iterset):
278278
assert all(mdat[0].data == 1.0) and mdat[1].data == 4096.0
279279

280280

281+
def test_permuted_map():
282+
fromset = op2.Set(1)
283+
toset = op2.Set(4)
284+
d1 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
285+
d2 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
286+
d1.data[:] = np.arange(4, dtype=np.int32)
287+
k = op2.Kernel("""
288+
void copy(int *to, const int * restrict from) {
289+
for (int i = 0; i < 4; i++) { to[i] = from[i]; }
290+
}""", "copy")
291+
m1 = op2.Map(fromset, toset, 4, values=[1, 2, 3, 0])
292+
m2 = op2.PermutedMap(m1, [3, 2, 0, 1])
293+
op2.par_loop(k, fromset, d2(op2.WRITE, m2), d1(op2.READ, m1))
294+
expect = np.empty_like(d1.data)
295+
expect[m1.values[..., m2.permutation]] = d1.data[m1.values]
296+
assert (d1.data == np.arange(4, dtype=np.int32)).all()
297+
assert (d2.data == expect).all()
298+
299+
300+
def test_permuted_map_both():
301+
fromset = op2.Set(1)
302+
toset = op2.Set(4)
303+
d1 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
304+
d2 = op2.Dat(op2.DataSet(toset, 1), dtype=np.int32)
305+
d1.data[:] = np.arange(4, dtype=np.int32)
306+
k = op2.Kernel("""
307+
void copy(int *to, const int * restrict from) {
308+
for (int i = 0; i < 4; i++) { to[i] = from[i]; }
309+
}""", "copy")
310+
m1 = op2.Map(fromset, toset, 4, values=[0, 2, 1, 3])
311+
m2 = op2.PermutedMap(m1, [3, 2, 1, 0])
312+
m3 = op2.PermutedMap(m1, [0, 2, 3, 1])
313+
op2.par_loop(k, fromset, d2(op2.WRITE, m2), d1(op2.READ, m3))
314+
expect = np.empty_like(d1.data)
315+
expect[m1.values[..., m2.permutation]] = d1.data[m1.values[..., m3.permutation]]
316+
assert (d1.data == np.arange(4, dtype=np.int32)).all()
317+
assert (d2.data == expect).all()
318+
319+
281320
if __name__ == '__main__':
282321
import os
283322
pytest.main(os.path.abspath(__file__))

0 commit comments

Comments
 (0)