@@ -1668,6 +1668,95 @@ def test_remove_predicates_from_insn():
16681668 assert t_unit == ref_t_unit
16691669
16701670
1671+ def test_reindexing_strided_access (ctx_factory ):
1672+ import islpy as isl
1673+
1674+ if not hasattr (isl .Set , "card" ):
1675+ pytest .skip ("No barvinok support" )
1676+
1677+ ctx = ctx_factory ()
1678+
1679+ tunit = lp .make_kernel (
1680+ "{[i, j]: 0<=j,i<10}" ,
1681+ """
1682+ <> tmp[2*i, 2*j] = a[i, j]
1683+ out[i, j] = tmp[2*i, 2*j]**2
1684+ """ )
1685+
1686+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1687+ ref_tunit = tunit
1688+
1689+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1690+ "tmp" )
1691+ tunit = tunit .with_kernel (knl )
1692+
1693+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1694+ assert tv .shape == (100 ,)
1695+
1696+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1697+
1698+
1699+ def test_reindexing_figurate (ctx_factory ):
1700+ import islpy as isl
1701+
1702+ if not hasattr (isl .Set , "card" ):
1703+ pytest .skip ("No barvinok support" )
1704+
1705+ ctx = ctx_factory ()
1706+
1707+ tunit = lp .make_kernel (
1708+ "{[i, j]: 0<=j<=i<10}" ,
1709+ """
1710+ <> tmp[2*i, 2*j] = a[i, j]
1711+ out[i, j] = tmp[2*i, 2*j]**2
1712+ """ )
1713+
1714+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1715+ ref_tunit = tunit
1716+
1717+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1718+ "tmp" )
1719+ tunit = tunit .with_kernel (knl )
1720+
1721+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1722+ assert tv .shape == (55 ,)
1723+
1724+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1725+
1726+
1727+ def test_reindexing_figurate_parametric_shape (ctx_factory ):
1728+ import islpy as isl
1729+ from loopy .symbolic import parse
1730+
1731+ if not hasattr (isl .Set , "card" ):
1732+ pytest .skip ("No barvinok support" )
1733+
1734+ ctx = ctx_factory ()
1735+
1736+ tunit = lp .make_kernel (
1737+ "{[i, j]: 0<=j<=i<n}" ,
1738+ """
1739+ <> tmp[i, j] = a[i, j]
1740+ out[i, j] = tmp[i, j]**2
1741+ """ ,
1742+ assumptions = "n > 0" ,
1743+ )
1744+
1745+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1746+ tunit = lp .set_temporary_address_space (tunit , "tmp" ,
1747+ lp .AddressSpace .GLOBAL )
1748+ ref_tunit = tunit
1749+
1750+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1751+ "tmp" )
1752+ tunit = tunit .with_kernel (knl )
1753+
1754+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1755+ assert tv .shape == (parse ("(n + n**2) // 2" ),)
1756+
1757+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit , parameters = {"n" : 20 })
1758+
1759+
16711760if __name__ == "__main__" :
16721761 if len (sys .argv ) > 1 :
16731762 exec (sys .argv [1 ])
0 commit comments