Skip to content

Commit d94668d

Browse files
kaushikcfdinducer
authored andcommitted
test reindex_using_seghir_loechner_scheme
1 parent 9c9c41d commit d94668d

1 file changed

Lines changed: 89 additions & 0 deletions

File tree

test/test_transform.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,95 @@ def test_remove_predicates_from_insn():
16681668
assert t_unit == ref_t_unit
16691669

16701670

1671+
def test_reindexing_strided_access(ctx_factory):
1672+
import islpy as isl
1673+
1674+
if not hasattr(isl.Set, "card"):
1675+
pytest.skip("No barvinok support")
1676+
1677+
ctx = ctx_factory()
1678+
1679+
tunit = lp.make_kernel(
1680+
"{[i, j]: 0<=j,i<10}",
1681+
"""
1682+
<> tmp[2*i, 2*j] = a[i, j]
1683+
out[i, j] = tmp[2*i, 2*j]**2
1684+
""")
1685+
1686+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1687+
ref_tunit = tunit
1688+
1689+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1690+
"tmp")
1691+
tunit = tunit.with_kernel(knl)
1692+
1693+
tv, = tunit.default_entrypoint.temporary_variables.values()
1694+
assert tv.shape == (100,)
1695+
1696+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1697+
1698+
1699+
def test_reindexing_figurate(ctx_factory):
1700+
import islpy as isl
1701+
1702+
if not hasattr(isl.Set, "card"):
1703+
pytest.skip("No barvinok support")
1704+
1705+
ctx = ctx_factory()
1706+
1707+
tunit = lp.make_kernel(
1708+
"{[i, j]: 0<=j<=i<10}",
1709+
"""
1710+
<> tmp[2*i, 2*j] = a[i, j]
1711+
out[i, j] = tmp[2*i, 2*j]**2
1712+
""")
1713+
1714+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1715+
ref_tunit = tunit
1716+
1717+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1718+
"tmp")
1719+
tunit = tunit.with_kernel(knl)
1720+
1721+
tv, = tunit.default_entrypoint.temporary_variables.values()
1722+
assert tv.shape == (55,)
1723+
1724+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit)
1725+
1726+
1727+
def test_reindexing_figurate_parametric_shape(ctx_factory):
1728+
import islpy as isl
1729+
from loopy.symbolic import parse
1730+
1731+
if not hasattr(isl.Set, "card"):
1732+
pytest.skip("No barvinok support")
1733+
1734+
ctx = ctx_factory()
1735+
1736+
tunit = lp.make_kernel(
1737+
"{[i, j]: 0<=j<=i<n}",
1738+
"""
1739+
<> tmp[i, j] = a[i, j]
1740+
out[i, j] = tmp[i, j]**2
1741+
""",
1742+
assumptions="n > 0",
1743+
)
1744+
1745+
tunit = lp.add_dtypes(tunit, {"a": "float64"})
1746+
tunit = lp.set_temporary_address_space(tunit, "tmp",
1747+
lp.AddressSpace.GLOBAL)
1748+
ref_tunit = tunit
1749+
1750+
knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint,
1751+
"tmp")
1752+
tunit = tunit.with_kernel(knl)
1753+
1754+
tv, = tunit.default_entrypoint.temporary_variables.values()
1755+
assert tv.shape == (parse("(n + n**2) // 2"),)
1756+
1757+
lp.auto_test_vs_ref(ref_tunit, ctx, tunit, parameters={"n": 20})
1758+
1759+
16711760
if __name__ == "__main__":
16721761
if len(sys.argv) > 1:
16731762
exec(sys.argv[1])

0 commit comments

Comments
 (0)