Skip to content

Commit 9aa7d52

Browse files
committed
feat(main/hops/estim/EstimatorRowWise.java): add support for element-wise and single operations
NOTE: using average case estimation per row
1 parent da4c2c5 commit 9aa7d52

1 file changed

Lines changed: 78 additions & 19 deletions

File tree

src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public DataCharacteristics estim(MMNode root) {
4949
return root.setDataCharacteristics(outputCharacteristics);
5050
}
5151

52-
@Override
52+
@Override
5353
public double estim(MatrixBlock m1, MatrixBlock m2) {
5454
return estim(m1, m2, OpCode.MM);
5555
}
@@ -99,8 +99,12 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
9999
estimInternChain(node.getLeft());
100100
estimInternChain(node.getRight());
101101
RSVector rsCBind = estimInternCBind((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
102-
if(rsRightNeighbor != null)
103-
rsOut = (RSVector)estimIntern(rsCBind, rsRightNeighbor, opRightNeighbor);
102+
if(rsRightNeighbor != null) {
103+
rsOut = (RSVector)estimInternMMFallback(rsCBind, rsRightNeighbor);
104+
if(opRightNeighbor != OpCode.MM)
105+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
106+
"considered for MM operation w/ right neighbor, yet");
107+
}
104108
else
105109
rsOut = (RSVector)rsCBind;
106110
break;
@@ -111,11 +115,47 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
111115
estimInternChain(node.getLeft());
112116
estimInternChain(node.getRight());
113117
RSVector rsRBind = estimInternRBind((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
114-
if(rsRightNeighbor != null)
115-
rsOut = (RSVector)estimIntern(rsRBind, rsRightNeighbor, opRightNeighbor);
118+
if(rsRightNeighbor != null) {
119+
rsOut = (RSVector)estimInternMMFallback(rsRBind, rsRightNeighbor);
120+
if(opRightNeighbor != OpCode.MM)
121+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
122+
"considered for MM operation w/ right neighbor, yet");
123+
}
116124
else
117125
rsOut = (RSVector)rsRBind;
118126
break;
127+
case PLUS:
128+
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
129+
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
130+
*/
131+
estimInternChain(node.getLeft());
132+
estimInternChain(node.getRight());
133+
RSVector rsPlus = estimInternPlus((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
134+
if(rsRightNeighbor != null) {
135+
rsOut = (RSVector)estimInternMMFallback(rsPlus, rsRightNeighbor);
136+
if(opRightNeighbor != OpCode.MM)
137+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
138+
"considered for MM operation w/ right neighbor, yet");
139+
}
140+
else
141+
rsOut = (RSVector)rsPlus;
142+
break;
143+
case MULT:
144+
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
145+
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
146+
*/
147+
estimInternChain(node.getLeft());
148+
estimInternChain(node.getRight());
149+
RSVector rsMult = estimInternMult((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
150+
if(rsRightNeighbor != null) {
151+
rsOut = (RSVector)estimInternMMFallback(rsMult, rsRightNeighbor);
152+
if(opRightNeighbor != OpCode.MM)
153+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
154+
"considered for MM operation w/ right neighbor, yet");
155+
}
156+
else
157+
rsOut = (RSVector)rsMult;
158+
break;
119159
default:
120160
throw new NotImplementedException("Chain estimation for operator " + node.getOp().toString() +
121161
" is not supported yet.");
@@ -139,19 +179,10 @@ private RSVector estimIntern(MatrixBlock m1, RSVector rsM2, OpCode op) {
139179
return estimInternCBind(getRowWiseSparsityVector(m1), rsM2);
140180
case RBIND:
141181
return estimInternRBind(getRowWiseSparsityVector(m1), rsM2);
142-
default:
143-
throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet.");
144-
}
145-
}
146-
147-
private RSVector estimIntern(RSVector rsM1, RSVector rsM2, OpCode op) {
148-
switch(op) {
149-
case MM:
150-
return estimInternMM(rsM1, rsM2);
151-
// case CBIND:
152-
// return estimInternCBind(rsM1, rsM2);
153-
// case RBIND:
154-
// return estimInternRBind(rsM1, rsM2);
182+
case PLUS:
183+
return estimInternPlus(getRowWiseSparsityVector(m1), rsM2);
184+
case MULT:
185+
return estimInternMult(getRowWiseSparsityVector(m1), rsM2);
155186
default:
156187
throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet.");
157188
}
@@ -168,7 +199,8 @@ private RSVector estimInternMM(MatrixBlock m1, RSVector rsM2) {
168199
}
169200

170201
// NOTE: this is the best estimation possible when we only have the two row sparsity vectors
171-
private RSVector estimInternMM(RSVector rsM1, RSVector rsM2) {
202+
private RSVector estimInternMMFallback(RSVector rsM1, RSVector rsM2) {
203+
// NOTE: Considering the average would probably not be far off while saving computing time
172204
// double avgRsM2 = DoubleStream.of(rsM2).average().orElse(0);
173205
// RSVector rsOut = DoubleStream.of(rsM1).map(
174206
// rsM1I -> (double) 1 - Math.pow((double) 1 - (rsM1I * avgRsM2), rsM2.length)).toArray();
@@ -187,6 +219,18 @@ private RSVector estimInternRBind(RSVector rsM1, RSVector rsM2) {
187219
return rsM1.append(rsM2);
188220
}
189221

222+
private RSVector estimInternPlus(RSVector rsM1, RSVector rsM2) {
223+
// row-wise average case estimates
224+
// rsM1 + rsM2 - (rsM1 * rsM2)
225+
return rsM1.add(rsM2).subtract(rsM1.multiply(rsM2));
226+
}
227+
228+
private RSVector estimInternMult(RSVector rsM1, RSVector rsM2) {
229+
// row-wise average case estimates
230+
// rsM1 * rsM2
231+
return rsM1.multiply(rsM2);
232+
}
233+
190234
private RSVector getRowWiseSparsityVector(MatrixBlock mb) {
191235
int numRows = mb.getNumRows();
192236
if(mb.isInSparseFormat()) {
@@ -287,5 +331,20 @@ public RSVector map(DoubleUnaryOperator mapper) {
287331
public double reduce(double identity, DoubleBinaryOperator op) {
288332
return DoubleStream.of(this.rs).reduce(identity, op);
289333
}
334+
335+
public RSVector add(RSVector that) {
336+
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
337+
idx -> this.get(idx) + that.get(idx)).toArray());
338+
}
339+
340+
public RSVector subtract(RSVector that) {
341+
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
342+
idx -> this.get(idx) - that.get(idx)).toArray());
343+
}
344+
345+
public RSVector multiply(RSVector that) {
346+
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
347+
idx -> this.get(idx) * that.get(idx)).toArray());
348+
}
290349
};
291350
};

0 commit comments

Comments
 (0)