Skip to content

Commit 85916ba

Browse files
committed
Add coverage tests for compressed matmul fast paths and exec-type decisions
Re-enable the betterIfDecompressed gate in CLALibRightMultBy so the decompressing right-multiply path stays reachable, while still excluding ASDC/ASDCZero column groups from forced decompression. Add targeted and end-to-end tests covering the recently tuned paths: - CLALibMMChainTest: the public CLALibTSMM.leftMultByTransposeSelf overload (wide/narrow/uncompressed/empty/reuse/null) and the XtXv mm-chain fast path, including a tile-then-recompress wide-chain case. - CLALibRightMultBySDCTest: right multiply on ASDC/ASDCZero inputs is not forced to decompress, single-threaded and parallel. - DecoderCompositeTest: parallel and single-thread composite decode, exercising the dummycode+recode ordering dependency. - SparkTransitiveExecTypeTest with DML scripts: UnaryOp/BinaryOp/Hop transitive Spark exec-type pulling under a constrained memory budget. - CompressedTestBase: two parameterized e2e cases that validate the TSMM overload and the wide XtXv fast path against uncompressed results across all compression configurations.
1 parent 688aecd commit 85916ba

8 files changed

Lines changed: 739 additions & 4 deletions

File tree

src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc
7373
if(m2 instanceof CompressedMatrixBlock)
7474
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k);
7575

76-
// if(betterIfDecompressed(m1)) {
77-
// // perform uncompressed multiplication.
78-
// return decompressingMatrixMult(m1, m2, k);
79-
// }
76+
if(betterIfDecompressed(m1)) {
77+
// perform uncompressed multiplication.
78+
return decompressingMatrixMult(m1, m2, k);
79+
}
8080

8181
if(!allowOverlap) {
8282
LOG.trace("Overlapping output not allowed in call to Right MM");

src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
6161
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
6262
import org.apache.sysds.runtime.compress.lib.CLALibCBind;
63+
import org.apache.sysds.runtime.compress.lib.CLALibTSMM;
6364
import org.apache.sysds.runtime.functionobjects.Builtin;
6465
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
6566
import org.apache.sysds.runtime.functionobjects.Divide;
@@ -503,6 +504,51 @@ public void testMatrixMultChain(ChainType ctype) {
503504
}
504505
}
505506

507+
@Test
508+
public void testTransposeSelfLeftMultOverload() {
509+
// Exercises the package-public CLALibTSMM.leftMultByTransposeSelf(cmb, k) entry point (used by the
510+
// XtXv mm-chain fast path) across all compression configurations.
511+
if(!(cmb instanceof CompressedMatrixBlock))
512+
return;
513+
try {
514+
MatrixBlock ret2 = CLALibTSMM.leftMultByTransposeSelf((CompressedMatrixBlock) cmb, _k);
515+
MatrixBlock ucRet2 = mb.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, _k);
516+
compareResultMatrices(ucRet2, ret2, overlappingType != OverLapping.NONE ? 256 : 2);
517+
}
518+
catch(Exception e) {
519+
e.printStackTrace();
520+
throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e);
521+
}
522+
}
523+
524+
@Test
525+
public void testMatrixMultChainXtXvWide() {
526+
// Widen the input beyond 30 columns so the XtXv fast path triggers, validating it against the
527+
// uncompressed result for whatever compression the current configuration produces.
528+
if(!(cmb instanceof CompressedMatrixBlock))
529+
return;
530+
try {
531+
final int nCol = mb.getNumColumns();
532+
final int reps = (int) Math.ceil(31.0 / nCol) + 1;
533+
MatrixBlock wide = mb;
534+
for(int i = 1; i < reps; i++)
535+
wide = wide.append(mb, new MatrixBlock(), true);
536+
537+
MatrixBlock wideC = CompressedMatrixBlockFactory.compress(wide, _k).getLeft();
538+
if(!(wideC instanceof CompressedMatrixBlock))
539+
return; // not compressible in this configuration
540+
541+
MatrixBlock vector1 = TestUtils.generateTestMatrixBlock(wide.getNumColumns(), 1, 0.9, 1.5, 1.0, 3);
542+
MatrixBlock ucRet2 = wide.chainMatrixMultOperations(vector1, null, new MatrixBlock(), ChainType.XtXv, _k);
543+
MatrixBlock ret2 = wideC.chainMatrixMultOperations(vector1, null, new MatrixBlock(), ChainType.XtXv, _k);
544+
compareResultMatricesPercentDistance(ucRet2, ret2, 0.99, 0.99);
545+
}
546+
catch(Exception e) {
547+
e.printStackTrace();
548+
throw new RuntimeException(bufferedToString + "\n" + e.getMessage(), e);
549+
}
550+
}
551+
506552
@Test
507553
public void testVectorMatrixMult() {
508554
MatrixBlock vector = TestUtils.generateTestMatrixBlock(1, rows, 0, 5, 1.0, 3);
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.compress.lib;
21+
22+
import static org.junit.Assert.assertEquals;
23+
import static org.junit.Assert.assertTrue;
24+
import static org.junit.Assert.fail;
25+
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
import java.util.Random;
29+
30+
import org.apache.commons.logging.Log;
31+
import org.apache.commons.logging.LogFactory;
32+
import org.apache.sysds.lops.MapMultChain.ChainType;
33+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
34+
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
35+
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
36+
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
37+
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
38+
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
39+
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
40+
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
41+
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
42+
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
43+
import org.apache.sysds.runtime.compress.lib.CLALibTSMM;
44+
import org.apache.sysds.lops.MMTSJ.MMTSJType;
45+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
46+
import org.apache.sysds.test.TestUtils;
47+
import org.apache.sysds.test.component.compress.mapping.MappingTestUtil;
48+
import org.junit.BeforeClass;
49+
import org.junit.Test;
50+
51+
/**
52+
* Targeted tests for the compressed transpose-self multiply ({@link CLALibTSMM}) and the XtXv mm-chain fast path that
53+
* was added in {@code CLALibMMChain}. The fast path triggers when the input has fewer than five column groups and more
54+
* than thirty columns, in which case the chain is computed as {@code (t(X) %*% X) %*% v}.
55+
*/
56+
public class CLALibMMChainTest {
57+
protected static final Log LOG = LogFactory.getLog(CLALibMMChainTest.class.getName());
58+
59+
@BeforeClass
60+
public static void setup() {
61+
Thread.currentThread().setName("main_test_" + Thread.currentThread().getId());
62+
}
63+
64+
/**
65+
* Build a compressed matrix backed by a single DDC column group spanning all {@code nCol} columns. This guarantees a
66+
* single (non-uncompressed) column group, which is what triggers the mm-chain fast path for wide enough matrices.
67+
*/
68+
private static CompressedMatrixBlock singleDDC(int nRow, int nCol, int nVal, int seed) {
69+
Random r = new Random(seed);
70+
double[] dictValues = new double[nVal * nCol];
71+
for(int i = 0; i < dictValues.length; i++)
72+
dictValues[i] = Math.round(r.nextDouble() * 20 - 10);
73+
IDictionary dict = Dictionary.create(dictValues);
74+
AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r);
75+
AColGroup g = ColGroupDDC.create(ColIndexFactory.create(nCol), dict, data, null);
76+
CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol);
77+
cmb.allocateColGroup(g);
78+
cmb.recomputeNonZeros();
79+
return cmb;
80+
}
81+
82+
private static CompressedMatrixBlock uncompressedGroup(int nRow, int nCol, int seed) {
83+
MatrixBlock mb = TestUtils.round(TestUtils.generateTestMatrixBlock(nRow, nCol, -10, 10, 1.0, seed));
84+
CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol);
85+
cmb.allocateColGroup(ColGroupUncompressed.create(mb));
86+
cmb.recomputeNonZeros();
87+
return cmb;
88+
}
89+
90+
private static CompressedMatrixBlock empty(int nRow, int nCol) {
91+
CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol);
92+
cmb.allocateColGroup(new ColGroupEmpty(ColIndexFactory.create(nCol)));
93+
cmb.recomputeNonZeros();
94+
return cmb;
95+
}
96+
97+
@Test
98+
public void tsmmWideSingleThread() {
99+
execTSMM(singleDDC(200, 40, 6, 1), 1);
100+
}
101+
102+
@Test
103+
public void tsmmWideParallel() {
104+
execTSMM(singleDDC(200, 40, 6, 2), 4);
105+
}
106+
107+
@Test
108+
public void tsmmNarrowSingleThread() {
109+
execTSMM(singleDDC(200, 8, 4, 3), 1);
110+
}
111+
112+
@Test
113+
public void tsmmNarrowParallel() {
114+
execTSMM(singleDDC(200, 8, 4, 4), 4);
115+
}
116+
117+
@Test
118+
public void tsmmUncompressedGroupSingleThread() {
119+
// A compressed block holding an uncompressed column group must fall back to the dense tsmm path.
120+
execTSMM(uncompressedGroup(150, 12, 5), 1);
121+
}
122+
123+
@Test
124+
public void tsmmUncompressedGroupParallel() {
125+
execTSMM(uncompressedGroup(150, 12, 6), 4);
126+
}
127+
128+
@Test
129+
public void tsmmEmpty() {
130+
CompressedMatrixBlock cmb = empty(100, 13);
131+
MatrixBlock ret = CLALibTSMM.leftMultByTransposeSelf(cmb, 1);
132+
assertEquals(13, ret.getNumRows());
133+
assertEquals(13, ret.getNumColumns());
134+
assertTrue("empty input must produce an empty result", ret.isEmptyBlock(false));
135+
}
136+
137+
@Test
138+
public void tsmmRetReused() {
139+
// A non-null ret must be reset and reused, producing the same result as a fresh allocation.
140+
CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 7);
141+
MatrixBlock preAllocated = new MatrixBlock(3, 3, 99.0);
142+
preAllocated.allocateDenseBlock();
143+
MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, preAllocated, 4);
144+
MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb)
145+
.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 4);
146+
TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0);
147+
}
148+
149+
@Test
150+
public void tsmmRetNull() {
151+
// Explicitly exercise the null-ret allocation branch of the helper.
152+
CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 8);
153+
MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, null, 1);
154+
MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb)
155+
.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 1);
156+
TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0);
157+
}
158+
159+
private static void execTSMM(CompressedMatrixBlock cmb, int k) {
160+
try {
161+
MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, k);
162+
MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb)
163+
.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, k);
164+
assertEquals(cmb.getNumColumns(), cRet.getNumRows());
165+
assertEquals(cmb.getNumColumns(), cRet.getNumColumns());
166+
TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0);
167+
}
168+
catch(Exception e) {
169+
e.printStackTrace();
170+
fail(e.getMessage());
171+
}
172+
}
173+
174+
@Test
175+
public void mmChainFastPathSingleThread() {
176+
// 40 columns, single column group -> XtXv fast path.
177+
execMMChain(singleDDC(200, 40, 6, 11), 1);
178+
}
179+
180+
@Test
181+
public void mmChainFastPathParallel() {
182+
execMMChain(singleDDC(200, 40, 6, 12), 4);
183+
}
184+
185+
@Test
186+
public void mmChainFastPathFewGroups() {
187+
// Two column groups (< 5) over 40 columns still triggers the fast path.
188+
execMMChain(twoGroups(200, 40, 13), 4);
189+
}
190+
191+
@Test
192+
public void mmChainRegularPathNarrow() {
193+
// Only 20 columns -> below the width threshold, exercises the regular (non fast) chain path.
194+
execMMChain(singleDDC(200, 20, 6, 14), 4);
195+
}
196+
197+
private static CompressedMatrixBlock twoGroups(int nRow, int nCol, int seed) {
198+
final int half = nCol / 2;
199+
Random r = new Random(seed);
200+
List<AColGroup> gs = new ArrayList<>();
201+
gs.add(ddcGroup(nRow, ColIndexFactory.create(0, half), 5, r));
202+
gs.add(ddcGroup(nRow, ColIndexFactory.create(half, nCol), 5, r));
203+
CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol);
204+
cmb.allocateColGroupList(gs);
205+
cmb.recomputeNonZeros();
206+
return cmb;
207+
}
208+
209+
private static AColGroup ddcGroup(int nRow, org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex cols,
210+
int nVal, Random r) {
211+
int nCol = cols.size();
212+
double[] dictValues = new double[nVal * nCol];
213+
for(int i = 0; i < dictValues.length; i++)
214+
dictValues[i] = Math.round(r.nextDouble() * 20 - 10);
215+
IDictionary dict = Dictionary.create(dictValues);
216+
AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r);
217+
return ColGroupDDC.create(cols, dict, data, null);
218+
}
219+
220+
@Test
221+
public void mmChainWideRecompressedDDC() {
222+
// Mirrors the e2e CompressedTestBase#testMatrixMultChainXtXvWide flow: tile a narrow matrix until it
223+
// exceeds the 30-column fast-path threshold, recompress it, then validate XtXv against uncompressed.
224+
execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 4, -10, 10, 1.0, 21)), 1);
225+
}
226+
227+
@Test
228+
public void mmChainWideRecompressedSparse() {
229+
execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 3, 1, 5, 0.2, 22)), 4);
230+
}
231+
232+
private static void execMMChainWide(MatrixBlock base, int k) {
233+
try {
234+
final int nCol = base.getNumColumns();
235+
final int reps = (int) Math.ceil(31.0 / nCol) + 1;
236+
MatrixBlock wide = base;
237+
for(int i = 1; i < reps; i++)
238+
wide = wide.append(base, new MatrixBlock(), true);
239+
assertTrue("widened matrix must exceed the fast-path threshold", wide.getNumColumns() > 30);
240+
241+
MatrixBlock wideC = CompressedMatrixBlockFactory.compress(wide, k).getLeft();
242+
assertTrue("tiled matrix should compress", wideC instanceof CompressedMatrixBlock);
243+
244+
MatrixBlock v = TestUtils.generateTestMatrixBlock(wide.getNumColumns(), 1, 0.9, 1.5, 1.0, 3);
245+
MatrixBlock uRet = wide.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k);
246+
MatrixBlock cRet = wideC.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k);
247+
TestUtils.compareMatrices(uRet, cRet, 1e-6, "wide recompressed mm-chain result mismatch");
248+
}
249+
catch(Exception e) {
250+
e.printStackTrace();
251+
fail(e.getMessage());
252+
}
253+
}
254+
255+
private static void execMMChain(CompressedMatrixBlock cmb, int k) {
256+
try {
257+
final int cols = cmb.getNumColumns();
258+
MatrixBlock v = TestUtils.round(TestUtils.generateTestMatrixBlock(cols, 1, -3, 3, 1.0, 42));
259+
MatrixBlock uncompressed = CompressedMatrixBlock.getUncompressed(cmb);
260+
261+
MatrixBlock cRet = cmb.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k);
262+
MatrixBlock uRet = uncompressed.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k);
263+
264+
assertEquals(cols, cRet.getNumRows());
265+
assertEquals(1, cRet.getNumColumns());
266+
TestUtils.compareMatrices(uRet, cRet, 1e-6, "mm-chain result mismatch");
267+
}
268+
catch(Exception e) {
269+
e.printStackTrace();
270+
fail(e.getMessage());
271+
}
272+
}
273+
}

0 commit comments

Comments
 (0)