|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +package org.apache.sysds.test.component.compress.lib; |
| 21 | + |
| 22 | +import static org.junit.Assert.assertEquals; |
| 23 | +import static org.junit.Assert.assertTrue; |
| 24 | +import static org.junit.Assert.fail; |
| 25 | + |
| 26 | +import java.util.ArrayList; |
| 27 | +import java.util.List; |
| 28 | +import java.util.Random; |
| 29 | + |
| 30 | +import org.apache.commons.logging.Log; |
| 31 | +import org.apache.commons.logging.LogFactory; |
| 32 | +import org.apache.sysds.lops.MapMultChain.ChainType; |
| 33 | +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; |
| 34 | +import org.apache.sysds.runtime.compress.colgroup.AColGroup; |
| 35 | +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; |
| 36 | +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; |
| 37 | +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; |
| 38 | +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; |
| 39 | +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; |
| 40 | +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; |
| 41 | +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; |
| 42 | +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; |
| 43 | +import org.apache.sysds.runtime.compress.lib.CLALibTSMM; |
| 44 | +import org.apache.sysds.lops.MMTSJ.MMTSJType; |
| 45 | +import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| 46 | +import org.apache.sysds.test.TestUtils; |
| 47 | +import org.apache.sysds.test.component.compress.mapping.MappingTestUtil; |
| 48 | +import org.junit.BeforeClass; |
| 49 | +import org.junit.Test; |
| 50 | + |
| 51 | +/** |
| 52 | + * Targeted tests for the compressed transpose-self multiply ({@link CLALibTSMM}) and the XtXv mm-chain fast path that |
| 53 | + * was added in {@code CLALibMMChain}. The fast path triggers when the input has fewer than five column groups and more |
| 54 | + * than thirty columns, in which case the chain is computed as {@code (t(X) %*% X) %*% v}. |
| 55 | + */ |
| 56 | +public class CLALibMMChainTest { |
| 57 | + protected static final Log LOG = LogFactory.getLog(CLALibMMChainTest.class.getName()); |
| 58 | + |
| 59 | + @BeforeClass |
| 60 | + public static void setup() { |
| 61 | + Thread.currentThread().setName("main_test_" + Thread.currentThread().getId()); |
| 62 | + } |
| 63 | + |
| 64 | + /** |
| 65 | + * Build a compressed matrix backed by a single DDC column group spanning all {@code nCol} columns. This guarantees a |
| 66 | + * single (non-uncompressed) column group, which is what triggers the mm-chain fast path for wide enough matrices. |
| 67 | + */ |
| 68 | + private static CompressedMatrixBlock singleDDC(int nRow, int nCol, int nVal, int seed) { |
| 69 | + Random r = new Random(seed); |
| 70 | + double[] dictValues = new double[nVal * nCol]; |
| 71 | + for(int i = 0; i < dictValues.length; i++) |
| 72 | + dictValues[i] = Math.round(r.nextDouble() * 20 - 10); |
| 73 | + IDictionary dict = Dictionary.create(dictValues); |
| 74 | + AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r); |
| 75 | + AColGroup g = ColGroupDDC.create(ColIndexFactory.create(nCol), dict, data, null); |
| 76 | + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); |
| 77 | + cmb.allocateColGroup(g); |
| 78 | + cmb.recomputeNonZeros(); |
| 79 | + return cmb; |
| 80 | + } |
| 81 | + |
| 82 | + private static CompressedMatrixBlock uncompressedGroup(int nRow, int nCol, int seed) { |
| 83 | + MatrixBlock mb = TestUtils.round(TestUtils.generateTestMatrixBlock(nRow, nCol, -10, 10, 1.0, seed)); |
| 84 | + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); |
| 85 | + cmb.allocateColGroup(ColGroupUncompressed.create(mb)); |
| 86 | + cmb.recomputeNonZeros(); |
| 87 | + return cmb; |
| 88 | + } |
| 89 | + |
| 90 | + private static CompressedMatrixBlock empty(int nRow, int nCol) { |
| 91 | + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); |
| 92 | + cmb.allocateColGroup(new ColGroupEmpty(ColIndexFactory.create(nCol))); |
| 93 | + cmb.recomputeNonZeros(); |
| 94 | + return cmb; |
| 95 | + } |
| 96 | + |
| 97 | + @Test |
| 98 | + public void tsmmWideSingleThread() { |
| 99 | + execTSMM(singleDDC(200, 40, 6, 1), 1); |
| 100 | + } |
| 101 | + |
| 102 | + @Test |
| 103 | + public void tsmmWideParallel() { |
| 104 | + execTSMM(singleDDC(200, 40, 6, 2), 4); |
| 105 | + } |
| 106 | + |
| 107 | + @Test |
| 108 | + public void tsmmNarrowSingleThread() { |
| 109 | + execTSMM(singleDDC(200, 8, 4, 3), 1); |
| 110 | + } |
| 111 | + |
| 112 | + @Test |
| 113 | + public void tsmmNarrowParallel() { |
| 114 | + execTSMM(singleDDC(200, 8, 4, 4), 4); |
| 115 | + } |
| 116 | + |
| 117 | + @Test |
| 118 | + public void tsmmUncompressedGroupSingleThread() { |
| 119 | + // A compressed block holding an uncompressed column group must fall back to the dense tsmm path. |
| 120 | + execTSMM(uncompressedGroup(150, 12, 5), 1); |
| 121 | + } |
| 122 | + |
| 123 | + @Test |
| 124 | + public void tsmmUncompressedGroupParallel() { |
| 125 | + execTSMM(uncompressedGroup(150, 12, 6), 4); |
| 126 | + } |
| 127 | + |
| 128 | + @Test |
| 129 | + public void tsmmEmpty() { |
| 130 | + CompressedMatrixBlock cmb = empty(100, 13); |
| 131 | + MatrixBlock ret = CLALibTSMM.leftMultByTransposeSelf(cmb, 1); |
| 132 | + assertEquals(13, ret.getNumRows()); |
| 133 | + assertEquals(13, ret.getNumColumns()); |
| 134 | + assertTrue("empty input must produce an empty result", ret.isEmptyBlock(false)); |
| 135 | + } |
| 136 | + |
| 137 | + @Test |
| 138 | + public void tsmmRetReused() { |
| 139 | + // A non-null ret must be reset and reused, producing the same result as a fresh allocation. |
| 140 | + CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 7); |
| 141 | + MatrixBlock preAllocated = new MatrixBlock(3, 3, 99.0); |
| 142 | + preAllocated.allocateDenseBlock(); |
| 143 | + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, preAllocated, 4); |
| 144 | + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) |
| 145 | + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 4); |
| 146 | + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); |
| 147 | + } |
| 148 | + |
| 149 | + @Test |
| 150 | + public void tsmmRetNull() { |
| 151 | + // Explicitly exercise the null-ret allocation branch of the helper. |
| 152 | + CompressedMatrixBlock cmb = singleDDC(120, 36, 5, 8); |
| 153 | + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, null, 1); |
| 154 | + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) |
| 155 | + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, 1); |
| 156 | + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); |
| 157 | + } |
| 158 | + |
| 159 | + private static void execTSMM(CompressedMatrixBlock cmb, int k) { |
| 160 | + try { |
| 161 | + MatrixBlock cRet = CLALibTSMM.leftMultByTransposeSelf(cmb, k); |
| 162 | + MatrixBlock uRet = CompressedMatrixBlock.getUncompressed(cmb) |
| 163 | + .transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.LEFT, k); |
| 164 | + assertEquals(cmb.getNumColumns(), cRet.getNumRows()); |
| 165 | + assertEquals(cmb.getNumColumns(), cRet.getNumColumns()); |
| 166 | + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); |
| 167 | + } |
| 168 | + catch(Exception e) { |
| 169 | + e.printStackTrace(); |
| 170 | + fail(e.getMessage()); |
| 171 | + } |
| 172 | + } |
| 173 | + |
| 174 | + @Test |
| 175 | + public void mmChainFastPathSingleThread() { |
| 176 | + // 40 columns, single column group -> XtXv fast path. |
| 177 | + execMMChain(singleDDC(200, 40, 6, 11), 1); |
| 178 | + } |
| 179 | + |
| 180 | + @Test |
| 181 | + public void mmChainFastPathParallel() { |
| 182 | + execMMChain(singleDDC(200, 40, 6, 12), 4); |
| 183 | + } |
| 184 | + |
| 185 | + @Test |
| 186 | + public void mmChainFastPathFewGroups() { |
| 187 | + // Two column groups (< 5) over 40 columns still triggers the fast path. |
| 188 | + execMMChain(twoGroups(200, 40, 13), 4); |
| 189 | + } |
| 190 | + |
| 191 | + @Test |
| 192 | + public void mmChainRegularPathNarrow() { |
| 193 | + // Only 20 columns -> below the width threshold, exercises the regular (non fast) chain path. |
| 194 | + execMMChain(singleDDC(200, 20, 6, 14), 4); |
| 195 | + } |
| 196 | + |
| 197 | + private static CompressedMatrixBlock twoGroups(int nRow, int nCol, int seed) { |
| 198 | + final int half = nCol / 2; |
| 199 | + Random r = new Random(seed); |
| 200 | + List<AColGroup> gs = new ArrayList<>(); |
| 201 | + gs.add(ddcGroup(nRow, ColIndexFactory.create(0, half), 5, r)); |
| 202 | + gs.add(ddcGroup(nRow, ColIndexFactory.create(half, nCol), 5, r)); |
| 203 | + CompressedMatrixBlock cmb = new CompressedMatrixBlock(nRow, nCol); |
| 204 | + cmb.allocateColGroupList(gs); |
| 205 | + cmb.recomputeNonZeros(); |
| 206 | + return cmb; |
| 207 | + } |
| 208 | + |
| 209 | + private static AColGroup ddcGroup(int nRow, org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex cols, |
| 210 | + int nVal, Random r) { |
| 211 | + int nCol = cols.size(); |
| 212 | + double[] dictValues = new double[nVal * nCol]; |
| 213 | + for(int i = 0; i < dictValues.length; i++) |
| 214 | + dictValues[i] = Math.round(r.nextDouble() * 20 - 10); |
| 215 | + IDictionary dict = Dictionary.create(dictValues); |
| 216 | + AMapToData data = MappingTestUtil.createRandomMap(nRow, nVal, r); |
| 217 | + return ColGroupDDC.create(cols, dict, data, null); |
| 218 | + } |
| 219 | + |
| 220 | + @Test |
| 221 | + public void mmChainWideRecompressedDDC() { |
| 222 | + // Mirrors the e2e CompressedTestBase#testMatrixMultChainXtXvWide flow: tile a narrow matrix until it |
| 223 | + // exceeds the 30-column fast-path threshold, recompress it, then validate XtXv against uncompressed. |
| 224 | + execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 4, -10, 10, 1.0, 21)), 1); |
| 225 | + } |
| 226 | + |
| 227 | + @Test |
| 228 | + public void mmChainWideRecompressedSparse() { |
| 229 | + execMMChainWide(TestUtils.round(TestUtils.generateTestMatrixBlock(300, 3, 1, 5, 0.2, 22)), 4); |
| 230 | + } |
| 231 | + |
| 232 | + private static void execMMChainWide(MatrixBlock base, int k) { |
| 233 | + try { |
| 234 | + final int nCol = base.getNumColumns(); |
| 235 | + final int reps = (int) Math.ceil(31.0 / nCol) + 1; |
| 236 | + MatrixBlock wide = base; |
| 237 | + for(int i = 1; i < reps; i++) |
| 238 | + wide = wide.append(base, new MatrixBlock(), true); |
| 239 | + assertTrue("widened matrix must exceed the fast-path threshold", wide.getNumColumns() > 30); |
| 240 | + |
| 241 | + MatrixBlock wideC = CompressedMatrixBlockFactory.compress(wide, k).getLeft(); |
| 242 | + assertTrue("tiled matrix should compress", wideC instanceof CompressedMatrixBlock); |
| 243 | + |
| 244 | + MatrixBlock v = TestUtils.generateTestMatrixBlock(wide.getNumColumns(), 1, 0.9, 1.5, 1.0, 3); |
| 245 | + MatrixBlock uRet = wide.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); |
| 246 | + MatrixBlock cRet = wideC.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); |
| 247 | + TestUtils.compareMatrices(uRet, cRet, 1e-6, "wide recompressed mm-chain result mismatch"); |
| 248 | + } |
| 249 | + catch(Exception e) { |
| 250 | + e.printStackTrace(); |
| 251 | + fail(e.getMessage()); |
| 252 | + } |
| 253 | + } |
| 254 | + |
| 255 | + private static void execMMChain(CompressedMatrixBlock cmb, int k) { |
| 256 | + try { |
| 257 | + final int cols = cmb.getNumColumns(); |
| 258 | + MatrixBlock v = TestUtils.round(TestUtils.generateTestMatrixBlock(cols, 1, -3, 3, 1.0, 42)); |
| 259 | + MatrixBlock uncompressed = CompressedMatrixBlock.getUncompressed(cmb); |
| 260 | + |
| 261 | + MatrixBlock cRet = cmb.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); |
| 262 | + MatrixBlock uRet = uncompressed.chainMatrixMultOperations(v, null, new MatrixBlock(), ChainType.XtXv, k); |
| 263 | + |
| 264 | + assertEquals(cols, cRet.getNumRows()); |
| 265 | + assertEquals(1, cRet.getNumColumns()); |
| 266 | + TestUtils.compareMatrices(uRet, cRet, 1e-6, "mm-chain result mismatch"); |
| 267 | + } |
| 268 | + catch(Exception e) { |
| 269 | + e.printStackTrace(); |
| 270 | + fail(e.getMessage()); |
| 271 | + } |
| 272 | + } |
| 273 | +} |
0 commit comments