Skip to content

Commit 73b622d

Browse files
committed
Tune compressed matmul fast paths and Spark execution decisions
Mixes two related performance changes: refined compressed multiply heuristics, and a Spark-vs-CP decision refresh on the Hop layer. CLALib matmul changes: - CLALibMMChain: for XtXv with few col groups and a wide-enough matrix, compute X' * X via leftMultByTransposeSelf and finish with a regular matrix multiply against v. Cheaper than chaining when the X' * X path can stay compressed - CLALibTSMM: refactor leftMultByTransposeSelf into a package-private helper so MMChain can call it; widen the ColGroupUncompressed handling - CLALibRightMultBy: stop forcing decompression for ASDC / ASDCZero inputs; they have working preAggregate paths that beat the dense fallback - CLALibCompAgg: fix blklen rounding so the last partition is not short by k rows on parallel aggregates Spark/CP exec-decision refresh (Hop, UnaryOp, BinaryOp): - Hop: new helpers hasSparkOutput() and isScalarOrVectorBellowBlockSize() shared between unary and binary decision points - UnaryOp.optFindExecType: replace the inline chain of negations with isDisallowedSparkOps(), allow Frame outputs, and pull unary ops into Spark whenever the input already has a Spark output - BinaryOp.optFindExecType: same kind of restructuring; allow matrix-or-frame outputs to be pulled into Spark when exactly one operand is a scalar or small vector Instruction-side adjustments: - VariableCPInstruction (CAST_AS_MATRIX from frame): use the parallel MatrixBlockFromFrame.convertToMatrixBlock(fin, k) path instead of the single-threaded DataConverter helper - ParameterizedBuiltinCPInstruction (transformdecode): call the parallel decoder.decode(data, out, k) overload using InfrastructureAnalyzer.getLocalParallelism()
1 parent 65e734e commit 73b622d

9 files changed

Lines changed: 109 additions & 44 deletions

File tree

src/main/java/org/apache/sysds/hops/BinaryOp.java

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) {
763763

764764
checkAndSetForcedPlatform();
765765

766-
DataType dt1 = getInput().get(0).getDataType();
767-
DataType dt2 = getInput().get(1).getDataType();
766+
final DataType dt1 = getInput(0).getDataType();
767+
final DataType dt2 = getInput(1).getDataType();
768768

769769
if( _etypeForced != null ) {
770770
setExecType(_etypeForced);
@@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) {
812812
checkAndSetInvalidCPDimsAndSize();
813813
}
814814

815-
//spark-specific decision refinement (execute unary scalar w/ spark input and
815+
// spark-specific decision refinement (execute unary scalar w/ spark input and
816816
// single parent also in spark because it's likely cheap and reduces intermediates)
817-
if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED &&
818-
getDataType().isMatrix() // output should be a matrix
819-
&& (dt1.isScalar() || dt2.isScalar()) // one side should be scalar
820-
&& supportsMatrixScalarOperations() // scalar operations
821-
&& !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint
822-
&& getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent
823-
&& !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec
824-
&& getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) {
825-
// pull unary scalar operation into spark
826-
_etype = ExecType.SPARK;
817+
if(transitive // we allow transitive Spark operations. continue sequences of spark operations
818+
&& _etype == ExecType.CP // The instruction is currently in CP
819+
&& _etypeForced != ExecType.CP // not forced CP
820+
&& _etypeForced != ExecType.FED // not federated
821+
&& (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame
822+
) {
823+
final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize();
824+
final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize();
825+
final boolean left = v1 == true; // left side is the vector or scalar
826+
final Hop sparkIn = getInput(left ? 1 : 0);
827+
if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar.
828+
&& (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation
829+
&& sparkIn.getParent().size() == 1 // only one parent
830+
&& !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec
831+
&& sparkIn.optFindExecType() == ExecType.SPARK // input was spark op.
832+
&& !(sparkIn instanceof DataOp) // input is not checkpoint
833+
) {
834+
// pull operation into spark
835+
_etype = ExecType.SPARK;
836+
}
827837
}
828838

829839
if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE &&
@@ -853,7 +863,7 @@ else if( (op == OpOp2.CBIND && getDataType().isList())
853863
|| (op == OpOp2.RBIND && getDataType().isList())) {
854864
_etype = ExecType.CP;
855865
}
856-
866+
857867
//mark for recompile (forever)
858868
setRequiresRecompileIfNecessary();
859869

src/main/java/org/apache/sysds/hops/Hop.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,12 @@ public final String toString() {
10451045
// ========================================================================================
10461046

10471047

1048+
protected boolean isScalarOrVectorBellowBlockSize(){
1049+
return getDataType().isScalar() || (dimsKnown() &&
1050+
(( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize())
1051+
|| _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize()));
1052+
}
1053+
10481054
protected boolean isVector() {
10491055
return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) );
10501056
}
@@ -1629,6 +1635,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) {
16291635
lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this));
16301636
}
16311637

1638+
protected boolean hasSparkOutput(){
1639+
return (this.optFindExecType() == ExecType.SPARK
1640+
|| (this instanceof DataOp && ((DataOp)this).hasOnlyRDD()));
1641+
}
1642+
16321643
/**
16331644
* Set parse information.
16341645
*

src/main/java/org/apache/sysds/hops/UnaryOp.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
366366
} else {
367367
sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
368368
}
369-
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType());
369+
370+
if(getDataType() == DataType.FRAME)
371+
return OptimizerUtils.estimateSizeExactFrame(dim1, dim2);
372+
else
373+
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
370374
}
371375

372376
@Override
@@ -463,6 +467,13 @@ public boolean isMetadataOperation() {
463467
|| _op == OpOp1.CAST_AS_LIST;
464468
}
465469

470+
private boolean isDisallowedSparkOps(){
471+
return isCumulativeUnaryOperation()
472+
|| isCastUnaryOperation()
473+
|| _op==OpOp1.MEDIAN
474+
|| _op==OpOp1.IQM;
475+
}
476+
466477
@Override
467478
protected ExecType optFindExecType(boolean transitive)
468479
{
@@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto
493504
checkAndSetInvalidCPDimsAndSize();
494505
}
495506

507+
496508
//spark-specific decision refinement (execute unary w/ spark input and
497509
//single parent also in spark because it's likely cheap and reduces intermediates)
498-
if( _etype == ExecType.CP && _etypeForced != ExecType.CP
499-
&& getInput().get(0).optFindExecType() == ExecType.SPARK
500-
&& getDataType().isMatrix()
501-
&& !isCumulativeUnaryOperation() && !isCastUnaryOperation()
502-
&& _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM
503-
&& !(getInput().get(0) instanceof DataOp) //input is not checkpoint
504-
&& getInput().get(0).getParent().size()==1 ) //unary is only parent
505-
{
510+
if(_etype == ExecType.CP // currently CP instruction
511+
&& _etype != ExecType.SPARK /// currently not SP.
512+
&& _etypeForced != ExecType.CP // not forced as CP instruction
513+
&& getInput(0).hasSparkOutput() // input is a spark instruction
514+
&& (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame
515+
&& !isDisallowedSparkOps() // is invalid spark instruction
516+
// && !(getInput().get(0) instanceof DataOp) // input is not checkpoint
517+
// && getInput(0).getParent().size() <= 1// unary is only parent
518+
) {
506519
//pull unary operation into spark
507520
_etype = ExecType.SPARK;
508521
}
522+
509523

510524
//mark for recompile (forever)
511525
setRequiresRecompileIfNecessary();
@@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
520534
} else {
521535
setRequiresRecompileIfNecessary();
522536
}
523-
537+
524538
return _etype;
525539
}
526540

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ private static List<Future<MatrixBlock>> generateUnaryAggregateOverlappingFuture
486486
final ArrayList<UAOverlappingTask> tasks = new ArrayList<>();
487487
final int nCol = m1.getNumColumns();
488488
final int nRow = m1.getNumRows();
489-
final int blklen = Math.max(64, nRow / k);
489+
final int blklen = Math.max(64, (nRow + k) / k);
490490
final List<AColGroup> groups = m1.getColGroups();
491491
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
492492
if(shouldFilter) {

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
3131
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
3232
import org.apache.sysds.runtime.functionobjects.Multiply;
33+
import org.apache.sysds.runtime.instructions.InstructionUtils;
3334
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
3435
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
3536
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix
9596
if(x.isEmpty())
9697
return returnEmpty(x, out);
9798

99+
if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){
100+
MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k);
101+
return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k));
102+
}
103+
98104
// Morph the columns to efficient types for the operation.
99105
x = filterColGroups(x);
100106
double preFilterTime = t.stop();

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import org.apache.sysds.conf.DMLConfig;
3232
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
3333
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
34+
import org.apache.sysds.runtime.compress.colgroup.ASDC;
35+
import org.apache.sysds.runtime.compress.colgroup.ASDCZero;
3436
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
3537
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
3638
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
@@ -71,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
7173
if(m2 instanceof CompressedMatrixBlock)
7274
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k);
7375

74-
if(betterIfDecompressed(m1)) {
75-
// perform uncompressed multiplication.
76-
return decompressingMatrixMult(m1, m2, k);
77-
}
76+
// if(betterIfDecompressed(m1)) {
77+
// // perform uncompressed multiplication.
78+
// return decompressingMatrixMult(m1, m2, k);
79+
// }
7880

7981
if(!allowOverlap) {
8082
LOG.trace("Overlapping output not allowed in call to Right MM");
@@ -143,7 +145,9 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat
143145

144146
private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
145147
for(AColGroup g : m.getColGroups()) {
146-
if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) {
148+
// TODO add subpport for decompressing RMM to ASDC and ASDCZero
149+
if(!(g instanceof ColGroupUncompressed || g instanceof ASDC || g instanceof ASDCZero) &&
150+
g.getNumValues() * 2 >= m.getNumRows()) {
147151
return true;
148152
}
149153
}

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.apache.sysds.runtime.DMLRuntimeException;
3232
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
3333
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
34+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
3435
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
3536
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3637
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -42,6 +43,10 @@ private CLALibTSMM() {
4243
// private constructor
4344
}
4445

46+
public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) {
47+
return leftMultByTransposeSelf(cmb, new MatrixBlock(), k);
48+
}
49+
4550
/**
4651
* Self left Matrix multiplication (tsmm)
4752
*
@@ -51,24 +56,32 @@ private CLALibTSMM() {
5156
* @param ret The output matrix to put the result into
5257
* @param k The parallelization degree allowed
5358
*/
54-
public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
59+
public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
5560

61+
final int numColumns = cmb.getNumColumns();
62+
final int numRows = cmb.getNumRows();
63+
if(cmb.isEmpty())
64+
return new MatrixBlock(numColumns, numColumns, true);
65+
// create output matrix block
66+
if(ret == null)
67+
ret = new MatrixBlock(numColumns, numColumns, false);
68+
else
69+
ret.reset(numColumns, numColumns, false);
70+
ret.allocateDenseBlock();
5671
final List<AColGroup> groups = cmb.getColGroups();
5772

58-
final int numColumns = cmb.getNumColumns();
59-
if(groups.size() >= numColumns) {
73+
if(groups.size() >= numColumns || containsUncompressedColGroup(groups)) {
6074
MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k);
6175
LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
62-
return;
76+
return ret;
6377
}
64-
final int numRows = cmb.getNumRows();
6578
final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
6679
final boolean overlapping = cmb.isOverlapping();
6780
if(shouldFilter) {
6881
final double[] constV = new double[numColumns];
6982
final List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
7083
tsmmColGroups(filteredGroups, ret, numRows, overlapping, k);
71-
addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV);
84+
addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV, k);
7285
}
7386
else {
7487

@@ -77,17 +90,23 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc
7790

7891
ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
7992
ret.examSparsity();
93+
return ret;
94+
}
95+
96+
private static boolean containsUncompressedColGroup(List<AColGroup> groups) {
97+
for(AColGroup g : groups)
98+
if(g instanceof ColGroupUncompressed)
99+
return true;
100+
return false;
80101
}
81102

82103
private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlock result, int nRows, int nCols,
83-
double[] constV) {
104+
double[] constV, int k) {
84105
final double[] retV = result.getDenseBlockValues();
85106
final double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows);
86107
addCorrectionLayer(constV, filteredColSum, nRows, retV);
87108
}
88109

89-
90-
91110
private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock ret, int nRows, boolean overlapping, int k) {
92111
if(k <= 1)
93112
tsmmColGroupsSingleThread(groups, ret, nRows);
@@ -136,12 +155,12 @@ private static void tsmmColGroupsMultiThread(List<AColGroup> groups, MatrixBlock
136155

137156
public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) {
138157
final int nColRow = constV.length;
139-
for(int row = 0; row < nColRow; row++){
158+
for(int row = 0; row < nColRow; row++) {
140159
int offOut = nColRow * row;
141160
final double v1l = constV[row];
142161
final double v2l = filteredColSum[row] + constV[row] * nRow;
143-
for(int col = row; col < nColRow; col++){
144-
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
162+
for(int col = row; col < nColRow; col++) {
163+
ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col];
145164
}
146165
}
147166
}

src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) {
352352
// compute transformdecode
353353
Decoder decoder = DecoderFactory
354354
.createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns());
355-
FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()));
355+
FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism());
356356
fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns()));
357357

358358
// release locks

src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
4545
import org.apache.sysds.runtime.data.TensorBlock;
4646
import org.apache.sysds.runtime.frame.data.FrameBlock;
47+
import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame;
4748
import org.apache.sysds.runtime.instructions.Instruction;
4849
import org.apache.sysds.runtime.instructions.InstructionUtils;
4950
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
@@ -923,7 +924,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) {
923924
switch( getInput1().getDataType() ) {
924925
case FRAME: {
925926
FrameBlock fin = ec.getFrameInput(getInput1().getName());
926-
MatrixBlock out = DataConverter.convertToMatrixBlock(fin);
927+
MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k);
927928
ec.releaseFrameInput(getInput1().getName());
928929
ec.setMatrixOutput(output.getName(), out);
929930
break;

0 commit comments

Comments
 (0)