@@ -43,7 +43,7 @@ public class EstimatorRowWise extends SparsityEstimator {
4343 @ Override
4444 public DataCharacteristics estim (MMNode root ) {
4545 estimInternChain (root );
46- double sparsity = (( RSVector )root .getSynopsis ()).avg ( );
46+ double sparsity = DoubleStream . of (( double [] )root .getSynopsis ()).average (). orElse ( 0 );
4747
4848 DataCharacteristics outputCharacteristics = deriveOutputCharacteristics (root , sparsity );
4949 return root .setDataCharacteristics (outputCharacteristics );
@@ -61,25 +61,25 @@ public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
6161 m2 .getDataCharacteristics (), op ).getSparsity ();
6262 }
6363
64- RSVector rsOut = estimIntern (m1 , m2 , op );
65- return rsOut . avg ( );
64+ double [] rsOut = estimIntern (m1 , m2 , op );
65+ return DoubleStream . of ( rsOut ). average (). orElse ( 0 );
6666 }
6767
6868 @ Override
6969 public double estim (MatrixBlock m1 , OpCode op ) {
7070 if ( isExactMetadataOp (op , m1 .getNumColumns ()) )
7171 return estimExactMetaData (m1 .getDataCharacteristics (), null , op ).getSparsity ();
7272
73- RSVector rsOut = estimIntern (m1 , op );
74- return rsOut . avg ( );
73+ double [] rsOut = estimIntern (m1 , op );
74+ return DoubleStream . of ( rsOut ). average (). orElse ( 0 );
7575 }
7676
7777 private void estimInternChain (MMNode node ) {
7878 estimInternChain (node , null , null );
7979 }
8080
81- private void estimInternChain (MMNode node , RSVector rsRightNeighbor , OpCode opRightNeighbor ) {
82- RSVector rsOut ;
81+ private void estimInternChain (MMNode node , double [] rsRightNeighbor , OpCode opRightNeighbor ) {
82+ double [] rsOut ;
8383 if (node .isLeaf ()) {
8484 MatrixBlock mb = node .getData ();
8585 if (rsRightNeighbor != null )
@@ -91,89 +91,89 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
9191 switch (node .getOp ()) {
9292 case MM :
9393 estimInternChain (node .getRight (), rsRightNeighbor , opRightNeighbor );
94- estimInternChain (node .getLeft (), (RSVector )(node .getRight ().getSynopsis ()), node .getOp ());
95- rsOut = (RSVector )node .getLeft ().getSynopsis ();
94+ estimInternChain (node .getLeft (), (double [] )(node .getRight ().getSynopsis ()), node .getOp ());
95+ rsOut = (double [] )node .getLeft ().getSynopsis ();
9696 break ;
9797 case CBIND :
9898 /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
9999 * the right neighbor cannot be aggregated into a cbind operation when having only row sparsity vectors
100100 */
101101 estimInternChain (node .getLeft ());
102102 estimInternChain (node .getRight ());
103- RSVector rsCBind = estimInternCBind ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
103+ double [] rsCBind = estimInternCBind ((double [] )(node .getLeft ().getSynopsis ()), (double [] )(node .getRight ().getSynopsis ()));
104104 if (rsRightNeighbor != null ) {
105- rsOut = (RSVector )estimInternMMFallback (rsCBind , rsRightNeighbor );
105+ rsOut = (double [] )estimInternMMFallback (rsCBind , rsRightNeighbor );
106106 if (opRightNeighbor != OpCode .MM )
107107 throw new NotImplementedException ("Fallback sparsity estimation has only been " +
108108 "considered for MM operation w/ right neighbor yet." );
109109 }
110110 else
111- rsOut = (RSVector )rsCBind ;
111+ rsOut = (double [] )rsCBind ;
112112 break ;
113113 case RBIND :
114114 /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
115115 * the right neighbor cannot be aggregated into an rbind operation when having only row sparsity vectors
116116 */
117117 estimInternChain (node .getLeft ());
118118 estimInternChain (node .getRight ());
119- RSVector rsRBind = estimInternRBind ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
119+ double [] rsRBind = estimInternRBind ((double [] )(node .getLeft ().getSynopsis ()), (double [] )(node .getRight ().getSynopsis ()));
120120 if (rsRightNeighbor != null ) {
121- rsOut = (RSVector )estimInternMMFallback (rsRBind , rsRightNeighbor );
121+ rsOut = (double [] )estimInternMMFallback (rsRBind , rsRightNeighbor );
122122 if (opRightNeighbor != OpCode .MM )
123123 throw new NotImplementedException ("Fallback sparsity estimation has only been " +
124124 "considered for MM operation w/ right neighbor yet." );
125125 }
126126 else
127- rsOut = (RSVector )rsRBind ;
127+ rsOut = (double [] )rsRBind ;
128128 break ;
129129 case PLUS :
130130 /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
131131 * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
132132 */
133133 estimInternChain (node .getLeft ());
134134 estimInternChain (node .getRight ());
135- RSVector rsPlus = estimInternPlus ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
135+ double [] rsPlus = estimInternPlus ((double [] )(node .getLeft ().getSynopsis ()), (double [] )(node .getRight ().getSynopsis ()));
136136 if (rsRightNeighbor != null ) {
137- rsOut = (RSVector )estimInternMMFallback (rsPlus , rsRightNeighbor );
137+ rsOut = (double [] )estimInternMMFallback (rsPlus , rsRightNeighbor );
138138 if (opRightNeighbor != OpCode .MM )
139139 throw new NotImplementedException ("Fallback sparsity estimation has only been " +
140140 "considered for MM operation w/ right neighbor yet." );
141141 }
142142 else
143- rsOut = (RSVector )rsPlus ;
143+ rsOut = (double [] )rsPlus ;
144144 break ;
145145 case MULT :
146146 /** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
147147 * the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
148148 */
149149 estimInternChain (node .getLeft ());
150150 estimInternChain (node .getRight ());
151- RSVector rsMult = estimInternMult ((RSVector )(node .getLeft ().getSynopsis ()), (RSVector )(node .getRight ().getSynopsis ()));
151+ double [] rsMult = estimInternMult ((double [] )(node .getLeft ().getSynopsis ()), (double [] )(node .getRight ().getSynopsis ()));
152152 if (rsRightNeighbor != null ) {
153- rsOut = (RSVector )estimInternMMFallback (rsMult , rsRightNeighbor );
153+ rsOut = (double [] )estimInternMMFallback (rsMult , rsRightNeighbor );
154154 if (opRightNeighbor != OpCode .MM )
155155 throw new NotImplementedException ("Fallback sparsity estimation has only been " +
156156 "considered for MM operation w/ right neighbor yet." );
157157 }
158158 else
159- rsOut = (RSVector )rsMult ;
159+ rsOut = (double [] )rsMult ;
160160 break ;
161161 default :
162162 throw new NotImplementedException ("Chain estimation for operator " + node .getOp ().toString () +
163163 " is not supported yet." );
164164 }
165165 }
166166 node .setSynopsis (rsOut );
167- node .setDataCharacteristics (deriveOutputCharacteristics (node , rsOut . avg ( )));
167+ node .setDataCharacteristics (deriveOutputCharacteristics (node , DoubleStream . of ( rsOut ). average (). orElse ( 0 )));
168168 return ;
169169 }
170170
171- private RSVector estimIntern (MatrixBlock m1 , MatrixBlock m2 , OpCode op ) {
172- RSVector rsM2 = getRowWiseSparsityVector (m2 );
171+ private double [] estimIntern (MatrixBlock m1 , MatrixBlock m2 , OpCode op ) {
172+ double [] rsM2 = getRowWiseSparsityVector (m2 );
173173 return estimIntern (m1 , rsM2 , op );
174174 }
175175
176- private RSVector estimIntern (MatrixBlock m1 , RSVector rsM2 , OpCode op ) {
176+ private double [] estimIntern (MatrixBlock m1 , double [] rsM2 , OpCode op ) {
177177 switch (op ) {
178178 case MM :
179179 return estimInternMM (m1 , rsM2 );
@@ -190,7 +190,7 @@ private RSVector estimIntern(MatrixBlock m1, RSVector rsM2, OpCode op) {
190190 }
191191 }
192192
193- private RSVector estimIntern (MatrixBlock mb , OpCode op ) {
193+ private double [] estimIntern (MatrixBlock mb , OpCode op ) {
194194 switch (op ) {
195195 case DIAG :
196196 return estimInternDiag (mb );
@@ -200,68 +200,72 @@ private RSVector estimIntern(MatrixBlock mb, OpCode op) {
200200 }
201201
202202 // Corresponds to Algorithm 1 in the publication
203- private RSVector estimInternMM (MatrixBlock m1 , RSVector rsM2 ) {
204- RSVector rsOut = new RSVector ( IntStream .range (0 , m1 .getNumRows ()).mapToDouble (
203+ private double [] estimInternMM (MatrixBlock m1 , double [] rsM2 ) {
204+ double [] rsOut = IntStream .range (0 , m1 .getNumRows ()).mapToDouble (
205205 r -> (double ) 1 - IntStream .of (getNonZeroColumnIndices (m1 , r )).mapToDouble (
206- c -> (double ) 1 - rsM2 . get ( c )
206+ c -> (double ) 1 - rsM2 [ c ]
207207 ).reduce ((double ) 1 , (currentVal , val ) -> currentVal * val ))
208- .toArray ()) ;
208+ .toArray ();
209209 return rsOut ;
210210 }
211211
212212 // NOTE: this is the best estimation possible when we only have the two row sparsity vectors
213- private RSVector estimInternMMFallback (RSVector rsM1 , RSVector rsM2 ) {
213+ private double [] estimInternMMFallback (double [] rsM1 , double [] rsM2 ) {
214214 // NOTE: Considering the average would probably not be far off while saving computing time
215215 // double avgRsM2 = DoubleStream.of(rsM2).average().orElse(0);
216- // RSVector rsOut = DoubleStream.of(rsM1).map(
216+ // double[] rsOut = DoubleStream.of(rsM1).map(
217217 // rsM1I -> (double) 1 - Math.pow((double) 1 - (rsM1I * avgRsM2), rsM2.length)).toArray();
218- RSVector rsOut = rsM1 .map (
219- rsM1I -> (double ) 1 - rsM2 .reduce ((double ) 1 ,
220- (currentVal , rsM2J ) -> currentVal * ((double ) 1 - (rsM1I * rsM2J ))));
218+ double [] rsOut = DoubleStream . of ( rsM1 ) .map (
219+ rsM1I -> (double ) 1 - DoubleStream . of ( rsM2 ) .reduce ((double ) 1 ,
220+ (currentVal , rsM2J ) -> currentVal * ((double ) 1 - (rsM1I * rsM2J )))). toArray () ;
221221 return rsOut ;
222222 }
223223
224- private RSVector estimInternCBind (RSVector rsM1 , RSVector rsM2 ) {
225- return new RSVector (IntStream .range (0 , rsM1 .size ()).mapToDouble (
226- idx -> (rsM1 .get (idx ) + rsM2 .get (idx )) / (double ) 2 ).toArray ());
224+ private double [] estimInternCBind (double [] rsM1 , double [] rsM2 ) {
225+ // FIXME: this assumes that the number of columns is equivalent for both inputs
226+ return IntStream .range (0 , rsM1 .length ).mapToDouble (
227+ idx -> (rsM1 [idx ] + rsM2 [idx ]) / (double ) 2 ).toArray ();
227228 }
228229
229- private RSVector estimInternRBind (RSVector rsM1 , RSVector rsM2 ) {
230- return rsM1 . append ( rsM2 );
230+ private double [] estimInternRBind (double [] rsM1 , double [] rsM2 ) {
231+ return ArrayUtils . addAll ( rsM1 , rsM2 );
231232 }
232233
233- private RSVector estimInternPlus (RSVector rsM1 , RSVector rsM2 ) {
234+ private double [] estimInternPlus (double [] rsM1 , double [] rsM2 ) {
234235 // row-wise average case estimates
235236 // rsM1 + rsM2 - (rsM1 * rsM2)
236- return rsM1 .add (rsM2 ).subtract (rsM1 .multiply (rsM2 ));
237+ return IntStream .range (0 , rsM1 .length ).mapToDouble (
238+ idx -> rsM1 [idx ] + rsM2 [idx ] - (rsM1 [idx ] * rsM2 [idx ])).toArray ();
237239 }
238240
239- private RSVector estimInternMult (RSVector rsM1 , RSVector rsM2 ) {
241+ private double [] estimInternMult (double [] rsM1 , double [] rsM2 ) {
240242 // row-wise average case estimates
241243 // rsM1 * rsM2
242- return rsM1 .multiply (rsM2 );
244+ return IntStream .range (0 , rsM1 .length ).mapToDouble (
245+ idx -> rsM1 [idx ] * rsM2 [idx ]).toArray ();
243246 }
244247
245- private RSVector estimInternDiag (MatrixBlock mb ) {
246- RSVector rsOut = new RSVector ( IntStream .range (0 , mb .getNumRows ()).mapToDouble (
248+ private double [] estimInternDiag (MatrixBlock mb ) {
249+ double [] rsOut = IntStream .range (0 , mb .getNumRows ()).mapToDouble (
247250 rIdx -> (mb .get (rIdx , rIdx ) == 0 ) ? 0d : 1d )
248- .toArray ()) ;
251+ .toArray ();
249252 return rsOut ;
250253 }
251254
252- private RSVector getRowWiseSparsityVector (MatrixBlock mb ) {
255+ private double [] getRowWiseSparsityVector (MatrixBlock mb ) {
253256 int numRows = mb .getNumRows ();
254257 if (mb .isInSparseFormat ()) {
255258 double [] rsArray = new double [numRows ];
256259 for (int counter = 0 ; counter < numRows ; counter ++) {
257260 SparseRow sparseRow = mb .getSparseBlock ().get (counter );
258261 rsArray [counter ] = (sparseRow == null ) ? 0 : (double ) sparseRow .size () / mb .getNumColumns ();
259262 }
260- return new RSVector ( rsArray ) ;
263+ return rsArray ;
261264 }
262265 else {
263- return new RSVector (IntStream .range (0 , numRows ).mapToDouble (
264- rIdx -> (double ) mb .getDenseBlock ().countNonZeros (rIdx ) / mb .getNumColumns ()).toArray ());
266+ return IntStream .range (0 , numRows ).mapToDouble (
267+ rIdx -> (double ) mb .getDenseBlock ().countNonZeros (rIdx ) / mb .getNumColumns ())
268+ .toArray ();
265269 }
266270 }
267271
@@ -320,55 +324,4 @@ public static DataCharacteristics deriveOutputCharacteristics(MMNode node, doubl
320324 throw new NotImplementedException ();
321325 }
322326 }
323-
324- public static class RSVector {
325- private final double [] rs ;
326-
327- public RSVector (double [] rs ) {
328- this .rs = rs ;
329- }
330-
331- public double [] get () {
332- return this .rs ;
333- }
334-
335- public double get (int idx ) {
336- return this .rs [idx ];
337- }
338-
339- public int size () {
340- return this .rs .length ;
341- }
342-
343- public double avg () {
344- return DoubleStream .of (this .rs ).average ().orElse (0 );
345- }
346-
347- public RSVector append (RSVector that ) {
348- return new RSVector (ArrayUtils .addAll (this .rs , that .get ()));
349- }
350-
351- public RSVector map (DoubleUnaryOperator mapper ) {
352- return new RSVector (DoubleStream .of (this .rs ).map (mapper ).toArray ());
353- }
354-
355- public double reduce (double identity , DoubleBinaryOperator op ) {
356- return DoubleStream .of (this .rs ).reduce (identity , op );
357- }
358-
359- public RSVector add (RSVector that ) {
360- return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
361- idx -> this .get (idx ) + that .get (idx )).toArray ());
362- }
363-
364- public RSVector subtract (RSVector that ) {
365- return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
366- idx -> this .get (idx ) - that .get (idx )).toArray ());
367- }
368-
369- public RSVector multiply (RSVector that ) {
370- return new RSVector (IntStream .range (0 , this .size ()).mapToDouble (
371- idx -> this .get (idx ) * that .get (idx )).toArray ());
372- }
373- };
374327};
0 commit comments