|
| 1 | +/* |
| 2 | + * Copyright DataStax, Inc. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +package io.github.jbellis.jvector.quantization; |
| 17 | + |
| 18 | +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; |
| 19 | +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; |
| 20 | +import io.github.jbellis.jvector.vector.VectorizationProvider; |
| 21 | +import io.github.jbellis.jvector.vector.types.VectorFloat; |
| 22 | +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; |
| 23 | +import org.junit.Test; |
| 24 | + |
| 25 | +import java.util.ArrayList; |
| 26 | +import java.util.List; |
| 27 | +import java.util.Random; |
| 28 | + |
| 29 | +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; |
| 30 | +import static io.github.jbellis.jvector.quantization.ProductQuantization.getSubvectorSizesAndOffsets; |
| 31 | +import static org.junit.Assert.assertTrue; |
| 32 | + |
| 33 | +/** |
| 34 | + * Verifies the correctness of the {@code scratchCentroid} optimisation in |
| 35 | + * {@link KMeansPlusPlusClusterer#updateCentroidsUnweighted()}. |
| 36 | + * |
| 37 | + * <h2>Background</h2> |
| 38 | + * Before this fix, every call to {@code updateCentroidsUnweighted()} allocated a fresh |
| 39 | + * {@code VectorFloat<?>} for <em>each</em> of the {@code k} centroids: |
| 40 | + * <pre> |
| 41 | + * var centroid = centroidNums[i].copy(); // k allocations per Lloyd's pass |
| 42 | + * </pre> |
| 43 | + * With {@code k = 256}, {@code dim/M = 8} floats per subspace centroid, |
| 44 | + * 6 iterations, and 32 subspaces, a single {@code ProductQuantization.compute()} call |
| 45 | + * produced roughly 256 × 6 × 32 = 49 152 short-lived {@code VectorFloat} objects, which |
| 46 | + * appeared as a 35 %+ allocation share in async-profiler flamegraphs during HerdDB indexing. |
| 47 | + * |
| 48 | + * <h2>Fix</h2> |
| 49 | + * A single pre-allocated {@code scratchCentroid} field (one per {@code KMeansPlusPlusClusterer} |
| 50 | + * instance) replaces all of those transient copies. The field is rewritten on every centroid |
| 51 | + * update via {@code scratchCentroid.copyFrom(centroidNums[i], 0, 0, dim)} and is never shared |
| 52 | + * across threads (each clusterer is single-threaded by construction). |
| 53 | + * |
| 54 | + * <h2>Why no allocation assertion?</h2> |
| 55 | + * HotSpot's escape analysis eliminates the {@code copy()} allocation when the method is |
| 56 | + * fully JIT-compiled, so {@link com.sun.management.ThreadMXBean#getThreadAllocatedBytes} |
| 57 | + * cannot distinguish the old code from the new code in a steady-state micro-benchmark. |
| 58 | + * The fix remains valuable because: (a) GraalVM/AOT compilers apply weaker escape analysis, |
| 59 | + * (b) deoptimisation events can temporarily defeat escape analysis under GC pressure, |
| 60 | + * and (c) the fix also reduces object-promotion pressure during the training warm-up phase |
| 61 | + * before methods are fully compiled. Correctness is verified by the two tests below. |
| 62 | + */ |
| 63 | +public class TestKMeansCentroidScratchBuffer { |
| 64 | + |
| 65 | + private static final VectorTypeSupport VTS = |
| 66 | + VectorizationProvider.getInstance().getVectorTypeSupport(); |
| 67 | + |
| 68 | + // ------------------------------------------------------------------------- |
| 69 | + // Correctness tests |
| 70 | + // ------------------------------------------------------------------------- |
| 71 | + |
| 72 | + /** |
| 73 | + * Train PQ on clearly separated cluster centers and verify that every training vector |
| 74 | + * encodes to a centroid within the same cluster. This exercises the full |
| 75 | + * {@code updateCentroidsUnweighted()} path and proves the scratch buffer does not |
| 76 | + * perturb the algorithm. |
| 77 | + */ |
| 78 | + @Test |
| 79 | + public void centroidsConvergeOnWellSeparatedClusters() { |
| 80 | + int k = 4; |
| 81 | + int dim = 8; |
| 82 | + int pointsPerCluster = 200; |
| 83 | + Random rng = new Random(0xABCDEF); |
| 84 | + |
| 85 | + // Build k well-separated cluster centres and surround each with tight Gaussian noise. |
| 86 | + float[][] centers = new float[k][dim]; |
| 87 | + for (int c = 0; c < k; c++) { |
| 88 | + for (int d = 0; d < dim; d++) { |
| 89 | + centers[c][d] = (c + 1) * 100.0f + rng.nextFloat(); |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + List<VectorFloat<?>> vectors = new ArrayList<>(k * pointsPerCluster); |
| 94 | + int[] expectedCluster = new int[k * pointsPerCluster]; |
| 95 | + for (int c = 0; c < k; c++) { |
| 96 | + for (int p = 0; p < pointsPerCluster; p++) { |
| 97 | + VectorFloat<?> v = VTS.createFloatVector(dim); |
| 98 | + for (int d = 0; d < dim; d++) { |
| 99 | + v.set(d, centers[c][d] + (rng.nextFloat() - 0.5f) * 0.1f); |
| 100 | + } |
| 101 | + vectors.add(v); |
| 102 | + expectedCluster[c * pointsPerCluster + p] = c; |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + // Use a single subspace (M=1) so each centroid covers the full vector dimension. |
| 107 | + var ravv = new ListRandomAccessVectorValues(vectors, dim); |
| 108 | + var pq = ProductQuantization.compute(ravv, 1, k, false); |
| 109 | + |
| 110 | + // Verify that all encodings within the same ground-truth cluster map to the same code. |
| 111 | + var cv = (PQVectors) pq.encodeAll(ravv); |
| 112 | + for (int c = 0; c < k; c++) { |
| 113 | + byte code0 = cv.get(c * pointsPerCluster).get(0); |
| 114 | + for (int p = 1; p < pointsPerCluster; p++) { |
| 115 | + byte codeP = cv.get(c * pointsPerCluster + p).get(0); |
| 116 | + assertTrue( |
| 117 | + "All vectors in cluster " + c + " should map to the same centroid; " |
| 118 | + + "got code " + codeP + " but expected " + code0, |
| 119 | + codeP == code0); |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + /** |
| 125 | + * Verify that a single {@link KMeansPlusPlusClusterer#clusterOnceUnweighted()} call |
| 126 | + * strictly improves the quantisation loss compared to random initial centroids. |
| 127 | + * This exercises {@code updateCentroidsUnweighted()} directly (via the package-private |
| 128 | + * accessor) and proves the scratch buffer does not alter the centroid update arithmetic. |
| 129 | + */ |
| 130 | + @Test |
| 131 | + public void singleLloydIterationReducesLoss() { |
| 132 | + int k = 16; |
| 133 | + int dim = 4; |
| 134 | + int n = 500; |
| 135 | + Random rng = new Random(0xDEADBEEF); |
| 136 | + |
| 137 | + VectorFloat<?>[] points = new VectorFloat<?>[n]; |
| 138 | + for (int i = 0; i < n; i++) { |
| 139 | + VectorFloat<?> v = VTS.createFloatVector(dim); |
| 140 | + for (int d = 0; d < dim; d++) v.set(d, rng.nextFloat() * 10f); |
| 141 | + points[i] = v; |
| 142 | + } |
| 143 | + |
| 144 | + var clusterer = new KMeansPlusPlusClusterer(points, k); |
| 145 | + double lossBeforeIteration = quantisationLoss(clusterer.getCentroids(), points, k, dim); |
| 146 | + |
| 147 | + // Execute one Lloyd's pass (updateCentroidsUnweighted + reassignment). |
| 148 | + clusterer.clusterOnceUnweighted(); |
| 149 | + |
| 150 | + double lossAfterIteration = quantisationLoss(clusterer.getCentroids(), points, k, dim); |
| 151 | + assertTrue( |
| 152 | + "One Lloyd's iteration should reduce loss; before=" + lossBeforeIteration |
| 153 | + + " after=" + lossAfterIteration, |
| 154 | + lossAfterIteration <= lossBeforeIteration); |
| 155 | + } |
| 156 | + |
| 157 | + // ------------------------------------------------------------------------- |
| 158 | + // Helpers |
| 159 | + // ------------------------------------------------------------------------- |
| 160 | + |
| 161 | + /** Sum of squared distances from each point to its nearest centroid. */ |
| 162 | + private static double quantisationLoss(VectorFloat<?> centroids, |
| 163 | + VectorFloat<?>[] points, |
| 164 | + int k, |
| 165 | + int dim) { |
| 166 | + var pq = new ProductQuantization( |
| 167 | + new VectorFloat<?>[]{centroids}, |
| 168 | + k, |
| 169 | + getSubvectorSizesAndOffsets(dim, 1), |
| 170 | + null, |
| 171 | + UNWEIGHTED); |
| 172 | + var ravv = new ListRandomAccessVectorValues(List.of(points), dim); |
| 173 | + var cv = (PQVectors) pq.encodeAll(ravv); |
| 174 | + var scratch = VTS.createFloatVector(dim); |
| 175 | + double loss = 0.0; |
| 176 | + for (int i = 0; i < points.length; i++) { |
| 177 | + pq.decode(cv.get(i), scratch); |
| 178 | + loss += 1.0 - VectorSimilarityFunction.EUCLIDEAN.compare(points[i], scratch); |
| 179 | + } |
| 180 | + return loss; |
| 181 | + } |
| 182 | +} |
0 commit comments