@@ -1478,6 +1478,95 @@ class FooTag(Tag):
14781478 assert t_unit .default_entrypoint .inames ["i_0" ].tags_of_type (FooTag ) # fails
14791479
14801480
1481+ def test_reindexing_strided_access (ctx_factory ):
1482+ import islpy as isl
1483+
1484+ if not hasattr (isl .Set , "card" ):
1485+ pytest .skip ("No barvinok support" )
1486+
1487+ ctx = ctx_factory ()
1488+
1489+ tunit = lp .make_kernel (
1490+ "{[i, j]: 0<=j,i<10}" ,
1491+ """
1492+ <> tmp[2*i, 2*j] = a[i, j]
1493+ out[i, j] = tmp[2*i, 2*j]**2
1494+ """ )
1495+
1496+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1497+ ref_tunit = tunit
1498+
1499+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1500+ "tmp" )
1501+ tunit = tunit .with_kernel (knl )
1502+
1503+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1504+ assert tv .shape == (100 ,)
1505+
1506+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1507+
1508+
1509+ def test_reindexing_figurate (ctx_factory ):
1510+ import islpy as isl
1511+
1512+ if not hasattr (isl .Set , "card" ):
1513+ pytest .skip ("No barvinok support" )
1514+
1515+ ctx = ctx_factory ()
1516+
1517+ tunit = lp .make_kernel (
1518+ "{[i, j]: 0<=j<=i<10}" ,
1519+ """
1520+ <> tmp[2*i, 2*j] = a[i, j]
1521+ out[i, j] = tmp[2*i, 2*j]**2
1522+ """ )
1523+
1524+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1525+ ref_tunit = tunit
1526+
1527+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1528+ "tmp" )
1529+ tunit = tunit .with_kernel (knl )
1530+
1531+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1532+ assert tv .shape == (55 ,)
1533+
1534+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit )
1535+
1536+
1537+ def test_reindexing_figurate_parametric_shape (ctx_factory ):
1538+ import islpy as isl
1539+ from loopy .symbolic import parse
1540+
1541+ if not hasattr (isl .Set , "card" ):
1542+ pytest .skip ("No barvinok support" )
1543+
1544+ ctx = ctx_factory ()
1545+
1546+ tunit = lp .make_kernel (
1547+ "{[i, j]: 0<=j<=i<n}" ,
1548+ """
1549+ <> tmp[i, j] = a[i, j]
1550+ out[i, j] = tmp[i, j]**2
1551+ """ ,
1552+ assumptions = "n > 0" ,
1553+ )
1554+
1555+ tunit = lp .add_dtypes (tunit , {"a" : "float64" })
1556+ tunit = lp .set_temporary_address_space (tunit , "tmp" ,
1557+ lp .AddressSpace .GLOBAL )
1558+ ref_tunit = tunit
1559+
1560+ knl = lp .reindex_temporary_using_seghir_loechner_scheme (tunit .default_entrypoint ,
1561+ "tmp" )
1562+ tunit = tunit .with_kernel (knl )
1563+
1564+ tv , = tunit .default_entrypoint .temporary_variables .values ()
1565+ assert tv .shape == (parse ("(n + n**2) // 2" ),)
1566+
1567+ lp .auto_test_vs_ref (ref_tunit , ctx , tunit , parameters = {"n" : 20 })
1568+
1569+
14811570if __name__ == "__main__" :
14821571 if len (sys .argv ) > 1 :
14831572 exec (sys .argv [1 ])
0 commit comments