Skip to content

Commit b748091

Browse files
Elmanjhgmboehm7
authored andcommitted
[SYSTEMDS-3333] Add a DP optimization for matrix chains with transposes
This adds a new HOP rewrite rule, RewriteMatrixMultChainWithTransOptimization.java, to find the optimal execution plan for matrix multiplication chains containing transposes. Previously, these chains were optimized using a simple heuristic that just pushes transposes down from t(A %*% B) -> t(B) %*% t(A), which fails to be the optimal plan in some instances especially with large matrices. An example would be R = t(A %*% B) %*% C with dimensions A = [16, 23], B = [23, 22], C = [16, 34] which would be according to the old rewrite class solved with (t(B) %*% t(A)) %*% C -> costs: t(B) -> 23*22 + t(A) -> 16 * 23 + t(B) %*% t(A) -> 22*23*16 + [...] %*% C -> 22*16*34 = 20938 FLOPs Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 + t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference gets larger with higher matrix dimensions. To solve this, we applied a DP Algorithm with a Memo Table containing Plans without transposing and Plans containing Transposing subchains calculating wether an algebraic transpose pushdown or direct transpose operation is cheaper. This also includes 24 automated DML test cases asserting intermediate HOP dimensions to validate optimal parenthesization and transpose placement. = 20938 FLOPs Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 + t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference gets larger with higher matrix dimensions. To solve this, we applied a DP Algorithm with a Memo Table containing Plans without transposing and Plans containing Transposing subchains calculating wether an algebraic transpose pushdown or direct transpose operation is cheaper. This also includes 24 automated DML test cases asserting intermediate HOP dimensions to validate optimal parenthesization and transpose placement. Closes #2465.
1 parent 780d790 commit b748091

29 files changed

Lines changed: 1644 additions & 1 deletion

pom.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,5 @@
15771577
<artifactId>fastdoubleparser</artifactId>
15781578
<version>0.9.0</version>
15791579
</dependency>
1580-
15811580
</dependencies>
15821581
</project>

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ public enum MemoryManager {
204204
* ALLOW_SUM_PRODUCT_REWRITES.
205205
*/
206206
public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false;
207+
208+
/**
209+
* Enables a DPSize inspired algorithm rewrite for MMChain with transposes
210+
*/
211+
public static boolean ALLOW_NEW_MMCHAIN_REWRITE = false;
207212

208213
/**
209214
* Enables a specific hop dag rewrite that splits hop dags after csv persistent reads with

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
136136
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
137137
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()); //dependency: cse
138138
}
139+
if( OptimizerUtils.ALLOW_NEW_MMCHAIN_REWRITE ) {
140+
_dagRuleSet.add( new RewriteMatrixMultChainWithTransOptimization() );
141+
}
139142
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
140143
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
141144
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationSparse() ); //dependency: cse

0 commit comments

Comments
 (0)