Skip to content

Commit 4855b08

Browse files
committed
refactor(main/hops/estim/EstimatorRowWise.java): refactor switch case to consolidate all calls to getters before the switch
1 parent dfcde4b commit 4855b08

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,29 +286,35 @@ public static DataCharacteristics deriveOutputCharacteristics(MMNode node, doubl
286286

287287
MMNode nodeLeft = node.getLeft();
288288
MMNode nodeRight = node.getRight();
289+
int leftNRow = nodeLeft.getRows();
290+
int leftNCol = nodeLeft.getCols();
291+
int rightNRow = nodeRight.getRows();
292+
int rightNCol = nodeRight.getCols();
289293
switch(node.getOp()) {
290294
case MM:
291-
return new MatrixCharacteristics(nodeLeft.getRows(), nodeRight.getCols(),
292-
OptimizerUtils.getNnz(nodeLeft.getRows(), nodeRight.getCols(), spOut));
295+
return new MatrixCharacteristics(leftNRow, rightNCol,
296+
OptimizerUtils.getNnz(leftNRow, rightNCol, spOut));
293297
case MULT:
294298
case PLUS:
295299
case NEQZERO:
296300
case EQZERO:
297-
return new MatrixCharacteristics(nodeLeft.getRows(), nodeLeft.getCols(),
298-
OptimizerUtils.getNnz(nodeLeft.getRows(), nodeLeft.getCols(), spOut));
301+
return new MatrixCharacteristics(leftNRow, leftNCol,
302+
OptimizerUtils.getNnz(leftNRow, leftNCol, spOut));
299303
case RBIND:
300-
return new MatrixCharacteristics(nodeLeft.getRows()+nodeLeft.getRows(), nodeLeft.getCols(),
301-
OptimizerUtils.getNnz(nodeLeft.getRows()+nodeRight.getRows(), nodeLeft.getCols(), spOut));
304+
return new MatrixCharacteristics(leftNRow+rightNRow, leftNCol,
305+
OptimizerUtils.getNnz(leftNRow+rightNRow, leftNCol, spOut));
302306
case CBIND:
303-
return new MatrixCharacteristics(nodeLeft.getRows(), nodeLeft.getCols()+nodeRight.getCols(),
304-
OptimizerUtils.getNnz(nodeLeft.getRows(), nodeLeft.getCols()+nodeRight.getCols(), spOut));
307+
return new MatrixCharacteristics(leftNRow, leftNCol+rightNCol,
308+
OptimizerUtils.getNnz(leftNRow, leftNCol+rightNCol, spOut));
305309
case DIAG:
306-
int ncol = nodeLeft.getCols()==1 ? nodeLeft.getRows() : 1;
307-
return new MatrixCharacteristics(nodeLeft.getRows(), ncol,
308-
OptimizerUtils.getNnz(nodeLeft.getRows(), ncol, spOut));
310+
int ncol = (leftNCol == 1) ? leftNRow : 1;
311+
return new MatrixCharacteristics(leftNRow, ncol,
312+
OptimizerUtils.getNnz(leftNRow, ncol, spOut));
309313
case TRANS:
314+
return new MatrixCharacteristics(leftNCol, leftNRow,
315+
OptimizerUtils.getNnz(leftNCol, leftNRow, spOut));
310316
case RESHAPE:
311-
throw new NotImplementedException("Characteristics derivation for trans and reshape has not been " +
317+
throw new NotImplementedException("Characteristics derivation for " + node.getOp() +" has not been " +
312318
"implemented yet, but could be implemented similar to EstimatorMatrixHistogram.java");
313319
default:
314320
throw new NotImplementedException();

0 commit comments

Comments
 (0)