Commit b748091
[SYSTEMDS-3333] Add a DP optimization for matrix chains with transposes
This adds a new HOP rewrite rule,
RewriteMatrixMultChainWithTransOptimization.java, to find the optimal
execution plan for matrix multiplication chains containing transposes.
Previously, these chains were optimized using a simple heuristic that
just pushes transposes down from t(A %*% B) -> t(B) %*% t(A), which
fails to be the optimal plan in some instances especially with large
matrices.
An example would be R = t(A %*% B) %*% C with dimensions A = [16, 23], B
= [23, 22], C = [16, 34]
which would be according to the old rewrite class solved with (t(B) %*%
t(A)) %*% C -> costs: t(B) -> 23*22 + t(A) -> 16 * 23 + t(B) %*% t(A) ->
22*23*16 + [...] %*% C -> 22*16*34 = 20938 FLOPs
Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 +
t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference
gets larger with higher matrix dimensions.
To solve this, we applied a DP Algorithm with a Memo Table containing
Plans without transposing and Plans containing Transposing subchains
calculating wether an algebraic transpose pushdown or direct transpose
operation is cheaper.
This also includes 24 automated DML test cases asserting intermediate
HOP dimensions to validate optimal parenthesization and transpose
placement. = 20938 FLOPs
Optimal would be simply: t(A %*% B) %*% C - costs: A %*% B -> 16*23*22 +
t(A %*% B) -> 16*22 + [...] %*% C -> 22*16*34 = 20416 FLOPs - difference
gets larger with higher matrix dimensions.
To solve this, we applied a DP Algorithm with a Memo Table containing
Plans without transposing and Plans containing Transposing subchains
calculating wether an algebraic transpose pushdown or direct transpose
operation is cheaper.
This also includes 24 automated DML test cases asserting intermediate
HOP dimensions to validate optimal parenthesization and transpose
placement.
Closes #2465.1 parent 780d790 commit b748091
29 files changed
Lines changed: 1644 additions & 1 deletion
File tree
- src
- main/java/org/apache/sysds/hops
- rewrite
- test
- java/org/apache/sysds/test/functions/rewrite
- scripts/functions/rewrite/mmchain
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1577 | 1577 | | |
1578 | 1578 | | |
1579 | 1579 | | |
1580 | | - | |
1581 | 1580 | | |
1582 | 1581 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
204 | 204 | | |
205 | 205 | | |
206 | 206 | | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
207 | 212 | | |
208 | 213 | | |
209 | 214 | | |
| |||
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
136 | 136 | | |
137 | 137 | | |
138 | 138 | | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
139 | 142 | | |
140 | 143 | | |
141 | 144 | | |
| |||
0 commit comments