@@ -1478,6 +1478,62 @@ class FooTag(Tag):
14781478 assert t_unit .default_entrypoint .inames ["i_0" ].tags_of_type (FooTag ) # fails
14791479
14801480
1481+ def test_reindexing_strided_access (ctx_factory ):
1482+ import islpy as isl
1483+
1484+ if not hasattr (isl .Set , "card" ):
1485+ pytest .skip ("No barvinok support" )
1486+
1487+ ctx = ctx_factory ()
1488+
1489+ tunit = lp .make_kernel (
1490+ "{[i, j]: 0<=j,i<10}" ,
1491+ """
1492+ <> tmp[2*i, 2*j] = a[i, j]
1493+ out[i, j] = tmp[2*i, 2*j]**2
1494+ """ )
1495+
1496+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1497+ ref_tunit = tunit
1498+
1499+ knl = lp .reindex_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1500+ "tmp" )
1501+ tunit = tunit .with_kernel (knl )
1502+
1503+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1504+ assert tv .shape == (100 ,)
1505+
1506+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1507+
1508+
1509+ def test_reindexing_figurate (ctx_factory ):
1510+ import islpy as isl
1511+
1512+ if not hasattr (isl .Set , "card" ):
1513+ pytest .skip ("No barvinok support" )
1514+
1515+ ctx = ctx_factory ()
1516+
1517+ tunit = lp .make_kernel (
1518+ "{[i, j]: 0<=j<=i<10}" ,
1519+ """
1520+ <> tmp[2*i, 2*j] = a[i, j]
1521+ out[i, j] = tmp[2*i, 2*j]**2
1522+ """ )
1523+
1524+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1525+ ref_tunit = tunit
1526+
1527+ knl = lp .reindex_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1528+ "tmp" )
1529+ tunit = tunit .with_kernel (knl )
1530+
1531+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1532+ assert tv .shape == (55 ,)
1533+
1534+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1535+
1536+
14811537if __name__ == "__main__" :
14821538 if len (sys .argv ) > 1 :
14831539 exec (sys .argv [1 ])
0 commit comments