Skip to content

Commit c50c769

Browse files
committed
[SYSTEMDS-3948] Implement Row-wise Sparsity Estimator
This commit implements the row-wise sparsity estimator and adds respective test cases. Closes #2466.
1 parent 684531d commit c50c769

10 files changed

Lines changed: 917 additions & 543 deletions

File tree

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.estim;
21+
22+
import org.apache.commons.lang3.ArrayUtils;
23+
import org.apache.commons.lang3.NotImplementedException;
24+
import org.apache.sysds.hops.OptimizerUtils;
25+
import org.apache.sysds.runtime.data.SparseRow;
26+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
27+
import org.apache.sysds.runtime.meta.DataCharacteristics;
28+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
29+
30+
import java.util.stream.DoubleStream;
31+
import java.util.stream.IntStream;
32+
33+
/**
34+
* This estimator implements an approach based on row-wise sparsity estimation,
35+
* introduced in
36+
* Lin, Chunxu, Wensheng Luo, Yixiang Fang, Chenhao Ma, Xilin Liu and Yuchi Ma:
37+
* On Efficient Large Sparse Matrix Chain Multiplication.
38+
* Proceedings of the ACM on Management of Data 2 (2024): 1 - 27.
39+
*/
40+
public class EstimatorRowWise extends SparsityEstimator {
41+
@Override
42+
public DataCharacteristics estim(MMNode root) {
43+
estimInternChain(root);
44+
double sparsity = DoubleStream.of((double[])root.getSynopsis()).average().orElse(0);
45+
46+
DataCharacteristics outputCharacteristics = deriveOutputCharacteristics(root, sparsity);
47+
return root.setDataCharacteristics(outputCharacteristics);
48+
}
49+
50+
@Override
51+
public double estim(MatrixBlock m1, MatrixBlock m2) {
52+
return estim(m1, m2, OpCode.MM);
53+
}
54+
55+
@Override
56+
public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
57+
if( isExactMetadataOp(op, m1.getNumColumns()) ) {
58+
return estimExactMetaData(m1.getDataCharacteristics(),
59+
m2.getDataCharacteristics(), op).getSparsity();
60+
}
61+
62+
double[] rsOut = estimIntern(m1, m2, op);
63+
return DoubleStream.of(rsOut).average().orElse(0);
64+
}
65+
66+
@Override
67+
public double estim(MatrixBlock m1, OpCode op) {
68+
if( isExactMetadataOp(op, m1.getNumColumns()) )
69+
return estimExactMetaData(m1.getDataCharacteristics(), null, op).getSparsity();
70+
71+
double[] rsOut = estimIntern(m1, op);
72+
return DoubleStream.of(rsOut).average().orElse(0);
73+
}
74+
75+
private double[] estimInternChain(MMNode node) {
76+
return estimInternChain(node, null, null);
77+
}
78+
79+
private double[] estimInternChain(MMNode node, double[] rsRightNeighbor, OpCode opRightNeighbor) {
80+
double[] rsOut;
81+
if(node.isLeaf()) {
82+
MatrixBlock mb = node.getData();
83+
if(rsRightNeighbor != null)
84+
rsOut = estimIntern(mb, rsRightNeighbor, opRightNeighbor);
85+
else
86+
rsOut = getRowWiseSparsityVector(mb);
87+
}
88+
else {
89+
MMNode nodeLeft = node.getLeft();
90+
MMNode nodeRight = node.getRight();
91+
switch(node.getOp()) {
92+
case MM:
93+
double[] rsRightMM = estimInternChain(nodeRight, rsRightNeighbor, opRightNeighbor);
94+
rsOut = estimInternChain(nodeLeft, rsRightMM, node.getOp());
95+
break;
96+
case CBIND:
97+
/**
98+
* NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
99+
* the right neighbor cannot be aggregated into a cbind operation when having only row sparsity vectors
100+
*/
101+
double[] rsLeftCBind = estimInternChain(nodeLeft);
102+
double[] rsRightCBind = estimInternChain(nodeRight);
103+
double[] rsCBind = estimInternCBind(rsLeftCBind, rsRightCBind);
104+
if(rsRightNeighbor != null) {
105+
rsOut = estimInternMMFallback(rsCBind, rsRightNeighbor);
106+
if(opRightNeighbor != OpCode.MM)
107+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
108+
"considered for MM operation w/ right neighbor yet.");
109+
}
110+
else
111+
rsOut = rsCBind;
112+
break;
113+
case RBIND:
114+
/**
115+
* NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
116+
* the right neighbor cannot be aggregated into an rbind operation when having only row sparsity vectors
117+
*/
118+
double[] rsLeftRBind = estimInternChain(nodeLeft);
119+
double[] rsRightRBind = estimInternChain(nodeRight);
120+
double[] rsRBind = estimInternRBind(rsLeftRBind, rsRightRBind);
121+
if(rsRightNeighbor != null) {
122+
rsOut = estimInternMMFallback(rsRBind, rsRightNeighbor);
123+
if(opRightNeighbor != OpCode.MM)
124+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
125+
"considered for MM operation w/ right neighbor yet.");
126+
}
127+
else
128+
rsOut = rsRBind;
129+
break;
130+
case PLUS:
131+
/**
132+
* NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
133+
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
134+
*/
135+
double[] rsLeftPlus = estimInternChain(nodeLeft);
136+
double[] rsRightPlus = estimInternChain(nodeRight);
137+
double[] rsPlus = estimInternPlus(rsLeftPlus, rsRightPlus);
138+
if(rsRightNeighbor != null) {
139+
rsOut = estimInternMMFallback(rsPlus, rsRightNeighbor);
140+
if(opRightNeighbor != OpCode.MM)
141+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
142+
"considered for MM operation w/ right neighbor yet.");
143+
}
144+
else
145+
rsOut = rsPlus;
146+
break;
147+
case MULT:
148+
/**
149+
* NOTE: considering the current node as new DAG for estimation (cut), since the row sparsity of
150+
* the right neighbor cannot be aggregated into an element-wise operation when having only row sparsity vectors
151+
*/
152+
double[] rsLeftMult = estimInternChain(nodeLeft);
153+
double[] rsRightMult = estimInternChain(nodeRight);
154+
double[] rsMult = estimInternMult(rsLeftMult, rsRightMult);
155+
if(rsRightNeighbor != null) {
156+
rsOut = estimInternMMFallback(rsMult, rsRightNeighbor);
157+
if(opRightNeighbor != OpCode.MM)
158+
throw new NotImplementedException("Fallback sparsity estimation has only been " +
159+
"considered for MM operation w/ right neighbor yet.");
160+
}
161+
else
162+
rsOut = rsMult;
163+
break;
164+
default:
165+
throw new NotImplementedException("Chain estimation for operator " + node.getOp().toString() +
166+
" is not supported yet.");
167+
}
168+
}
169+
node.setSynopsis(rsOut);
170+
node.setDataCharacteristics(deriveOutputCharacteristics(node, DoubleStream.of(rsOut).average().orElse(0)));
171+
return rsOut;
172+
}
173+
174+
private double[] estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) {
175+
double[] rsM2 = getRowWiseSparsityVector(m2);
176+
return estimIntern(m1, rsM2, op);
177+
}
178+
179+
private double[] estimIntern(MatrixBlock m1, double[] rsM2, OpCode op) {
180+
switch(op) {
181+
case MM:
182+
return estimInternMM(m1, rsM2);
183+
case CBIND:
184+
return estimInternCBind(getRowWiseSparsityVector(m1), rsM2);
185+
case RBIND:
186+
return estimInternRBind(getRowWiseSparsityVector(m1), rsM2);
187+
case PLUS:
188+
return estimInternPlus(getRowWiseSparsityVector(m1), rsM2);
189+
case MULT:
190+
return estimInternMult(getRowWiseSparsityVector(m1), rsM2);
191+
default:
192+
throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet.");
193+
}
194+
}
195+
196+
private double[] estimIntern(MatrixBlock mb, OpCode op) {
197+
switch(op) {
198+
case DIAG:
199+
return estimInternDiag(mb);
200+
default:
201+
throw new NotImplementedException("Sparsity estimation for operation " + op.toString() + " not supported yet.");
202+
}
203+
}
204+
205+
/**
206+
* Corresponds to Algorithm 1 in the publication
207+
*/
208+
private double[] estimInternMM(MatrixBlock m1, double[] rsM2) {
209+
double[] rsOut = new double[m1.getNumRows()];
210+
for(int rIdx = 0; rIdx < m1.getNumRows(); rIdx++) {
211+
double currentVal = 1;
212+
for(int cIdx : getNonZeroColumnIndices(m1, rIdx)) {
213+
currentVal *= 1.0 - rsM2[cIdx];
214+
}
215+
rsOut[rIdx] = 1 - currentVal;
216+
}
217+
return rsOut;
218+
}
219+
220+
/**
221+
* NOTE: fallback estimate using the uniform estimator (aka average-case estimator, Naive Bayes estimator) for
222+
* the case when we are limited to the row sparsity vectors of both inputs
223+
* NOTE: Considering the average of the second matrix would probably not be far off while saving computing time
224+
*/
225+
private double[] estimInternMMFallback(double[] rsM1, double[] rsM2) {
226+
double[] rsOut = new double[rsM1.length];
227+
for(int i = 0; i < rsM1.length; i++) {
228+
double rsM1i = rsM1[i];
229+
if(rsM1i == 0) {
230+
rsOut[i] = 0;
231+
}
232+
else {
233+
double currentVal = 1;
234+
for(int j = 0; j < rsM2.length; j++) {
235+
currentVal *= 1.0 - (rsM1i * rsM2[j]);
236+
}
237+
rsOut[i] = 1.0 - currentVal;
238+
}
239+
}
240+
return rsOut;
241+
}
242+
243+
private double[] estimInternCBind(double[] rsM1, double[] rsM2) {
244+
// FIXME: this estimate assumes that the number of columns is equivalent for both inputs
245+
double[] rsOut = new double[rsM1.length];
246+
for(int idx = 0; idx < rsM1.length; idx++) {
247+
rsOut[idx] = (rsM1[idx] + rsM2[idx]) / 2.0;
248+
}
249+
return rsOut;
250+
}
251+
252+
private double[] estimInternRBind(double[] rsM1, double[] rsM2) {
253+
return ArrayUtils.addAll(rsM1, rsM2);
254+
}
255+
256+
private double[] estimInternPlus(double[] rsM1, double[] rsM2) {
257+
// row-wise average case estimates
258+
// rsM1 + rsM2 - (rsM1 * rsM2)
259+
double[] rsOut = new double[rsM1.length];
260+
for(int idx = 0; idx < rsM1.length; idx++) {
261+
rsOut[idx] = rsM1[idx] + rsM2[idx] - (rsM1[idx] * rsM2[idx]);
262+
}
263+
return rsOut;
264+
}
265+
266+
private double[] estimInternMult(double[] rsM1, double[] rsM2) {
267+
// row-wise average case estimates
268+
// rsM1 * rsM2
269+
double[] rsOut = new double[rsM1.length];
270+
for(int idx = 0; idx < rsM1.length; idx++) {
271+
rsOut[idx] = rsM1[idx] * rsM2[idx];
272+
}
273+
return rsOut;
274+
}
275+
276+
private double[] estimInternDiag(MatrixBlock mb) {
277+
double[] rsOut = new double[mb.getNumRows()];
278+
for(int rIdx = 0; rIdx < mb.getNumRows(); rIdx++) {
279+
rsOut[rIdx] = (mb.get(rIdx, rIdx) == 0) ? 0 : 1;
280+
}
281+
return rsOut;
282+
}
283+
284+
private double[] getRowWiseSparsityVector(MatrixBlock mb) {
285+
int numRows = mb.getNumRows();
286+
double[] rsOut = new double[numRows];
287+
if(mb.isInSparseFormat()) {
288+
for(int rIdx = 0; rIdx < numRows; rIdx++) {
289+
SparseRow sparseRow = mb.getSparseBlock().get(rIdx);
290+
rsOut[rIdx] = (sparseRow == null) ? 0 : (double) sparseRow.size() / mb.getNumColumns();
291+
}
292+
}
293+
else {
294+
for(int rIdx = 0; rIdx < numRows; rIdx++) {
295+
rsOut[rIdx] = (double) mb.getDenseBlock().countNonZeros(rIdx) / mb.getNumColumns();
296+
}
297+
}
298+
return rsOut;
299+
}
300+
301+
private int[] getNonZeroColumnIndices(MatrixBlock mb, final int rIdx) {
302+
int[] nonZeroCols;
303+
if(mb.isInSparseFormat()) {
304+
SparseRow sparseRow = mb.getSparseBlock().get(rIdx);
305+
nonZeroCols = (sparseRow == null) ? new int[0] : sparseRow.indexes();
306+
}
307+
else {
308+
nonZeroCols = IntStream.range(0, mb.getNumColumns())
309+
.filter(cIdx -> mb.get(rIdx, cIdx) != 0).toArray();
310+
}
311+
return nonZeroCols;
312+
}
313+
314+
public static DataCharacteristics deriveOutputCharacteristics(MMNode node, double spOut) {
315+
if(node.isLeaf() ||
316+
(node.getDataCharacteristics() != null && node.getDataCharacteristics().getNonZeros() != -1)) {
317+
return node.getDataCharacteristics();
318+
}
319+
320+
MMNode nodeLeft = node.getLeft();
321+
MMNode nodeRight = node.getRight();
322+
int leftNRow = nodeLeft.getRows();
323+
int leftNCol = nodeLeft.getCols();
324+
int rightNRow = nodeRight.getRows();
325+
int rightNCol = nodeRight.getCols();
326+
switch(node.getOp()) {
327+
case MM:
328+
return new MatrixCharacteristics(leftNRow, rightNCol,
329+
OptimizerUtils.getNnz(leftNRow, rightNCol, spOut));
330+
case MULT:
331+
case PLUS:
332+
case NEQZERO:
333+
case EQZERO:
334+
return new MatrixCharacteristics(leftNRow, leftNCol,
335+
OptimizerUtils.getNnz(leftNRow, leftNCol, spOut));
336+
case RBIND:
337+
return new MatrixCharacteristics(leftNRow+rightNRow, leftNCol,
338+
OptimizerUtils.getNnz(leftNRow+rightNRow, leftNCol, spOut));
339+
case CBIND:
340+
return new MatrixCharacteristics(leftNRow, leftNCol+rightNCol,
341+
OptimizerUtils.getNnz(leftNRow, leftNCol+rightNCol, spOut));
342+
case DIAG:
343+
int ncol = (leftNCol == 1) ? leftNRow : 1;
344+
return new MatrixCharacteristics(leftNRow, ncol,
345+
OptimizerUtils.getNnz(leftNRow, ncol, spOut));
346+
case TRANS:
347+
return new MatrixCharacteristics(leftNCol, leftNRow,
348+
OptimizerUtils.getNnz(leftNCol, leftNRow, spOut));
349+
case RESHAPE:
350+
throw new NotImplementedException("Characteristics derivation for " + node.getOp() +" has not been " +
351+
"implemented yet, but could be implemented similar to EstimatorMatrixHistogram.java");
352+
default:
353+
throw new NotImplementedException();
354+
}
355+
}
356+
};

0 commit comments

Comments
 (0)