@@ -286,29 +286,35 @@ public static DataCharacteristics deriveOutputCharacteristics(MMNode node, doubl
286286
287287 MMNode nodeLeft = node .getLeft ();
288288 MMNode nodeRight = node .getRight ();
289+ int leftNRow = nodeLeft .getRows ();
290+ int leftNCol = nodeLeft .getCols ();
291+ int rightNRow = nodeRight .getRows ();
292+ int rightNCol = nodeRight .getCols ();
289293 switch (node .getOp ()) {
290294 case MM :
291- return new MatrixCharacteristics (nodeLeft . getRows (), nodeRight . getCols () ,
292- OptimizerUtils .getNnz (nodeLeft . getRows (), nodeRight . getCols () , spOut ));
295+ return new MatrixCharacteristics (leftNRow , rightNCol ,
296+ OptimizerUtils .getNnz (leftNRow , rightNCol , spOut ));
293297 case MULT :
294298 case PLUS :
295299 case NEQZERO :
296300 case EQZERO :
297- return new MatrixCharacteristics (nodeLeft . getRows (), nodeLeft . getCols () ,
298- OptimizerUtils .getNnz (nodeLeft . getRows (), nodeLeft . getCols () , spOut ));
301+ return new MatrixCharacteristics (leftNRow , leftNCol ,
302+ OptimizerUtils .getNnz (leftNRow , leftNCol , spOut ));
299303 case RBIND :
300- return new MatrixCharacteristics (nodeLeft . getRows ()+ nodeLeft . getRows (), nodeLeft . getCols () ,
301- OptimizerUtils .getNnz (nodeLeft . getRows ()+ nodeRight . getRows (), nodeLeft . getCols () , spOut ));
304+ return new MatrixCharacteristics (leftNRow + rightNRow , leftNCol ,
305+ OptimizerUtils .getNnz (leftNRow + rightNRow , leftNCol , spOut ));
302306 case CBIND :
303- return new MatrixCharacteristics (nodeLeft . getRows (), nodeLeft . getCols ()+ nodeRight . getCols () ,
304- OptimizerUtils .getNnz (nodeLeft . getRows (), nodeLeft . getCols ()+ nodeRight . getCols () , spOut ));
307+ return new MatrixCharacteristics (leftNRow , leftNCol + rightNCol ,
308+ OptimizerUtils .getNnz (leftNRow , leftNCol + rightNCol , spOut ));
305309 case DIAG :
306- int ncol = nodeLeft . getCols ()== 1 ? nodeLeft . getRows () : 1 ;
307- return new MatrixCharacteristics (nodeLeft . getRows () , ncol ,
308- OptimizerUtils .getNnz (nodeLeft . getRows () , ncol , spOut ));
310+ int ncol = ( leftNCol == 1 ) ? leftNRow : 1 ;
311+ return new MatrixCharacteristics (leftNRow , ncol ,
312+ OptimizerUtils .getNnz (leftNRow , ncol , spOut ));
309313 case TRANS :
314+ return new MatrixCharacteristics (leftNCol , leftNRow ,
315+ OptimizerUtils .getNnz (leftNCol , leftNRow , spOut ));
310316 case RESHAPE :
311- throw new NotImplementedException ("Characteristics derivation for trans and reshape has not been " +
317+ throw new NotImplementedException ("Characteristics derivation for " + node . getOp () + " has not been " +
312318 "implemented yet, but could be implemented similar to EstimatorMatrixHistogram.java" );
313319 default :
314320 throw new NotImplementedException ();
0 commit comments