Skip to content

Commit dbe1ded

Browse files
committed
test reindex_using_seghir_loechner_scheme
1 parent cb84277 commit dbe1ded

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

test/test_transform.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,62 @@ class FooTag(Tag):
14781478
assert t_unit.default_entrypoint.inames["i_0"].tags_of_type(FooTag) # fails
14791479

14801480

1481+
def test_reindexing_strided_access(ctx_factory):
1482+
import islpy as isl
1483+
1484+
if not hasattr(isl.Set, "card"):
1485+
pytest.skip("No barvinok support")
1486+
1487+
ctx = ctx_factory()
1488+
1489+
tunit = lp.make_kernel(
1490+
"{[i, j]: 0<=j,i<10}",
1491+
"""
1492+
<> tmp[2*i, 2*j] = a[i, j]
1493+
out[i, j] = tmp[2*i, 2*j]**2
1494+
""")
1495+
1496+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1497+
ref_tunit = tunit
1498+
1499+
knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint,
1500+
"tmp")
1501+
tunit = tunit.with_kernel(knl)
1502+
1503+
tv, = tunit.default_entrypoint.temporary_variables.values()
1504+
assert tv.shape == (100,)
1505+
1506+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1507+
1508+
1509+
def test_reindexing_figurate(ctx_factory):
1510+
import islpy as isl
1511+
1512+
if not hasattr(isl.Set, "card"):
1513+
pytest.skip("No barvinok support")
1514+
1515+
ctx = ctx_factory()
1516+
1517+
tunit = lp.make_kernel(
1518+
"{[i, j]: 0<=j<=i<10}",
1519+
"""
1520+
<> tmp[2*i, 2*j] = a[i, j]
1521+
out[i, j] = tmp[2*i, 2*j]**2
1522+
""")
1523+
1524+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1525+
ref_tunit = tunit
1526+
1527+
knl = lp.reindex_using_seghir_loechner_scheme(tunit.default_entrypoint,
1528+
"tmp")
1529+
tunit = tunit.with_kernel(knl)
1530+
1531+
tv, = tunit.default_entrypoint.temporary_variables.values()
1532+
assert tv.shape == (55,)
1533+
1534+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1535+
1536+
14811537
if __name__ == "__main__":
14821538
if len(sys.argv) > 1:
14831539
exec(sys.argv[1])

0 commit comments

Comments
 (0)