Skip to content

Commit 17cf5d9

Browse files
eolivelliclaude
andcommitted
Reuse scratchCentroid buffer in updateCentroidsUnweighted to avoid per-centroid allocation
Before this change, every call to updateCentroidsUnweighted() allocated a fresh VectorFloat<?> for each of the k centroids: var centroid = centroidNums[i].copy(); // k allocations per Lloyd's pass With k=256, dim/M=8 floats per subspace centroid, 6 iterations, and 32 subspaces a single ProductQuantization.compute() call produced roughly 256×6×32 = 49 152 short-lived VectorFloat objects. In HerdDB's vector indexing service these training calls were visible as a 35%+ allocation share in async-profiler flamegraphs during the 500M bigann benchmark (issue herddb#283). Fix: add a private final VectorFloat<?> scratchCentroid field (one per KMeansPlusPlusClusterer instance, allocated once in the 3-arg constructor alongside the existing centroidNums[] entries). In updateCentroidsUnweighted() replace copy() with scratchCentroid.copyFrom(centroidNums[i], 0, 0, dim), then reuse scratchCentroid for the scale() and centroids.copyFrom() calls. Thread safety: KMeansPlusPlusClusterer instances are single-threaded by construction (each subspace in ProductQuantization gets its own instance in the parallel stream); the scratch field is never shared. Note: HotSpot's escape analysis can eliminate the copy() allocation when the method is fully JIT-compiled (the object doesn't escape the stack frame), so the improvement is most pronounced during training warm-up, under GC pressure that defeats escape analysis, and in GraalVM/AOT builds that apply weaker escape analysis. Add TestKMeansCentroidScratchBuffer with two correctness tests: - centroidsConvergeOnWellSeparatedClusters: confirms that training on 4 well-separated cluster centers assigns every point to the correct centroid. - singleLloydIterationReducesLoss: confirms that one clusterOnceUnweighted() call strictly improves (or maintains) the quantisation loss. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7e788a6 commit 17cf5d9

2 files changed

Lines changed: 199 additions & 3 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ public class KMeansPlusPlusClusterer {
5757
private final int[] centroidDenoms; // the number of points assigned to each cluster
5858
private final VectorFloat<?>[] centroidNums; // the sum of all points assigned to each cluster
5959

60+
/**
61+
* Reusable scratch buffer for {@link #updateCentroidsUnweighted()}.
62+
* Avoids allocating a new {@code VectorFloat} per centroid per Lloyd's iteration
63+
* (previously {@code centroidNums[i].copy()} was called k × iterations times per
64+
* training run, generating GC pressure proportional to k × dim × iterations).
65+
* This instance is never shared across threads; each {@code KMeansPlusPlusClusterer}
66+
* is single-threaded by construction.
67+
*/
68+
private final VectorFloat<?> scratchCentroid;
69+
6070
/**
6171
* Constructs a KMeansPlusPlusFloatClusterer with the specified points and number of clusters.
6272
*
@@ -102,6 +112,7 @@ public KMeansPlusPlusClusterer(VectorFloat<?>[] points, VectorFloat<?> centroids
102112
for (int i = 0; i < centroidNums.length; i++) {
103113
centroidNums[i] = vectorTypeSupport.createFloatVector(points[0].length());
104114
}
115+
scratchCentroid = vectorTypeSupport.createFloatVector(points[0].length());
105116
assignments = new int[points.length];
106117

107118
initializeAssignedPoints();
@@ -357,15 +368,18 @@ private static void assertFinite(VectorFloat<?> vector) {
357368
* Calculates centroids from centroidNums/centroidDenoms updated during point assignment
358369
*/
359370
private void updateCentroidsUnweighted() {
371+
int dim = scratchCentroid.length();
360372
for (int i = 0; i < k; i++) {
361373
var denom = centroidDenoms[i];
362374
if (denom == 0) {
363375
// no points assigned to this cluster
364376
initializeCentroidToRandomPoint(i);
365377
} else {
366-
var centroid = centroidNums[i].copy();
367-
scale(centroid, 1.0f / centroidDenoms[i]);
368-
centroids.copyFrom(centroid, 0, i * centroid.length(), centroid.length());
378+
// Reuse the pre-allocated scratch buffer instead of calling centroidNums[i].copy(),
379+
// which would allocate a new VectorFloat on every (centroid × iteration) pass.
380+
scratchCentroid.copyFrom(centroidNums[i], 0, 0, dim);
381+
scale(scratchCentroid, 1.0f / denom);
382+
centroids.copyFrom(scratchCentroid, 0, i * dim, dim);
369383
}
370384
}
371385
}
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.jbellis.jvector.quantization;
17+
18+
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
19+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
20+
import io.github.jbellis.jvector.vector.VectorizationProvider;
21+
import io.github.jbellis.jvector.vector.types.VectorFloat;
22+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
23+
import org.junit.Test;
24+
25+
import java.util.ArrayList;
26+
import java.util.List;
27+
import java.util.Random;
28+
29+
import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED;
30+
import static io.github.jbellis.jvector.quantization.ProductQuantization.getSubvectorSizesAndOffsets;
31+
import static org.junit.Assert.assertTrue;
32+
33+
/**
34+
* Verifies the correctness of the {@code scratchCentroid} optimisation in
35+
* {@link KMeansPlusPlusClusterer#updateCentroidsUnweighted()}.
36+
*
37+
* <h2>Background</h2>
38+
* Before this fix, every call to {@code updateCentroidsUnweighted()} allocated a fresh
39+
* {@code VectorFloat<?>} for <em>each</em> of the {@code k} centroids:
40+
* <pre>
41+
* var centroid = centroidNums[i].copy(); // k allocations per Lloyd's pass
42+
* </pre>
43+
* With {@code k = 256}, {@code dim/M = 8} floats per subspace centroid,
44+
* 6 iterations, and 32 subspaces, a single {@code ProductQuantization.compute()} call
45+
* produced roughly 256 × 6 × 32 = 49 152 short-lived {@code VectorFloat} objects, which
46+
* appeared as a 35 %+ allocation share in async-profiler flamegraphs during HerdDB indexing.
47+
*
48+
* <h2>Fix</h2>
49+
* A single pre-allocated {@code scratchCentroid} field (one per {@code KMeansPlusPlusClusterer}
50+
* instance) replaces all of those transient copies. The field is rewritten on every centroid
51+
* update via {@code scratchCentroid.copyFrom(centroidNums[i], 0, 0, dim)} and is never shared
52+
* across threads (each clusterer is single-threaded by construction).
53+
*
54+
* <h2>Why no allocation assertion?</h2>
55+
* HotSpot's escape analysis eliminates the {@code copy()} allocation when the method is
56+
* fully JIT-compiled, so {@link com.sun.management.ThreadMXBean#getThreadAllocatedBytes}
57+
* cannot distinguish the old code from the new code in a steady-state micro-benchmark.
58+
* The fix remains valuable because: (a) GraalVM/AOT compilers apply weaker escape analysis,
59+
* (b) deoptimisation events can temporarily defeat escape analysis under GC pressure,
60+
* and (c) the fix also reduces object-promotion pressure during the training warm-up phase
61+
* before methods are fully compiled. Correctness is verified by the two tests below.
62+
*/
63+
public class TestKMeansCentroidScratchBuffer {
64+
65+
private static final VectorTypeSupport VTS =
66+
VectorizationProvider.getInstance().getVectorTypeSupport();
67+
68+
// -------------------------------------------------------------------------
69+
// Correctness tests
70+
// -------------------------------------------------------------------------
71+
72+
/**
73+
* Train PQ on clearly separated cluster centers and verify that every training vector
74+
* encodes to a centroid within the same cluster. This exercises the full
75+
* {@code updateCentroidsUnweighted()} path and proves the scratch buffer does not
76+
* perturb the algorithm.
77+
*/
78+
@Test
79+
public void centroidsConvergeOnWellSeparatedClusters() {
80+
int k = 4;
81+
int dim = 8;
82+
int pointsPerCluster = 200;
83+
Random rng = new Random(0xABCDEF);
84+
85+
// Build k well-separated cluster centres and surround each with tight Gaussian noise.
86+
float[][] centers = new float[k][dim];
87+
for (int c = 0; c < k; c++) {
88+
for (int d = 0; d < dim; d++) {
89+
centers[c][d] = (c + 1) * 100.0f + rng.nextFloat();
90+
}
91+
}
92+
93+
List<VectorFloat<?>> vectors = new ArrayList<>(k * pointsPerCluster);
94+
int[] expectedCluster = new int[k * pointsPerCluster];
95+
for (int c = 0; c < k; c++) {
96+
for (int p = 0; p < pointsPerCluster; p++) {
97+
VectorFloat<?> v = VTS.createFloatVector(dim);
98+
for (int d = 0; d < dim; d++) {
99+
v.set(d, centers[c][d] + (rng.nextFloat() - 0.5f) * 0.1f);
100+
}
101+
vectors.add(v);
102+
expectedCluster[c * pointsPerCluster + p] = c;
103+
}
104+
}
105+
106+
// Use a single subspace (M=1) so each centroid covers the full vector dimension.
107+
var ravv = new ListRandomAccessVectorValues(vectors, dim);
108+
var pq = ProductQuantization.compute(ravv, 1, k, false);
109+
110+
// Verify that all encodings within the same ground-truth cluster map to the same code.
111+
var cv = (PQVectors) pq.encodeAll(ravv);
112+
for (int c = 0; c < k; c++) {
113+
byte code0 = cv.get(c * pointsPerCluster).get(0);
114+
for (int p = 1; p < pointsPerCluster; p++) {
115+
byte codeP = cv.get(c * pointsPerCluster + p).get(0);
116+
assertTrue(
117+
"All vectors in cluster " + c + " should map to the same centroid; "
118+
+ "got code " + codeP + " but expected " + code0,
119+
codeP == code0);
120+
}
121+
}
122+
}
123+
124+
/**
125+
* Verify that a single {@link KMeansPlusPlusClusterer#clusterOnceUnweighted()} call
126+
* strictly improves the quantisation loss compared to random initial centroids.
127+
* This exercises {@code updateCentroidsUnweighted()} directly (via the package-private
128+
* accessor) and proves the scratch buffer does not alter the centroid update arithmetic.
129+
*/
130+
@Test
131+
public void singleLloydIterationReducesLoss() {
132+
int k = 16;
133+
int dim = 4;
134+
int n = 500;
135+
Random rng = new Random(0xDEADBEEF);
136+
137+
VectorFloat<?>[] points = new VectorFloat<?>[n];
138+
for (int i = 0; i < n; i++) {
139+
VectorFloat<?> v = VTS.createFloatVector(dim);
140+
for (int d = 0; d < dim; d++) v.set(d, rng.nextFloat() * 10f);
141+
points[i] = v;
142+
}
143+
144+
var clusterer = new KMeansPlusPlusClusterer(points, k);
145+
double lossBeforeIteration = quantisationLoss(clusterer.getCentroids(), points, k, dim);
146+
147+
// Execute one Lloyd's pass (updateCentroidsUnweighted + reassignment).
148+
clusterer.clusterOnceUnweighted();
149+
150+
double lossAfterIteration = quantisationLoss(clusterer.getCentroids(), points, k, dim);
151+
assertTrue(
152+
"One Lloyd's iteration should reduce loss; before=" + lossBeforeIteration
153+
+ " after=" + lossAfterIteration,
154+
lossAfterIteration <= lossBeforeIteration);
155+
}
156+
157+
// -------------------------------------------------------------------------
158+
// Helpers
159+
// -------------------------------------------------------------------------
160+
161+
/** Sum of squared distances from each point to its nearest centroid. */
162+
private static double quantisationLoss(VectorFloat<?> centroids,
163+
VectorFloat<?>[] points,
164+
int k,
165+
int dim) {
166+
var pq = new ProductQuantization(
167+
new VectorFloat<?>[]{centroids},
168+
k,
169+
getSubvectorSizesAndOffsets(dim, 1),
170+
null,
171+
UNWEIGHTED);
172+
var ravv = new ListRandomAccessVectorValues(List.of(points), dim);
173+
var cv = (PQVectors) pq.encodeAll(ravv);
174+
var scratch = VTS.createFloatVector(dim);
175+
double loss = 0.0;
176+
for (int i = 0; i < points.length; i++) {
177+
pq.decode(cv.get(i), scratch);
178+
loss += 1.0 - VectorSimilarityFunction.EUCLIDEAN.compare(points[i], scratch);
179+
}
180+
return loss;
181+
}
182+
}

0 commit comments

Comments
 (0)