Skip to content

Commit c2d0e5a

Browse files
smelihportakaljanniklinde
authored andcommitted
[SYSTEMDS-3941] New Algebraic Rewrites
Closes #2460.
1 parent d45ff6d commit c2d0e5a

17 files changed

Lines changed: 615 additions & 2 deletions

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
184184
hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector
185185
hi = simplifyLowerTriExtraction(hop, hi, i); //e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
186186
hi = simplifyConstantCumsum(hop, hi, i); //e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n)
187+
hi = simplifySumConstantMatrix(hop, hi, i); //e.g., sum(matrix(a,rows=b,cols=c)) -> a*b*c
187188
hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector
188189
hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
189190
if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1273,6 +1274,31 @@ private static Hop simplifyConstantCumsum(Hop parent, Hop hi, int pos) {
12731274
}
12741275
return hi;
12751276
}
1277+
1278+
private static Hop simplifySumConstantMatrix(Hop parent, Hop hi, int pos) {
1279+
//pattern: sum(matrix(a, rows=b, cols=c)) -> a*b*c
1280+
if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.RowCol)
1281+
&& HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(0))
1282+
&& hi.getInput(0).dimsKnown()
1283+
&& hi.getInput(0).getDim1() >= 1
1284+
&& hi.getInput(0).getDim2() >= 1
1285+
&& hi.getInput(0).getParent().size() == 1 )
1286+
{
1287+
DataGenOp datagen = (DataGenOp) hi.getInput(0);
1288+
Hop constVal = datagen.getConstantValue();
1289+
Hop rows = new LiteralOp(datagen.getDim1());
1290+
Hop cols = new LiteralOp(datagen.getDim2());
1291+
1292+
Hop hnew = HopRewriteUtils.createBinary(
1293+
HopRewriteUtils.createBinary(constVal, rows, OpOp2.MULT), cols, OpOp2.MULT);
1294+
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
1295+
HopRewriteUtils.cleanupUnreferenced(hi, datagen);
1296+
1297+
hi = hnew;
1298+
LOG.debug("Applied simplifySumConstantMatrix (line "+hi.getBeginLine()+").");
1299+
}
1300+
return hi;
1301+
}
12761302

12771303
private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, int pos)
12781304
{

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
170170
hi = pushdownDetMultOperation(hop, hi, i); //e.g., det(X%*%Y) -> det(X)*det(Y)
171171
hi = pushdownDetScalarMatrixMultOperation(hop, hi, i); //e.g., det(lambda*X) -> lambda^nrow(X)*det(X)
172172
hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lambda*X) -> lambda*sum(X)
173+
hi = pushdownRowSumBinaryMult(hop, hi, i); //e.g., rowSums(lambda*X) -> lambda*rowSums(X)
174+
hi = pushdownColSumBinaryMult(hop, hi, i); //e.g., colSums(lambda*X) -> lambda*colSums(X)
173175
hi = pullupAbs(hop, hi, i); //e.g., abs(X)*abs(Y) --> abs(X*Y)
174176
hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
175177
hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
@@ -1447,6 +1449,58 @@ private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) {
14471449
return hi;
14481450
}
14491451

1452+
private static Hop pushdownRowSumBinaryMult(Hop parent, Hop hi, int pos ) {
1453+
//pattern: rowSums(lamda*X) -> lamda*rowSums(X)
1454+
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.Row
1455+
&& ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the rowSums
1456+
&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MULT, 1)
1457+
&& ((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
1458+
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
1459+
{
1460+
Hop operand1 = hi.getInput(0).getInput(0);
1461+
Hop operand2 = hi.getInput(0).getInput(1);
1462+
1463+
//check which operand is the Scalar and which is the matrix
1464+
Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2;
1465+
Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2;
1466+
1467+
AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Row);
1468+
Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
1469+
1470+
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
1471+
1472+
LOG.debug("Applied pushdownRowSumBinaryMult (line "+hi.getBeginLine()+").");
1473+
return bop;
1474+
}
1475+
return hi;
1476+
}
1477+
1478+
private static Hop pushdownColSumBinaryMult(Hop parent, Hop hi, int pos ) {
1479+
//pattern: colSums(lamda*X) -> lamda*colSums(X)
1480+
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.Col
1481+
&& ((AggUnaryOp)hi).getOp()==AggOp.SUM // only one parent which is the colSums
1482+
&& HopRewriteUtils.isBinary(hi.getInput(0), OpOp2.MULT, 1)
1483+
&& ((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
1484+
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
1485+
{
1486+
Hop operand1 = hi.getInput(0).getInput(0);
1487+
Hop operand2 = hi.getInput(0).getInput(1);
1488+
1489+
//check which operand is the Scalar and which is the matrix
1490+
Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? operand1 : operand2;
1491+
Hop matrix = (operand1.getDataType()==DataType.MATRIX) ? operand1 : operand2;
1492+
1493+
AggUnaryOp aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Col);
1494+
Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, OpOp2.MULT);
1495+
1496+
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
1497+
1498+
LOG.debug("Applied pushdownColSumBinaryMult (line "+hi.getBeginLine()+").");
1499+
return bop;
1500+
}
1501+
return hi;
1502+
}
1503+
14501504
private static Hop pullupAbs(Hop parent, Hop hi, int pos ) {
14511505
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
14521506
&& HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS)

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private void testRewriteFusedRand( String testname, String pdf, boolean rewrites
121121
//compare matrices
122122
Double ret = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1));
123123
if( testname.equals(TEST_NAME1) )
124-
Assert.assertEquals("Wrong result", Double.valueOf(rows), ret);
124+
Assert.assertEquals("Wrong result", Double.valueOf(rows*cols), ret);
125125
else if( testname.equals(TEST_NAME2) )
126126
Assert.assertEquals("Wrong result", Double.valueOf(Math.pow(rows*cols, 2)), ret);
127127
else if( testname.equals(TEST_NAME3) )
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import java.util.HashMap;
23+
24+
import org.junit.Test;
25+
import org.apache.sysds.hops.OptimizerUtils;
26+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
27+
import org.apache.sysds.test.AutomatedTestBase;
28+
import org.apache.sysds.test.TestConfiguration;
29+
import org.apache.sysds.test.TestUtils;
30+
import org.apache.sysds.utils.Statistics;
31+
import org.junit.Assert;
32+
33+
public class RewritePushdownColSumBinaryMultTest extends AutomatedTestBase
34+
{
35+
private static final String TEST_NAME1 = "RewritePushdownColSumBinaryMult";
36+
private static final String TEST_NAME2 = "RewritePushdownColSumBinaryMult2";
37+
38+
private static final String TEST_DIR = "functions/rewrite/";
39+
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownColSumBinaryMultTest.class.getSimpleName() + "/";
40+
41+
@Override
42+
public void setUp() {
43+
TestUtils.clearAssertionInformation();
44+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
45+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
46+
}
47+
48+
@Test
49+
public void testPushdownColSumBinaryMultNoRewrite() {
50+
testRewritePushdownColSumBinaryMult(TEST_NAME1, false);
51+
}
52+
53+
@Test
54+
public void testPushdownColSumBinaryMultRewrite() {
55+
testRewritePushdownColSumBinaryMult(TEST_NAME1, true);
56+
}
57+
58+
@Test
59+
public void testPushdownColSumBinaryMultNoRewrite2() {
60+
testRewritePushdownColSumBinaryMult(TEST_NAME2, false);
61+
}
62+
63+
@Test
64+
public void testPushdownColSumBinaryMultRewrite2() {
65+
testRewritePushdownColSumBinaryMult(TEST_NAME2, true);
66+
}
67+
68+
private void testRewritePushdownColSumBinaryMult(String testname, boolean rewrites) {
69+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
70+
71+
try {
72+
TestConfiguration config = getTestConfiguration(testname);
73+
loadTestConfiguration(config);
74+
75+
String HOME = SCRIPT_DIR + TEST_DIR;
76+
fullDMLScriptName = HOME + testname + ".dml";
77+
programArgs = new String[] { "-stats", "-args", output("R") };
78+
79+
fullRScriptName = HOME + testname + ".R";
80+
rCmd = getRCmd(inputDir(), expectedDir());
81+
82+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
83+
84+
runTest(true, false, null, -1);
85+
runRScript(true);
86+
87+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
88+
HashMap<CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
89+
TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", "R");
90+
91+
if(rewrites)
92+
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("n*"));
93+
else
94+
Assert.assertEquals(2, Statistics.getCPHeavyHitterCount("*"));
95+
}
96+
finally {
97+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
98+
}
99+
}
100+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import java.util.HashMap;
23+
24+
import org.junit.Test;
25+
import org.apache.sysds.hops.OptimizerUtils;
26+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
27+
import org.apache.sysds.test.AutomatedTestBase;
28+
import org.apache.sysds.test.TestConfiguration;
29+
import org.apache.sysds.test.TestUtils;
30+
import org.apache.sysds.utils.Statistics;
31+
import org.junit.Assert;
32+
33+
public class RewritePushdownRowSumBinaryMultTest extends AutomatedTestBase
34+
{
35+
private static final String TEST_NAME1 = "RewritePushdownRowSumBinaryMult";
36+
private static final String TEST_NAME2 = "RewritePushdownRowSumBinaryMult2";
37+
38+
private static final String TEST_DIR = "functions/rewrite/";
39+
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownRowSumBinaryMultTest.class.getSimpleName() + "/";
40+
41+
@Override
42+
public void setUp() {
43+
TestUtils.clearAssertionInformation();
44+
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
45+
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
46+
}
47+
48+
@Test
49+
public void testPushdownRowSumBinaryMultNoRewrite() {
50+
testRewritePushdownRowSumBinaryMult(TEST_NAME1, false);
51+
}
52+
53+
@Test
54+
public void testPushdownRowSumBinaryMultRewrite() {
55+
testRewritePushdownRowSumBinaryMult(TEST_NAME1, true);
56+
}
57+
58+
@Test
59+
public void testPushdownRowSumBinaryMultNoRewrite2() {
60+
testRewritePushdownRowSumBinaryMult(TEST_NAME2, false);
61+
}
62+
63+
@Test
64+
public void testPushdownRowSumBinaryMultRewrite2() {
65+
testRewritePushdownRowSumBinaryMult(TEST_NAME2, true);
66+
}
67+
68+
private void testRewritePushdownRowSumBinaryMult(String testname, boolean rewrites) {
69+
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
70+
71+
try {
72+
TestConfiguration config = getTestConfiguration(testname);
73+
loadTestConfiguration(config);
74+
75+
String HOME = SCRIPT_DIR + TEST_DIR;
76+
fullDMLScriptName = HOME + testname + ".dml";
77+
programArgs = new String[] { "-stats", "-args", output("R") };
78+
79+
fullRScriptName = HOME + testname + ".R";
80+
rCmd = getRCmd(inputDir(), expectedDir());
81+
82+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
83+
84+
runTest(true, false, null, -1);
85+
runRScript(true);
86+
87+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
88+
HashMap<CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
89+
TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", "R");
90+
91+
if(rewrites)
92+
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("n*"));
93+
else
94+
Assert.assertEquals(2, Statistics.getCPHeavyHitterCount("*"));
95+
}
96+
finally {
97+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)