@@ -49,7 +49,7 @@ public DataCharacteristics estim(MMNode root) {
4949 return root .setDataCharacteristics (outputCharacteristics );
5050 }
5151
52- @ Override
52+ @ Override
5353 public double estim (MatrixBlock m1 , MatrixBlock m2 ) {
5454 return estim (m1 , m2 , OpCode .MM );
5555 }
@@ -99,8 +99,12 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
9999 estimInternChain (node .getLeft ());
100100 estimInternChain (node .getRight ());
101101 RSVector rsCBind = estimInternCBind ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
102- if (rsRightNeighbor != null )
103- rsOut = (RSVector )estimIntern (rsCBind , rsRightNeighbor , opRightNeighbor );
102+ if (rsRightNeighbor != null ) {
103+ rsOut = (RSVector )estimInternMMFallback (rsCBind , rsRightNeighbor );
104+ if (opRightNeighbor != OpCode .MM )
105+ throw new NotImplementedException ("Fallback sparsity estimation has only been " +
106+ "considered for MM operation w/ right neighbor, yet" );
107+ }
104108 else
105109 rsOut = (RSVector )rsCBind ;
106110 break ;
@@ -111,11 +115,47 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
111115 estimInternChain (node .getLeft ());
112116 estimInternChain (node .getRight ());
113117 RSVector rsRBind = estimInternRBind ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
114- if (rsRightNeighbor != null )
115- rsOut = (RSVector )estimIntern (rsRBind , rsRightNeighbor , opRightNeighbor );
118+ if (rsRightNeighbor != null ) {
119+ rsOut = (RSVector )estimInternMMFallback (rsRBind , rsRightNeighbor );
120+ if (opRightNeighbor != OpCode .MM )
121+ throw new NotImplementedException ("Fallback sparsity estimation has only been " +
122+ "considered for MM operation w/ right neighbor, yet" );
123+ }
116124 else
117125 rsOut = (RSVector )rsRBind ;
118126 break ;
127+ case PLUS :
128+ /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
129+ * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
130+ */
131+ estimInternChain (node .getLeft ());
132+ estimInternChain (node .getRight ());
133+ RSVector rsPlus = estimInternPlus ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
134+ if (rsRightNeighbor != null ) {
135+ rsOut = (RSVector )estimInternMMFallback (rsPlus , rsRightNeighbor );
136+ if (opRightNeighbor != OpCode .MM )
137+ throw new NotImplementedException ("Fallback sparsity estimation has only been " +
138+ "considered for MM operation w/ right neighbor, yet" );
139+ }
140+ else
141+ rsOut = (RSVector )rsPlus ;
142+ break ;
143+ case MULT :
144+ /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
145+ * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
146+ */
147+ estimInternChain (node .getLeft ());
148+ estimInternChain (node .getRight ());
149+ RSVector rsMult = estimInternMult ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
150+ if (rsRightNeighbor != null ) {
151+ rsOut = (RSVector )estimInternMMFallback (rsMult , rsRightNeighbor );
152+ if (opRightNeighbor != OpCode .MM )
153+ throw new NotImplementedException ("Fallback sparsity estimation has only been " +
154+ "considered for MM operation w/ right neighbor, yet" );
155+ }
156+ else
157+ rsOut = (RSVector )rsMult ;
158+ break ;
119159 default :
120160 throw new NotImplementedException ("Chain estimation for operator " + node .getOp ().toString () +
121161 " is not supported yet." );
@@ -139,19 +179,10 @@ private RSVector estimIntern(MatrixBlock m1, RSVector rsM2, OpCode op) {
139179 return estimInternCBind (getRowWiseSparsityVector (m1 ), rsM2 );
140180 case RBIND :
141181 return estimInternRBind (getRowWiseSparsityVector (m1 ), rsM2 );
142- default :
143- throw new NotImplementedException ("Sparsity estimation for operation " + op .toString () + " not supported yet." );
144- }
145- }
146-
147- private RSVector estimIntern (RSVector rsM1 , RSVector rsM2 , OpCode op ) {
148- switch (op ) {
149- case MM :
150- return estimInternMM (rsM1 , rsM2 );
151- // case CBIND:
152- // return estimInternCBind(rsM1, rsM2);
153- // case RBIND:
154- // return estimInternRBind(rsM1, rsM2);
182+ case PLUS :
183+ return estimInternPlus (getRowWiseSparsityVector (m1 ), rsM2 );
184+ case MULT :
185+ return estimInternMult (getRowWiseSparsityVector (m1 ), rsM2 );
155186 default :
156187 throw new NotImplementedException ("Sparsity estimation for operation " + op .toString () + " not supported yet." );
157188 }
@@ -168,7 +199,8 @@ private RSVector estimInternMM(MatrixBlock m1, RSVector rsM2) {
168199 }
169200
170201 // NOTE: this is the best estimation possible when we only have the two row sparsity vectors
171- private RSVector estimInternMM (RSVector rsM1 , RSVector rsM2 ) {
202+ private RSVector estimInternMMFallback (RSVector rsM1 , RSVector rsM2 ) {
203+ // NOTE: Considering the average would probably not be far off while saving computing time
172204 // double avgRsM2 = DoubleStream.of(rsM2).average().orElse(0);
173205 // RSVector rsOut = DoubleStream.of(rsM1).map(
174206 // rsM1I -> (double) 1 - Math.pow((double) 1 - (rsM1I * avgRsM2), rsM2.length)).toArray();
@@ -187,6 +219,18 @@ private RSVector estimInternRBind(RSVector rsM1, RSVector rsM2) {
187219 return rsM1 .append (rsM2 );
188220 }
189221
222+ private RSVector estimInternPlus (RSVector rsM1 , RSVector rsM2 ) {
223+ // row-wise average case estimates
224+ // rsM1 + rsM2 - (rsM1 * rsM2)
225+ return rsM1 .add (rsM2 ).subtract (rsM1 .multiply (rsM2 ));
226+ }
227+
228+ private RSVector estimInternMult (RSVector rsM1 , RSVector rsM2 ) {
229+ // row-wise average case estimates
230+ // rsM1 * rsM2
231+ return rsM1 .multiply (rsM2 );
232+ }
233+
190234 private RSVector getRowWiseSparsityVector (MatrixBlock mb ) {
191235 int numRows = mb .getNumRows ();
192236 if (mb .isInSparseFormat ()) {
@@ -287,5 +331,20 @@ public RSVector map(DoubleUnaryOperator mapper) {
287331 public double reduce (double identity , DoubleBinaryOperator op ) {
288332 return DoubleStream .of (this .rs ).reduce (identity , op );
289333 }
334+
335+ public RSVector add (RSVector that ) {
336+ return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
337+ idx -> this .get (idx ) + that .get (idx )).toArray ());
338+ }
339+
340+ public RSVector subtract (RSVector that ) {
341+ return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
342+ idx -> this .get (idx ) - that .get (idx )).toArray ());
343+ }
344+
345+ public RSVector multiply (RSVector that ) {
346+ return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
347+ idx -> this .get (idx ) * that .get (idx )).toArray ());
348+ }
290349 };
291350};
0 commit comments