Skip to content

Commit 2361593

Browse files
committed
refactor(main/hops/estim/EstimatorRowWise.java): remove wrapper class for row-wise sparsity vector and apply the corresponding operations directly in the code instead
1 parent 252317a commit 2361593

1 file changed

Lines changed: 55 additions & 102 deletions

File tree

src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java

Lines changed: 55 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class EstimatorRowWise extends SparsityEstimator {
4343
@Override
4444
public DataCharacteristics estim(MMNode root) {
4545
estimInternChain(root);
46-
double sparsity = ((RSVector)root.getSynopsis()).avg();
46+
double sparsity = DoubleStream.of((double[])root.getSynopsis()).average().orElse(0);
4747

4848
DataCharacteristics outputCharacteristics = deriveOutputCharacteristics(root, sparsity);
4949
return root.setDataCharacteristics(outputCharacteristics);
@@ -61,25 +61,25 @@ public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
6161
m2.getDataCharacteristics(), op).getSparsity();
6262
}
6363

64-
RSVector rsOut = estimIntern(m1, m2, op);
65-
return rsOut.avg();
64+
double[] rsOut = estimIntern(m1, m2, op);
65+
return DoubleStream.of(rsOut).average().orElse(0);
6666
}
6767

6868
@Override
6969
public double estim(MatrixBlock m1, OpCode op) {
7070
if( isExactMetadataOp(op, m1.getNumColumns()) )
7171
return estimExactMetaData(m1.getDataCharacteristics(), null, op).getSparsity();
7272

73-
RSVector rsOut = estimIntern(m1, op);
74-
return rsOut.avg();
73+
double[] rsOut = estimIntern(m1, op);
74+
return DoubleStream.of(rsOut).average().orElse(0);
7575
}
7676

7777
private void estimInternChain(MMNode node) {
7878
estimInternChain(node, null, null);
7979
}
8080

81-
private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRightNeighbor) {
82-
RSVector rsOut;
81+
private void estimInternChain(MMNode node, double[] rsRightNeighbor, OpCode opRightNeighbor) {
82+
double[] rsOut;
8383
if(node.isLeaf()) {
8484
MatrixBlock mb = node.getData();
8585
if(rsRightNeighbor != null)
@@ -91,89 +91,89 @@ private void estimInternChain(MMNode node, RSVector rsRightNeighbor, OpCode opRi
9191
switch(node.getOp()) {
9292
case MM:
9393
estimInternChain(node.getRight(), rsRightNeighbor, opRightNeighbor);
94-
estimInternChain(node.getLeft(), (RSVector)(node.getRight().getSynopsis()), node.getOp());
95-
rsOut = (RSVector)node.getLeft().getSynopsis();
94+
estimInternChain(node.getLeft(), (double[])(node.getRight().getSynopsis()), node.getOp());
95+
rsOut = (double[])node.getLeft().getSynopsis();
9696
break;
9797
case CBIND:
9898
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
9999
* the right neighbor cannot be aggregated into a cbind operation when having only row sparsity vectors
100100
*/
101101
estimInternChain(node.getLeft());
102102
estimInternChain(node.getRight());
103-
RSVector rsCBind = estimInternCBind((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
103+
double[] rsCBind = estimInternCBind((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis()));
104104
if(rsRightNeighbor != null) {
105-
rsOut = (RSVector)estimInternMMFallback(rsCBind, rsRightNeighbor);
105+
rsOut = (double[])estimInternMMFallback(rsCBind, rsRightNeighbor);
106106
if(opRightNeighbor != OpCode.MM)
107107
throw new NotImplementedException("Fallback sparsity estimation has only been " +
108108
"considered for MM operation w/ right neighbor yet.");
109109
}
110110
else
111-
rsOut = (RSVector)rsCBind;
111+
rsOut = (double[])rsCBind;
112112
break;
113113
case RBIND:
114114
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
115115
* the right neighbor cannot be aggregated into an rbind operation when having only row sparsity vectors
116116
*/
117117
estimInternChain(node.getLeft());
118118
estimInternChain(node.getRight());
119-
RSVector rsRBind = estimInternRBind((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
119+
double[] rsRBind = estimInternRBind((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis()));
120120
if(rsRightNeighbor != null) {
121-
rsOut = (RSVector)estimInternMMFallback(rsRBind, rsRightNeighbor);
121+
rsOut = (double[])estimInternMMFallback(rsRBind, rsRightNeighbor);
122122
if(opRightNeighbor != OpCode.MM)
123123
throw new NotImplementedException("Fallback sparsity estimation has only been " +
124124
"considered for MM operation w/ right neighbor yet.");
125125
}
126126
else
127-
rsOut = (RSVector)rsRBind;
127+
rsOut = (double[])rsRBind;
128128
break;
129129
case PLUS:
130130
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
131131
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
132132
*/
133133
estimInternChain(node.getLeft());
134134
estimInternChain(node.getRight());
135-
RSVector rsPlus = estimInternPlus((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
135+
double[] rsPlus = estimInternPlus((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis()));
136136
if(rsRightNeighbor != null) {
137-
rsOut = (RSVector)estimInternMMFallback(rsPlus, rsRightNeighbor);
137+
rsOut = (double[])estimInternMMFallback(rsPlus, rsRightNeighbor);
138138
if(opRightNeighbor != OpCode.MM)
139139
throw new NotImplementedException("Fallback sparsity estimation has only been " +
140140
"considered for MM operation w/ right neighbor yet.");
141141
}
142142
else
143-
rsOut = (RSVector)rsPlus;
143+
rsOut = (double[])rsPlus;
144144
break;
145145
case MULT:
146146
/** NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
147147
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
148148
*/
149149
estimInternChain(node.getLeft());
150150
estimInternChain(node.getRight());
151-
RSVector rsMult = estimInternMult((RSVector)(node.getLeft().getSynopsis()), (RSVector)(node.getRight().getSynopsis()));
151+
double[] rsMult = estimInternMult((double[])(node.getLeft().getSynopsis()), (double[])(node.getRight().getSynopsis()));
152152
if(rsRightNeighbor != null) {
153-
rsOut = (RSVector)estimInternMMFallback(rsMult, rsRightNeighbor);
153+
rsOut = (double[])estimInternMMFallback(rsMult, rsRightNeighbor);
154154
if(opRightNeighbor != OpCode.MM)
155155
throw new NotImplementedException("Fallback sparsity estimation has only been " +
156156
"considered for MM operation w/ right neighbor yet.");
157157
}
158158
else
159-
rsOut = (RSVector)rsMult;
159+
rsOut = (double[])rsMult;
160160
break;
161161
default:
162162
throw new NotImplementedException("Chain estimation for operator " + node.getOp().toString() +
163163
" is not supported yet.");
164164
}
165165
}
166166
node.setSynopsis(rsOut);
167-
node.setDataCharacteristics(deriveOutputCharacteristics(node, rsOut.avg()));
167+
node.setDataCharacteristics(deriveOutputCharacteristics(node, DoubleStream.of(rsOut).average().orElse(0)));
168168
return;
169169
}
170170

171-
private RSVector estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) {
172-
RSVector rsM2 = getRowWiseSparsityVector(m2);
171+
private double[] estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) {
172+
double[] rsM2 = getRowWiseSparsityVector(m2);
173173
return estimIntern(m1, rsM2, op);
174174
}
175175

176-
private RSVector estimIntern(MatrixBlock m1, RSVector rsM2, OpCode op) {
176+
private double[] estimIntern(MatrixBlock m1, double[] rsM2, OpCode op) {
177177
switch(op) {
178178
case MM:
179179
return estimInternMM(m1, rsM2);
@@ -190,7 +190,7 @@ private RSVector estimIntern(MatrixBlock m1, RSVector rsM2, OpCode op) {
190190
}
191191
}
192192

193-
private RSVector estimIntern(MatrixBlock mb, OpCode op) {
193+
private double[] estimIntern(MatrixBlock mb, OpCode op) {
194194
switch(op) {
195195
case DIAG:
196196
return estimInternDiag(mb);
@@ -200,68 +200,72 @@ private RSVector estimIntern(MatrixBlock mb, OpCode op) {
200200
}
201201

202202
// Corresponds to Algorithm 1 in the publication
203-
private RSVector estimInternMM(MatrixBlock m1, RSVector rsM2) {
204-
RSVector rsOut = new RSVector(IntStream.range(0, m1.getNumRows()).mapToDouble(
203+
private double[] estimInternMM(MatrixBlock m1, double[] rsM2) {
204+
double[] rsOut = IntStream.range(0, m1.getNumRows()).mapToDouble(
205205
r -> (double) 1 - IntStream.of(getNonZeroColumnIndices(m1, r)).mapToDouble(
206-
c -> (double) 1 - rsM2.get(c)
206+
c -> (double) 1 - rsM2[c]
207207
).reduce((double) 1, (currentVal, val) -> currentVal * val))
208-
.toArray());
208+
.toArray();
209209
return rsOut;
210210
}
211211

212212
// NOTE: this is the best estimation possible when we only have the two row sparsity vectors
213-
private RSVector estimInternMMFallback(RSVector rsM1, RSVector rsM2) {
213+
private double[] estimInternMMFallback(double[] rsM1, double[] rsM2) {
214214
// NOTE: Considering the average would probably not be far off while saving computing time
215215
// double avgRsM2 = DoubleStream.of(rsM2).average().orElse(0);
216-
// RSVector rsOut = DoubleStream.of(rsM1).map(
216+
// double[] rsOut = DoubleStream.of(rsM1).map(
217217
// rsM1I -> (double) 1 - Math.pow((double) 1 - (rsM1I * avgRsM2), rsM2.length)).toArray();
218-
RSVector rsOut = rsM1.map(
219-
rsM1I -> (double) 1 - rsM2.reduce((double) 1,
220-
(currentVal, rsM2J) -> currentVal * ((double) 1 - (rsM1I * rsM2J))));
218+
double[] rsOut = DoubleStream.of(rsM1).map(
219+
rsM1I -> (double) 1 - DoubleStream.of(rsM2).reduce((double) 1,
220+
(currentVal, rsM2J) -> currentVal * ((double) 1 - (rsM1I * rsM2J)))).toArray();
221221
return rsOut;
222222
}
223223

224-
private RSVector estimInternCBind(RSVector rsM1, RSVector rsM2) {
225-
return new RSVector(IntStream.range(0, rsM1.size()).mapToDouble(
226-
idx -> (rsM1.get(idx) + rsM2.get(idx)) / (double) 2).toArray());
224+
private double[] estimInternCBind(double[] rsM1, double[] rsM2) {
225+
// FIXME: this assumes that the number of columns is equivalent for both inputs
226+
return IntStream.range(0, rsM1.length).mapToDouble(
227+
idx -> (rsM1[idx] + rsM2[idx]) / (double) 2).toArray();
227228
}
228229

229-
private RSVector estimInternRBind(RSVector rsM1, RSVector rsM2) {
230-
return rsM1.append(rsM2);
230+
private double[] estimInternRBind(double[] rsM1, double[] rsM2) {
231+
return ArrayUtils.addAll(rsM1, rsM2);
231232
}
232233

233-
private RSVector estimInternPlus(RSVector rsM1, RSVector rsM2) {
234+
private double[] estimInternPlus(double[] rsM1, double[] rsM2) {
234235
// row-wise average case estimates
235236
// rsM1 + rsM2 - (rsM1 * rsM2)
236-
return rsM1.add(rsM2).subtract(rsM1.multiply(rsM2));
237+
return IntStream.range(0, rsM1.length).mapToDouble(
238+
idx -> rsM1[idx] + rsM2[idx] - (rsM1[idx] * rsM2[idx])).toArray();
237239
}
238240

239-
private RSVector estimInternMult(RSVector rsM1, RSVector rsM2) {
241+
private double[] estimInternMult(double[] rsM1, double[] rsM2) {
240242
// row-wise average case estimates
241243
// rsM1 * rsM2
242-
return rsM1.multiply(rsM2);
244+
return IntStream.range(0, rsM1.length).mapToDouble(
245+
idx -> rsM1[idx] * rsM2[idx]).toArray();
243246
}
244247

245-
private RSVector estimInternDiag(MatrixBlock mb) {
246-
RSVector rsOut = new RSVector(IntStream.range(0, mb.getNumRows()).mapToDouble(
248+
private double[] estimInternDiag(MatrixBlock mb) {
249+
double[] rsOut = IntStream.range(0, mb.getNumRows()).mapToDouble(
247250
rIdx -> (mb.get(rIdx, rIdx) == 0) ? 0d : 1d)
248-
.toArray());
251+
.toArray();
249252
return rsOut;
250253
}
251254

252-
private RSVector getRowWiseSparsityVector(MatrixBlock mb) {
255+
private double[] getRowWiseSparsityVector(MatrixBlock mb) {
253256
int numRows = mb.getNumRows();
254257
if(mb.isInSparseFormat()) {
255258
double[] rsArray = new double[numRows];
256259
for(int counter = 0; counter < numRows; counter++) {
257260
SparseRow sparseRow = mb.getSparseBlock().get(counter);
258261
rsArray[counter] = (sparseRow == null) ? 0 : (double) sparseRow.size() / mb.getNumColumns();
259262
}
260-
return new RSVector(rsArray);
263+
return rsArray;
261264
}
262265
else {
263-
return new RSVector(IntStream.range(0, numRows).mapToDouble(
264-
rIdx -> (double) mb.getDenseBlock().countNonZeros(rIdx) / mb.getNumColumns()).toArray());
266+
return IntStream.range(0, numRows).mapToDouble(
267+
rIdx -> (double) mb.getDenseBlock().countNonZeros(rIdx) / mb.getNumColumns())
268+
.toArray();
265269
}
266270
}
267271

@@ -320,55 +324,4 @@ public static DataCharacteristics deriveOutputCharacteristics(MMNode node, doubl
320324
throw new NotImplementedException();
321325
}
322326
}
323-
324-
public static class RSVector {
325-
private final double[] rs;
326-
327-
public RSVector(double[] rs) {
328-
this.rs = rs;
329-
}
330-
331-
public double[] get() {
332-
return this.rs;
333-
}
334-
335-
public double get(int idx) {
336-
return this.rs[idx];
337-
}
338-
339-
public int size() {
340-
return this.rs.length;
341-
}
342-
343-
public double avg() {
344-
return DoubleStream.of(this.rs).average().orElse(0);
345-
}
346-
347-
public RSVector append(RSVector that) {
348-
return new RSVector(ArrayUtils.addAll(this.rs, that.get()));
349-
}
350-
351-
public RSVector map(DoubleUnaryOperator mapper) {
352-
return new RSVector(DoubleStream.of(this.rs).map(mapper).toArray());
353-
}
354-
355-
public double reduce(double identity, DoubleBinaryOperator op) {
356-
return DoubleStream.of(this.rs).reduce(identity, op);
357-
}
358-
359-
public RSVector add(RSVector that) {
360-
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
361-
idx -> this.get(idx) + that.get(idx)).toArray());
362-
}
363-
364-
public RSVector subtract(RSVector that) {
365-
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
366-
idx -> this.get(idx) - that.get(idx)).toArray());
367-
}
368-
369-
public RSVector multiply(RSVector that) {
370-
return new RSVector(IntStream.range(0, this.size()).mapToDouble(
371-
idx -> this.get(idx) * that.get(idx)).toArray());
372-
}
373-
};
374327
};

0 commit comments

Comments
 (0)