Skip to content

Commit 5305f47

Browse files
committed
test reindex_using_seghir_loechner_scheme
1 parent 9be2056 commit 5305f47

1 file changed

Lines changed: 89 additions & 0 deletions

File tree

test/test_transform.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,95 @@ class FooTag(Tag):
14781478
assert t_unit.default_entrypoint.inames["i_0"].tags_of_type(FooTag) # fails
14791479

14801480

1481+
def test_reindexing_strided_access(ctx_factory):
1482+
import islpy as isl
1483+
1484+
if not hasattr(isl.Set, "card"):
1485+
pytest.skip("No barvinok support")
1486+
1487+
ctx = ctx_factory()
1488+
1489+
tunit = lp.make_kernel(
1490+
"{[i, j]: 0<=j,i<10}",
1491+
"""
1492+
<> tmp[2*i, 2*j] = a[i, j]
1493+
out[i, j] = tmp[2*i, 2*j]**2
1494+
""")
1495+
1496+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1497+
ref_tunit = tunit
1498+
1499+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1500+
"tmp")
1501+
tunit = tunit.with_kernel(knl)
1502+
1503+
tv, = tunit.default_entrypoint.temporary_variables.values()
1504+
assert tv.shape == (100,)
1505+
1506+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1507+
1508+
1509+
def test_reindexing_figurate(ctx_factory):
1510+
import islpy as isl
1511+
1512+
if not hasattr(isl.Set, "card"):
1513+
pytest.skip("No barvinok support")
1514+
1515+
ctx = ctx_factory()
1516+
1517+
tunit = lp.make_kernel(
1518+
"{[i, j]: 0<=j<=i<10}",
1519+
"""
1520+
<> tmp[2*i, 2*j] = a[i, j]
1521+
out[i, j] = tmp[2*i, 2*j]**2
1522+
""")
1523+
1524+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1525+
ref_tunit = tunit
1526+
1527+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1528+
"tmp")
1529+
tunit = tunit.with_kernel(knl)
1530+
1531+
tv, = tunit.default_entrypoint.temporary_variables.values()
1532+
assert tv.shape == (55,)
1533+
1534+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1535+
1536+
1537+
def test_reindexing_figurate_parametric_shape(ctx_factory):
1538+
import islpy as isl
1539+
from loopy.symbolic import parse
1540+
1541+
if not hasattr(isl.Set, "card"):
1542+
pytest.skip("No barvinok support")
1543+
1544+
ctx = ctx_factory()
1545+
1546+
tunit = lp.make_kernel(
1547+
"{[i, j]: 0<=j<=i<n}",
1548+
"""
1549+
<> tmp[i, j] = a[i, j]
1550+
out[i, j] = tmp[i, j]**2
1551+
""",
1552+
assumptions="n > 0",
1553+
)
1554+
1555+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1556+
tunit = lp.set_temporary_address_space(tunit, "tmp",
1557+
lp.AddressSpace.GLOBAL)
1558+
ref_tunit = tunit
1559+
1560+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1561+
"tmp")
1562+
tunit = tunit.with_kernel(knl)
1563+
1564+
tv, = tunit.default_entrypoint.temporary_variables.values()
1565+
assert tv.shape == (parse("(n + n**2) // 2"),)
1566+
1567+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit, parameters={"n": 20})
1568+
1569+
14811570
if __name__ == "__main__":
14821571
if len(sys.argv) > 1:
14831572
exec(sys.argv[1])

0 commit comments

Comments
 (0)