Skip to content

Commit da4c2c5

Browse files
committed
feat(test/component/estim): add unit tests for row wise sparsity estimator with element-wise and single operations
1 parent ef4a026 commit da4c2c5

3 files changed

Lines changed: 55 additions & 11 deletions

File tree

src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
2626
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
2727
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
28+
import org.apache.sysds.hops.estim.EstimatorRowWise;
2829
import org.apache.sysds.hops.estim.EstimatorDensityMap;
2930
import org.apache.sysds.hops.estim.MMNode;
3031
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -118,8 +119,18 @@ public void testLGCasemult() {
118119
public void testLGCaseplus() {
119120
runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, plus);
120121
}
121-
122-
122+
123+
// Row Wise Sparsity Estimator
124+
@Test
125+
public void testRowWiseCaseMult() {
126+
runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, mult);
127+
}
128+
129+
@Test
130+
public void testRowWiseCasePlus() {
131+
runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, plus);
132+
}
133+
123134
private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int n, double[] sp, OpCode op) {
124135
MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1, "uniform", 3);
125136
MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 5);

src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
2626
import org.apache.sysds.hops.estim.EstimatorDensityMap;
2727
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
28+
import org.apache.sysds.hops.estim.EstimatorRowWise;
2829
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
2930
import org.apache.sysds.hops.estim.EstimatorSample;
3031
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -128,7 +129,18 @@ public void testSampleMult() {
128129
public void testSamplePlus() {
129130
runSparsityEstimateTest(new EstimatorSample(), m, n, sparsity, plus);
130131
}
131-
132+
133+
// Row Wise Sparsity Estimator
134+
@Test
135+
public void testRowWiseMult() {
136+
runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, mult);
137+
}
138+
139+
@Test
140+
public void testRowWisePlus() {
141+
runSparsityEstimateTest(new EstimatorRowWise(), m, n, sparsity, plus);
142+
}
143+
132144
private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int n, double[] sp, OpCode op) {
133145
MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1, "uniform", 3);
134146
MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7);

src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
2727
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
2828
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
29+
import org.apache.sysds.hops.estim.EstimatorRowWise;
2930
import org.apache.sysds.hops.estim.SparsityEstimator;
3031
import org.apache.sysds.hops.estim.SparsityEstimator.OpCode;
3132
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -40,7 +41,7 @@ public class OpSingleTest extends AutomatedTestBase
4041
private final static int m = 600;
4142
private final static int k = 300;
4243
private final static double sparsity = 0.2;
43-
// private final static OpCode eqzero = OpCode.EQZERO;
44+
// private final static OpCode eqzero = OpCode.EQZERO;
4445
private final static OpCode diag = OpCode.DIAG;
4546
private final static OpCode neqzero = OpCode.NEQZERO;
4647
private final static OpCode trans = OpCode.TRANS;
@@ -237,7 +238,33 @@ public void testLGCasetrans() {
237238
// public void testSampleCasereshape() {
238239
// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity, reshape);
239240
// }
240-
241+
242+
// Row Wise Sparsity Estimator
243+
// @Test
244+
// public void testRowWiseEqzero() {
245+
// runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, eqzero);
246+
// }
247+
248+
// @Test
249+
// public void testRowWiseDiag() {
250+
// runSparsityEstimateTest(new EstimatorRowWise(), m, m, sparsity, diag);
251+
// }
252+
253+
@Test
254+
public void testRowWiseNeqzero() {
255+
runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, neqzero);
256+
}
257+
258+
@Test
259+
public void testRowWiseTrans() {
260+
runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, trans);
261+
}
262+
263+
@Test
264+
public void testRowWiseReshape() {
265+
runSparsityEstimateTest(new EstimatorRowWise(), m, k, sparsity, reshape);
266+
}
267+
241268
private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, double sp, OpCode op) {
242269
MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp, 1, 1, "uniform", 3);
243270
MatrixBlock m2 = new MatrixBlock();
@@ -252,13 +279,7 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int
252279
est = estim.estim(m1, op);
253280
break;
254281
case NEQZERO:
255-
m2 = m1;
256-
est = estim.estim(m1, op);
257-
break;
258282
case TRANS:
259-
m2 = m1;
260-
est = estim.estim(m1, op);
261-
break;
262283
case RESHAPE:
263284
m2 = m1;
264285
est = estim.estim(m1, op);

0 commit comments

Comments
 (0)