diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8f892031fc11..e7a87b6b45c0 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -76,6 +76,11 @@ API Changes New Features --------------------- +* GITHUB#XXXXX: TurboQuant vector quantization codec — data-oblivious rotation-based quantization + with near-optimal distortion rates (Zandieh et al., ICLR 2026). Supports 2/3/4/8 bits per + coordinate, dimensions up to 16384, and byte-copy merge via global rotation seed. Located in + lucene/sandbox module as TurboQuantHnswVectorsFormat. + * GITHUB#15505: Upgrade snowball to 2d2e312df56f2ede014a4ffb3e91e6dea43c24be. New stemmer: PolishStemmer (and PolishSnowballAnalyzer in the stempel package) (Justas Sakalauskas, Dawid Weiss) diff --git a/lucene/benchmark-jmh/build.gradle b/lucene/benchmark-jmh/build.gradle index 6f874e410b9b..78018c95916d 100644 --- a/lucene/benchmark-jmh/build.gradle +++ b/lucene/benchmark-jmh/build.gradle @@ -19,6 +19,7 @@ description = 'Lucene JMH micro-benchmarking module' dependencies { moduleImplementation project(':lucene:core') + moduleImplementation project(':lucene:codecs') moduleImplementation project(':lucene:expressions') moduleImplementation project(':lucene:sandbox') moduleTestImplementation project(':lucene:test-framework') diff --git a/lucene/benchmark-jmh/src/java/module-info.java b/lucene/benchmark-jmh/src/java/module-info.java index 0a283644a35c..1999ed990e2d 100644 --- a/lucene/benchmark-jmh/src/java/module-info.java +++ b/lucene/benchmark-jmh/src/java/module-info.java @@ -23,6 +23,7 @@ requires jmh.core; requires jdk.unsupported; requires org.apache.lucene.core; + requires org.apache.lucene.codecs; requires org.apache.lucene.expressions; requires org.apache.lucene.sandbox; diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/TurboQuantBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/TurboQuantBenchmark.java new file mode 100644 index 000000000000..f34fb076523d --- /dev/null +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/TurboQuantBenchmark.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.sandbox.codecs.turboquant.BetaCodebook; +import org.apache.lucene.sandbox.codecs.turboquant.HadamardRotation; +import org.apache.lucene.sandbox.codecs.turboquant.TurboQuantBitPacker; +import org.apache.lucene.sandbox.codecs.turboquant.TurboQuantEncoding; +import org.apache.lucene.sandbox.codecs.turboquant.TurboQuantScoringUtil; +import org.openjdk.jmh.annotations.*; + +/** JMH benchmarks for TurboQuant core operations. */ +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +@State(Scope.Thread) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +@Fork(1) +public class TurboQuantBenchmark { + + @Param({"4096"}) + int dim; + + @Param({"4"}) + int bits; + + private float[] vector; + private float[] rotated; + private float[] query; + private byte[] indices; + private byte[] packed; + private float[] centroids; + private HadamardRotation rotation; + + @Setup + public void setup() { + Random rng = new Random(42); + TurboQuantEncoding enc = + TurboQuantEncoding.fromWireNumber( + switch (bits) { + case 2 -> 0; + case 3 -> 1; + case 4 -> 2; + case 8 -> 3; + default -> throw new IllegalArgumentException(); + }) + .orElseThrow(); + + vector = new float[dim]; + float norm = 0; + for (int i = 0; i < dim; i++) { + vector[i] = (float) rng.nextGaussian(); + norm += vector[i] * vector[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < dim; i++) vector[i] /= norm; + + rotation = HadamardRotation.create(dim, 12345L); + rotated = new float[dim]; + rotation.rotate(vector, rotated); + + centroids = BetaCodebook.centroids(dim, bits); + float[] boundaries = BetaCodebook.boundaries(dim, bits); + + indices = new byte[dim]; + for (int i = 0; i < dim; i++) { + indices[i] = (byte) BetaCodebook.quantize(rotated[i], boundaries); + } + + packed = new byte[enc.getPackedByteLength(dim)]; + TurboQuantBitPacker.pack(indices, dim, bits, packed); + + query = new float[dim]; + for (int i = 0; i < dim; i++) query[i] = (float) rng.nextGaussian() / (float) Math.sqrt(dim); + } + + @Benchmark + public void hadamardRotation() { + rotation.rotate(vector, rotated); + } + + @Benchmark + public float dotProductScoring() { + return TurboQuantScoringUtil.dotProduct(query, packed, centroids, bits, dim); + } + + @Benchmark + public void quantize() { + float[] boundaries = BetaCodebook.boundaries(dim, bits); + for (int i = 0; i < dim; i++) { + indices[i] = (byte) BetaCodebook.quantize(rotated[i], boundaries); + } + TurboQuantBitPacker.pack(indices, dim, bits, packed); + } +} diff --git a/lucene/sandbox/src/java/module-info.java b/lucene/sandbox/src/java/module-info.java index ee9be3227de2..ab2c2488a96c 100644 --- a/lucene/sandbox/src/java/module-info.java +++ b/lucene/sandbox/src/java/module-info.java @@ -25,6 +25,7 @@ exports org.apache.lucene.sandbox.codecs.faiss; exports org.apache.lucene.sandbox.codecs.idversion; exports org.apache.lucene.sandbox.codecs.quantization; + exports org.apache.lucene.sandbox.codecs.turboquant; exports org.apache.lucene.sandbox.document; exports org.apache.lucene.sandbox.queries; exports org.apache.lucene.sandbox.search; @@ -41,5 +42,6 @@ provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.sandbox.codecs.idversion.IDVersionPostingsFormat; provides org.apache.lucene.codecs.KnnVectorsFormat with - org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat; + org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat, + org.apache.lucene.sandbox.codecs.turboquant.TurboQuantHnswVectorsFormat; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/BetaCodebook.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/BetaCodebook.java new file mode 100644 index 000000000000..bd6bef5e533d --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/BetaCodebook.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +/** + * Precomputed Lloyd-Max optimal centroids for Gaussian-distributed coordinates. After random + * rotation, each coordinate of a unit vector in ℝᵈ follows approximately N(0, 1/d) for d ≥ 64. + * Canonical centroids are computed for N(0,1) and scaled by 1/√d at runtime. + */ +public final class BetaCodebook { + + private BetaCodebook() {} + + // Canonical Lloyd-Max optimal centroids for N(0,1), computed via Lloyd's algorithm. + // Symmetric around 0. Scaled by 1/√d at runtime. + + // @formatter:off + private static final float[] CENTROIDS_2 = { + -1.510418f, -0.452780f, 0.452780f, 1.510418f + }; + + private static final float[] CENTROIDS_3 = { + -2.151946f, -1.343909f, -0.756005f, -0.245094f, + 0.245094f, 0.756005f, 1.343909f, 2.151946f + }; + + private static final float[] CENTROIDS_4 = { + -2.732590f, -2.069017f, -1.618046f, -1.256231f, + -0.942340f, -0.656759f, -0.388048f, -0.128395f, + 0.128395f, 0.388048f, 0.656759f, 0.942340f, + 1.256231f, 1.618046f, 2.069017f, 2.732590f + }; + + private static final float[] CENTROIDS_8 = { + -4.035480f, -3.565625f, -3.268187f, -3.045475f, -2.865491f, -2.713551f, -2.581644f, -2.464895f, + -2.360107f, -2.265066f, -2.178166f, -2.098206f, -2.024257f, -1.955584f, -1.891595f, -1.831799f, + -1.775785f, -1.723203f, -1.673751f, -1.627164f, -1.583207f, -1.541672f, -1.502368f, -1.465126f, + -1.429789f, -1.396212f, -1.364264f, -1.333822f, -1.304772f, -1.277010f, -1.250438f, -1.224965f, + -1.200508f, -1.176989f, -1.154335f, -1.132480f, -1.111361f, -1.090923f, -1.071113f, -1.051883f, + -1.033188f, -1.014988f, -0.997247f, -0.979930f, -0.963006f, -0.946448f, -0.930229f, -0.914327f, + -0.898719f, -0.883388f, -0.868315f, -0.853484f, -0.838881f, -0.824492f, -0.810305f, -0.796310f, + -0.782495f, -0.768852f, -0.755371f, -0.742046f, -0.728869f, -0.715832f, -0.702931f, -0.690157f, + -0.677508f, -0.664976f, -0.652557f, -0.640248f, -0.628042f, -0.615938f, -0.603930f, -0.592014f, + -0.580189f, -0.568449f, -0.556793f, -0.545217f, -0.533718f, -0.522294f, -0.510941f, -0.499658f, + -0.488442f, -0.477290f, -0.466201f, -0.455172f, -0.444200f, -0.433285f, -0.422424f, -0.411614f, + -0.400855f, -0.390145f, -0.379481f, -0.368862f, -0.358286f, -0.347752f, -0.337259f, -0.326803f, + -0.316386f, -0.306003f, -0.295655f, -0.285340f, -0.275057f, -0.264803f, -0.254579f, -0.244382f, + -0.234211f, -0.224066f, -0.213944f, -0.203846f, -0.193768f, -0.183712f, -0.173674f, -0.163654f, + -0.153652f, -0.143665f, -0.133694f, -0.123736f, -0.113791f, -0.103857f, -0.093934f, -0.084021f, + -0.074116f, -0.064219f, -0.054328f, -0.044443f, -0.034562f, -0.024685f, -0.014810f, -0.004936f, + 0.004936f, 0.014810f, 0.024685f, 0.034562f, 0.044443f, 0.054328f, 0.064219f, 0.074116f, + 0.084021f, 0.093934f, 0.103857f, 0.113791f, 0.123736f, 0.133694f, 0.143665f, 0.153652f, + 0.163654f, 0.173674f, 0.183712f, 0.193768f, 0.203846f, 0.213944f, 0.224066f, 0.234211f, + 0.244382f, 0.254579f, 0.264803f, 0.275057f, 0.285340f, 0.295655f, 0.306003f, 0.316386f, + 0.326803f, 0.337259f, 0.347752f, 0.358286f, 0.368862f, 0.379481f, 0.390145f, 0.400855f, + 0.411614f, 0.422424f, 0.433285f, 0.444200f, 0.455172f, 0.466201f, 0.477290f, 0.488442f, + 0.499658f, 0.510941f, 0.522294f, 0.533718f, 0.545217f, 0.556793f, 0.568449f, 0.580189f, + 0.592014f, 0.603930f, 0.615938f, 0.628042f, 0.640248f, 0.652557f, 0.664976f, 0.677508f, + 0.690157f, 0.702931f, 0.715832f, 0.728869f, 0.742046f, 0.755371f, 0.768852f, 0.782495f, + 0.796310f, 0.810305f, 0.824492f, 0.838881f, 0.853484f, 0.868315f, 0.883388f, 0.898719f, + 0.914327f, 0.930229f, 0.946448f, 0.963006f, 0.979930f, 0.997247f, 1.014988f, 1.033188f, + 1.051883f, 1.071113f, 1.090923f, 1.111361f, 1.132480f, 1.154335f, 1.176989f, 1.200508f, + 1.224965f, 1.250438f, 1.277010f, 1.304772f, 1.333822f, 1.364264f, 1.396212f, 1.429789f, + 1.465126f, 1.502368f, 1.541672f, 1.583207f, 1.627164f, 1.673751f, 1.723203f, 1.775785f, + 1.831799f, 1.891595f, 1.955584f, 2.024257f, 2.098206f, 2.178166f, 2.265066f, 2.360107f, + 2.464895f, 2.581644f, 2.713551f, 2.865491f, 3.045475f, 3.268187f, 3.565625f, 4.035480f + }; + // @formatter:on + + private static float[] canonicalCentroids(int b) { + return switch (b) { + case 2 -> CENTROIDS_2; + case 3 -> CENTROIDS_3; + case 4 -> CENTROIDS_4; + case 8 -> CENTROIDS_8; + default -> throw new IllegalArgumentException("Unsupported bit-width: " + b); + }; + } + + /** + * Returns 2^b centroid values scaled by 1/√d for the given dimension and bit-width. These are the + * reconstruction values for quantized coordinates after Hadamard rotation. + */ + public static float[] centroids(int d, int b) { + float[] canonical = canonicalCentroids(b); + float scale = (float) (1.0 / Math.sqrt(d)); + float[] result = new float[canonical.length]; + for (int i = 0; i < canonical.length; i++) { + result[i] = canonical[i] * scale; + } + return result; + } + + /** + * Returns 2^b + 1 decision boundary values scaled by 1/√d. Boundaries are midpoints between + * adjacent centroids, with first = -∞ (represented as {@code -Float.MAX_VALUE}) and last = +∞ + * (represented as {@code Float.MAX_VALUE}). + */ + public static float[] boundaries(int d, int b) { + float[] c = centroids(d, b); + float[] bd = new float[c.length + 1]; + bd[0] = -Float.MAX_VALUE; + bd[c.length] = Float.MAX_VALUE; + for (int i = 0; i < c.length - 1; i++) { + bd[i + 1] = (c[i] + c[i + 1]) / 2; + } + return bd; + } + + /** + * Quantizes a single coordinate value to the nearest centroid index using binary search on + * boundaries. + */ + public static int quantize(float value, float[] boundaries) { + // Binary search for the bin + int lo = 1, hi = boundaries.length - 2; + while (lo <= hi) { + int mid = (lo + hi) >>> 1; + if (value < boundaries[mid]) { + hi = mid - 1; + } else { + lo = mid + 1; + } + } + return hi; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/HadamardRotation.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/HadamardRotation.java new file mode 100644 index 000000000000..cb7fbc9c1a4a --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/HadamardRotation.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.util.Random; + +/** + * Randomized Hadamard rotation for TurboQuant. Applies Π = BlockHadamard · Permutation · SignFlip + * to decorrelate vector coordinates before scalar quantization. + * + *

For power-of-2 dimensions (e.g., d=4096), this is a single Hadamard transform with random + * sign flips. For non-power-of-2 dimensions (e.g., d=768), a block-diagonal Hadamard is used with + * blocks determined by the binary decomposition of d, preceded by a random permutation. + * + *

The rotation is orthogonal, so it preserves all distances and inner products. + */ +public final class HadamardRotation { + + private final int d; + private final int[] blockSizes; + private final int[] permutation; + private final int[] inversePermutation; + private final boolean[] signs; // true = negative + + private HadamardRotation(int d, int[] blockSizes, int[] permutation, boolean[] signs) { + this.d = d; + this.blockSizes = blockSizes; + this.permutation = permutation; + this.signs = signs; + this.inversePermutation = new int[d]; + for (int i = 0; i < d; i++) { + inversePermutation[permutation[i]] = i; + } + } + + /** + * Creates a HadamardRotation for the given dimension and seed. The rotation is deterministic for + * a given (d, seed) pair. + */ + public static HadamardRotation create(int d, long seed) { + if (d < 1) { + throw new IllegalArgumentException("Dimension must be >= 1, got " + d); + } + int[] blockSizes = decomposeBlocks(d); + Random rng = new Random(seed); + + // Fisher-Yates shuffle for random permutation + int[] permutation = new int[d]; + for (int i = 0; i < d; i++) { + permutation[i] = i; + } + for (int i = d - 1; i > 0; i--) { + int j = rng.nextInt(i + 1); + int tmp = permutation[i]; + permutation[i] = permutation[j]; + permutation[j] = tmp; + } + + // Random sign flips + boolean[] signs = new boolean[d]; + for (int i = 0; i < d; i++) { + signs[i] = rng.nextBoolean(); + } + + return new HadamardRotation(d, blockSizes, permutation, signs); + } + + /** + * Decomposes d into power-of-2 block sizes (binary representation). The blocks are returned in + * descending order and sum to d. + */ + static int[] decomposeBlocks(int d) { + if (d < 1) { + throw new IllegalArgumentException("d must be >= 1, got " + d); + } + int bitCount = Integer.bitCount(d); + int[] blocks = new int[bitCount]; + int idx = 0; + for (int bit = 30; bit >= 0; bit--) { + if ((d & (1 << bit)) != 0) { + blocks[idx++] = 1 << bit; + } + } + return blocks; + } + + /** + * Applies the rotation: out = BlockHadamard(Permute(SignFlip(x))). The output is normalized so + * that ||out|| = ||x||. + */ + public void rotate(float[] x, float[] out) { + // Step 1: Sign flip + for (int i = 0; i < d; i++) { + out[i] = signs[i] ? -x[i] : x[i]; + } + + // Step 2: Permute (out[permutation[i]] = signFlipped[i], but we need to reorder) + // We need a temp buffer for the permutation step + float[] temp = new float[d]; + for (int i = 0; i < d; i++) { + temp[permutation[i]] = out[i]; + } + + // Step 3: Block-diagonal Hadamard + int offset = 0; + for (int blockSize : blockSizes) { + fwht(temp, offset, blockSize); + offset += blockSize; + } + + System.arraycopy(temp, 0, out, 0, d); + } + + /** + * Applies the inverse rotation: out = SignFlip⁻¹(Permute⁻¹(BlockHadamard⁻¹(y))). Since + * Hadamard is self-inverse (up to scaling) and we normalize, this exactly inverts rotate(). + */ + public void inverseRotate(float[] y, float[] out) { + // Step 1: Inverse block-diagonal Hadamard (same as forward — Hadamard is self-inverse) + float[] temp = new float[d]; + System.arraycopy(y, 0, temp, 0, d); + int offset = 0; + for (int blockSize : blockSizes) { + fwht(temp, offset, blockSize); + offset += blockSize; + } + + // Step 2: Inverse permute + for (int i = 0; i < d; i++) { + out[i] = temp[permutation[i]]; + } + + // Step 3: Inverse sign flip (same as forward — signs are self-inverse) + for (int i = 0; i < d; i++) { + if (signs[i]) { + out[i] = -out[i]; + } + } + } + + /** + * In-place Fast Walsh-Hadamard Transform on a contiguous block of the array. The transform is + * normalized by 1/√blockSize so that it preserves the L2 norm. + */ + private static void fwht(float[] data, int offset, int n) { + for (int len = 1; len < n; len <<= 1) { + for (int i = 0; i < n; i += len << 1) { + for (int j = 0; j < len; j++) { + int u = offset + i + j; + int v = u + len; + float a = data[u]; + float b = data[v]; + data[u] = a + b; + data[v] = a - b; + } + } + } + // Normalize by 1/√n to preserve L2 norm + float scale = (float) (1.0 / Math.sqrt(n)); + for (int i = 0; i < n; i++) { + data[offset + i] *= scale; + } + } + + /** Returns the dimension this rotation operates on. */ + public int dimension() { + return d; + } + + /** Returns the block sizes used in the block-diagonal Hadamard. */ + public int[] blockSizes() { + return blockSizes.clone(); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/OffHeapTurboQuantVectorValues.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/OffHeapTurboQuantVectorValues.java new file mode 100644 index 000000000000..4e66aed36adc --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/OffHeapTurboQuantVectorValues.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.quantization.BaseQuantizedByteVectorValues; + +/** + * Off-heap random access to TurboQuant quantized vectors stored in a mmap'd {@code .vetq} file. + * Each vector is stored as packed b-bit indices followed by a float32 norm. + */ +public class OffHeapTurboQuantVectorValues extends BaseQuantizedByteVectorValues { + + private final int dimension; + private final int size; + private final int bitsPerCoordinate; + private final int packedBytesPerVector; + private final int bytesPerVector; // packedBytes + 4 (float norm) + private final long dataOffset; + private final IndexInput data; + private final float[] centroids; + private final HadamardRotation rotation; + private final byte[] packedBuffer; + + /** Creates off-heap quantized vector values. */ + public OffHeapTurboQuantVectorValues( + int dimension, + int size, + TurboQuantEncoding encoding, + long dataOffset, + IndexInput data, + float[] centroids, + HadamardRotation rotation) { + this.dimension = dimension; + this.size = size; + this.bitsPerCoordinate = encoding.bitsPerCoordinate; + this.packedBytesPerVector = encoding.getPackedByteLength(dimension); + this.bytesPerVector = packedBytesPerVector + Float.BYTES; + this.dataOffset = dataOffset; + this.data = data; + this.centroids = centroids; + this.rotation = rotation; + this.packedBuffer = new byte[packedBytesPerVector]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + long offset = dataOffset + (long) ord * bytesPerVector; + data.seek(offset); + byte[] buf = new byte[packedBytesPerVector]; + data.readBytes(buf, 0, packedBytesPerVector); + return buf; + } + + /** Returns the stored norm for the given ordinal. */ + public float getNorm(int ord) throws IOException { + long offset = dataOffset + (long) ord * bytesPerVector + packedBytesPerVector; + data.seek(offset); + return Float.intBitsToFloat(data.readInt()); + } + + /** Returns the precomputed centroids scaled for this field's dimension. */ + public float[] getCentroids() { + return centroids; + } + + /** Returns the Hadamard rotation for this field. */ + public HadamardRotation getRotation() { + return rotation; + } + + /** Returns the bits per coordinate for this encoding. */ + public int getBitsPerCoordinate() { + return bitsPerCoordinate; + } + + @Override + public OffHeapTurboQuantVectorValues copy() throws IOException { + return new OffHeapTurboQuantVectorValues( + dimension, + size, + TurboQuantEncoding.fromWireNumber( + switch (bitsPerCoordinate) { + case 2 -> 0; + case 3 -> 1; + case 4 -> 2; + case 8 -> 3; + default -> throw new IllegalStateException(); + }) + .orElseThrow(), + dataOffset, + data.clone(), + centroids, + rotation); + } + + @Override + public VectorEncoding getEncoding() { + return VectorEncoding.BYTE; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public IndexInput getSlice() { + return data; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantBitPacker.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantBitPacker.java new file mode 100644 index 000000000000..0b136b79f4b4 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantBitPacker.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +/** + * Packs and unpacks b-bit quantization indices into byte arrays. Optimized paths for b=2 (4 per + * byte), b=3 (8 indices per 3 bytes), b=4 (2 per byte / nibble), and b=8 (1 per byte / no-op). + */ +public final class TurboQuantBitPacker { + + private TurboQuantBitPacker() {} + + /** Packs b-bit indices into a byte array. */ + public static void pack(byte[] indices, int d, int b, byte[] out) { + switch (b) { + case 2 -> pack2(indices, d, out); + case 3 -> pack3(indices, d, out); + case 4 -> pack4(indices, d, out); + case 8 -> System.arraycopy(indices, 0, out, 0, d); + default -> throw new IllegalArgumentException("Unsupported bit-width: " + b); + } + } + + /** Unpacks b-bit indices from a byte array. */ + public static void unpack(byte[] packed, int b, int d, byte[] out) { + switch (b) { + case 2 -> unpack2(packed, d, out); + case 3 -> unpack3(packed, d, out); + case 4 -> unpack4(packed, d, out); + case 8 -> System.arraycopy(packed, 0, out, 0, d); + default -> throw new IllegalArgumentException("Unsupported bit-width: " + b); + } + } + + // b=2: 4 indices per byte, MSB first + private static void pack2(byte[] indices, int d, byte[] out) { + int outIdx = 0; + int i = 0; + for (; i + 3 < d; i += 4) { + out[outIdx++] = + (byte) + (((indices[i] & 0x03) << 6) + | ((indices[i + 1] & 0x03) << 4) + | ((indices[i + 2] & 0x03) << 2) + | (indices[i + 3] & 0x03)); + } + // Handle remainder + if (i < d) { + int val = 0; + for (int shift = 6; i < d; i++, shift -= 2) { + val |= (indices[i] & 0x03) << shift; + } + out[outIdx] = (byte) val; + } + } + + private static void unpack2(byte[] packed, int d, byte[] out) { + int pIdx = 0; + int i = 0; + for (; i + 3 < d; i += 4) { + int b = packed[pIdx++] & 0xFF; + out[i] = (byte) ((b >> 6) & 0x03); + out[i + 1] = (byte) ((b >> 4) & 0x03); + out[i + 2] = (byte) ((b >> 2) & 0x03); + out[i + 3] = (byte) (b & 0x03); + } + if (i < d) { + int b = packed[pIdx] & 0xFF; + for (int shift = 6; i < d; i++, shift -= 2) { + out[i] = (byte) ((b >> shift) & 0x03); + } + } + } + + // b=3: 8 indices per 3 bytes + private static void pack3(byte[] indices, int d, byte[] out) { + int outIdx = 0; + int i = 0; + for (; i + 7 < d; i += 8) { + // Pack 8 3-bit values into 3 bytes (24 bits) + int bits = + ((indices[i] & 0x07) << 21) + | ((indices[i + 1] & 0x07) << 18) + | ((indices[i + 2] & 0x07) << 15) + | ((indices[i + 3] & 0x07) << 12) + | ((indices[i + 4] & 0x07) << 9) + | ((indices[i + 5] & 0x07) << 6) + | ((indices[i + 6] & 0x07) << 3) + | (indices[i + 7] & 0x07); + out[outIdx++] = (byte) (bits >> 16); + out[outIdx++] = (byte) (bits >> 8); + out[outIdx++] = (byte) bits; + } + // Handle remainder + if (i < d) { + int bits = 0; + int shift = 21; + for (int j = i; j < d; j++, shift -= 3) { + bits |= (indices[j] & 0x07) << shift; + } + out[outIdx++] = (byte) (bits >> 16); + if (outIdx < out.length) out[outIdx++] = (byte) (bits >> 8); + if (outIdx < out.length) out[outIdx] = (byte) bits; + } + } + + private static void unpack3(byte[] packed, int d, byte[] out) { + int pIdx = 0; + int i = 0; + for (; i + 7 < d; i += 8) { + int bits = + ((packed[pIdx] & 0xFF) << 16) + | ((packed[pIdx + 1] & 0xFF) << 8) + | (packed[pIdx + 2] & 0xFF); + pIdx += 3; + out[i] = (byte) ((bits >> 21) & 0x07); + out[i + 1] = (byte) ((bits >> 18) & 0x07); + out[i + 2] = (byte) ((bits >> 15) & 0x07); + out[i + 3] = (byte) ((bits >> 12) & 0x07); + out[i + 4] = (byte) ((bits >> 9) & 0x07); + out[i + 5] = (byte) ((bits >> 6) & 0x07); + out[i + 6] = (byte) ((bits >> 3) & 0x07); + out[i + 7] = (byte) (bits & 0x07); + } + if (i < d) { + int bits = + ((pIdx < packed.length ? packed[pIdx] & 0xFF : 0) << 16) + | ((pIdx + 1 < packed.length ? packed[pIdx + 1] & 0xFF : 0) << 8) + | (pIdx + 2 < packed.length ? packed[pIdx + 2] & 0xFF : 0); + for (int shift = 21; i < d; i++, shift -= 3) { + out[i] = (byte) ((bits >> shift) & 0x07); + } + } + } + + // b=4: 2 indices per byte (nibble packing) + private static void pack4(byte[] indices, int d, byte[] out) { + int outIdx = 0; + int i = 0; + for (; i + 1 < d; i += 2) { + out[outIdx++] = (byte) (((indices[i] & 0x0F) << 4) | (indices[i + 1] & 0x0F)); + } + if (i < d) { + out[outIdx] = (byte) ((indices[i] & 0x0F) << 4); + } + } + + private static void unpack4(byte[] packed, int d, byte[] out) { + int pIdx = 0; + int i = 0; + for (; i + 1 < d; i += 2) { + int b = packed[pIdx++] & 0xFF; + out[i] = (byte) ((b >> 4) & 0x0F); + out[i + 1] = (byte) (b & 0x0F); + } + if (i < d) { + out[i] = (byte) ((packed[pIdx] >> 4) & 0x0F); + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantEncoding.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantEncoding.java new file mode 100644 index 000000000000..8ab8b32aa57f --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantEncoding.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.util.Optional; + +/** + * Bit-width encoding for TurboQuant vector quantization. Each coordinate of a rotated vector is + * quantized to this many bits using precomputed Beta-distribution-optimal Lloyd-Max centroids. + */ +public enum TurboQuantEncoding { + /** 2 bits per coordinate, 16x compression, aggressive. */ + BITS_2(0, 2), + /** 3 bits per coordinate, ~10.7x compression. */ + BITS_3(1, 3), + /** 4 bits per coordinate, 8x compression, default, best recall/compression trade-off. */ + BITS_4(2, 4), + /** 8 bits per coordinate, 4x compression, near-lossless. */ + BITS_8(3, 8); + + private final int wireNumber; + + /** Number of bits used per coordinate. */ + public final int bitsPerCoordinate; + + TurboQuantEncoding(int wireNumber, int bitsPerCoordinate) { + this.wireNumber = wireNumber; + this.bitsPerCoordinate = bitsPerCoordinate; + } + + /** Returns the wire number used for serialization. */ + public int getWireNumber() { + return wireNumber; + } + + /** + * Returns the number of bytes required to store a packed quantized vector of the given + * dimensionality. + */ + public int getPackedByteLength(int d) { + return (d * bitsPerCoordinate + 7) / 8; + } + + /** + * Returns the number of dimensions rounded up so that the packed representation fills whole + * bytes. + */ + public int getDiscreteDimensions(int d) { + int totalBits = d * bitsPerCoordinate; + int roundedBits = (totalBits + 7) / 8 * 8; + return roundedBits / bitsPerCoordinate; + } + + /** Returns the encoding for the given wire number, or empty if unknown. */ + public static Optional fromWireNumber(int wireNumber) { + for (TurboQuantEncoding encoding : values()) { + if (encoding.wireNumber == wireNumber) { + return Optional.of(encoding); + } + } + return Optional.empty(); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsFormat.java new file mode 100644 index 000000000000..083ea629a27d --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsFormat.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * TurboQuant flat vectors format. Stores quantized vectors using rotation-based data-oblivious + * quantization with precomputed Beta-distribution-optimal Lloyd-Max centroids. + * + *

This format stores both raw float32 vectors (delegated to {@link Lucene99FlatVectorsFormat}) + * and quantized vectors in separate files. The quantized vectors use unique extensions {@code .vetq} + * (data) and {@code .vemtq} (metadata). + */ +public class TurboQuantFlatVectorsFormat extends FlatVectorsFormat { + + public static final String NAME = "TurboQuantFlatVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "TurboQuantVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "TurboQuantVectorsFormatData"; + static final String META_EXTENSION = "vemtq"; + static final String VECTOR_DATA_EXTENSION = "vetq"; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private final TurboQuantEncoding encoding; + private final Long rotationSeed; + private final FlatVectorsScorer scorer; + + /** Creates a new instance with default BITS_4 encoding. */ + public TurboQuantFlatVectorsFormat() { + this(TurboQuantEncoding.BITS_4); + } + + /** Creates a new instance with the given encoding. */ + public TurboQuantFlatVectorsFormat(TurboQuantEncoding encoding) { + this(encoding, null); + } + + /** + * Creates a new instance with the given encoding and optional explicit rotation seed. + * + * @param encoding the quantization bit-width + * @param rotationSeed explicit rotation seed, or null to derive from field name + */ + public TurboQuantFlatVectorsFormat(TurboQuantEncoding encoding, Long rotationSeed) { + super(NAME); + this.encoding = encoding; + this.rotationSeed = rotationSeed; + this.scorer = new TurboQuantVectorsScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new TurboQuantFlatVectorsWriter( + state, encoding, rotationSeed, rawVectorFormat.fieldsWriter(state), scorer); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new TurboQuantFlatVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 16384; + } + + @Override + public String toString() { + return "TurboQuantFlatVectorsFormat(name=" + + NAME + + ", encoding=" + + encoding + + ", rotationSeed=" + + rotationSeed + + ")"; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsReader.java new file mode 100644 index 000000000000..0f20ee40926f --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsReader.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.BaseQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedVectorsReader; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +/** + * Reader for TurboQuant quantized vectors. Reads quantized data from {@code .vetq} and metadata + * from {@code .vemtq}, delegating raw vector access to the underlying {@link + * org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat} reader. + */ +public class TurboQuantFlatVectorsReader extends FlatVectorsReader + implements QuantizedVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TurboQuantFlatVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + + public TurboQuantFlatVectorsReader( + SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer) + throws IOException { + super(scorer); + this.rawVectorsReader = rawVectorsReader; + + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + TurboQuantFlatVectorsFormat.META_EXTENSION); + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + CodecUtil.checkIndexHeader( + meta, + TurboQuantFlatVectorsFormat.META_CODEC_NAME, + TurboQuantFlatVectorsFormat.VERSION_START, + TurboQuantFlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + TurboQuantFlatVectorsFormat.VECTOR_DATA_EXTENSION); + try { + quantizedVectorData = + state.directory.openInput(vectorDataFileName, state.context); + CodecUtil.checkIndexHeader( + quantizedVectorData, + TurboQuantFlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + TurboQuantFlatVectorsFormat.VERSION_START, + TurboQuantFlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + int dimension = meta.readInt(); + int vectorCount = meta.readInt(); + int encodingWire = meta.readInt(); + int simOrdinal = meta.readInt(); + long rotationSeed = meta.readLong(); + long vectorDataOffset = meta.readLong(); + long vectorDataLength = meta.readLong(); + + TurboQuantEncoding encoding = + TurboQuantEncoding.fromWireNumber(encodingWire) + .orElseThrow( + () -> + new CorruptIndexException( + "Unknown TurboQuant encoding wire number: " + encodingWire, meta)); + + VectorSimilarityFunction similarityFunction = + VectorSimilarityFunction.values()[simOrdinal]; + + fields.put( + info.name, + new FieldEntry( + dimension, + vectorCount, + encoding, + similarityFunction, + rotationSeed, + vectorDataOffset, + vectorDataLength)); + } + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public org.apache.lucene.index.ByteVectorValues getByteVectorValues(String field) + throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry entry = fields.get(field); + if (entry == null) { + return null; + } + OffHeapTurboQuantVectorValues quantizedValues = getQuantizedValues(field, entry); + return vectorScorer.getRandomVectorScorer( + entry.similarityFunction, quantizedValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public BaseQuantizedByteVectorValues getQuantizedVectorValues(String fieldName) + throws IOException { + FieldEntry entry = fields.get(fieldName); + if (entry == null) { + return null; + } + return getQuantizedValues(fieldName, entry); + } + + @Override + public ScalarQuantizer getQuantizationState(String fieldName) { + // TurboQuant doesn't use ScalarQuantizer + return null; + } + + private OffHeapTurboQuantVectorValues getQuantizedValues(String field, FieldEntry entry) + throws IOException { + HadamardRotation rotation = HadamardRotation.create(entry.dimension, entry.rotationSeed); + float[] centroids = BetaCodebook.centroids(entry.dimension, entry.encoding.bitsPerCoordinate); + return new OffHeapTurboQuantVectorValues( + entry.dimension, + entry.vectorCount, + entry.encoding, + entry.vectorDataOffset, + quantizedVectorData.clone(), + centroids, + rotation); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_SIZE; + total += RamUsageEstimator.sizeOfMap(fields); + total += rawVectorsReader.ramBytesUsed(); + return total; + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + Map result = new HashMap<>(rawVectorsReader.getOffHeapByteSize(fieldInfo)); + FieldEntry entry = fields.get(fieldInfo.name); + if (entry != null) { + result.put(TurboQuantFlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength); + } + return result; + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + /** Per-field metadata read from .vemtq. */ + private record FieldEntry( + int dimension, + int vectorCount, + TurboQuantEncoding encoding, + VectorSimilarityFunction similarityFunction, + long rotationSeed, + long vectorDataOffset, + long vectorDataLength) {} +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsWriter.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsWriter.java new file mode 100644 index 000000000000..6b2ef9fe8eed --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantFlatVectorsWriter.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** + * Writer for TurboQuant quantized vectors. Delegates raw vector storage to {@link + * org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat} and writes quantized data to {@code + * .vetq} and metadata to {@code .vemtq}. + */ +public class TurboQuantFlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(TurboQuantFlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, quantizedVectorData; + private final TurboQuantEncoding encoding; + private final Long rotationSeed; + private final FlatVectorsWriter rawVectorDelegate; + private boolean finished; + + public TurboQuantFlatVectorsWriter( + SegmentWriteState state, + TurboQuantEncoding encoding, + Long rotationSeed, + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer) + throws IOException { + super(scorer); + this.encoding = encoding; + this.rotationSeed = rotationSeed; + this.segmentWriteState = state; + this.rawVectorDelegate = rawVectorDelegate; + + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + TurboQuantFlatVectorsFormat.META_EXTENSION); + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + TurboQuantFlatVectorsFormat.VECTOR_DATA_EXTENSION); + try { + meta = state.directory.createOutput(metaFileName, state.context); + quantizedVectorData = state.directory.createOutput(vectorDataFileName, state.context); + CodecUtil.writeIndexHeader( + meta, + TurboQuantFlatVectorsFormat.META_CODEC_NAME, + TurboQuantFlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + quantizedVectorData, + TurboQuantFlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + TurboQuantFlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, this); + throw t; + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawFieldWriter = rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = + new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawFieldWriter); + fields.add(fieldWriter); + return fieldWriter; + } + return rawFieldWriter; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + if (vectorCount == 0) { + continue; + } + int d = field.fieldInfo.getVectorDimension(); + long seed = getRotationSeed(field.fieldInfo); + HadamardRotation rotation = HadamardRotation.create(d, seed); + float[] boundaries = BetaCodebook.boundaries(d, encoding.bitsPerCoordinate); + + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + + List vectors = field.flatFieldVectorsWriter.getVectors(); + float[] rotated = new float[d]; + byte[] indices = new byte[d]; + byte[] packed = new byte[encoding.getPackedByteLength(d)]; + + for (float[] vector : vectors) { + writeQuantizedVector(vector, d, rotation, boundaries, indices, rotated, packed); + } + + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + writeMeta(field.fieldInfo, vectorDataOffset, vectorDataLength, vectorCount, seed); + field.finish(); + } + } + + private void writeQuantizedVector( + float[] vector, + int d, + HadamardRotation rotation, + float[] boundaries, + byte[] indices, + float[] rotated, + byte[] packed) + throws IOException { + float norm = 0; + for (int i = 0; i < d; i++) norm += vector[i] * vector[i]; + norm = (float) Math.sqrt(norm); + + float[] normalized = new float[d]; + if (norm > 0) { + for (int i = 0; i < d; i++) normalized[i] = vector[i] / norm; + } + + rotation.rotate(normalized, rotated); + + for (int i = 0; i < d; i++) { + indices[i] = (byte) BetaCodebook.quantize(rotated[i], boundaries); + } + + TurboQuantBitPacker.pack(indices, d, encoding.bitsPerCoordinate, packed); + quantizedVectorData.writeBytes(packed, packed.length); + quantizedVectorData.writeInt(Float.floatToIntBits(norm)); + } + + private void writeMeta( + FieldInfo fieldInfo, + long vectorDataOffset, + long vectorDataLength, + int vectorCount, + long rotSeed) + throws IOException { + meta.writeInt(fieldInfo.number); + meta.writeInt(fieldInfo.getVectorDimension()); + meta.writeInt(vectorCount); + meta.writeInt(encoding.getWireNumber()); + meta.writeInt(fieldInfo.getVectorSimilarityFunction().ordinal()); + meta.writeLong(rotSeed); + meta.writeLong(vectorDataOffset); + meta.writeLong(vectorDataLength); + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (!fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + + int d = fieldInfo.getVectorDimension(); + long seed = getRotationSeed(fieldInfo); + HadamardRotation rotation = HadamardRotation.create(d, seed); + float[] centroids = BetaCodebook.centroids(d, encoding.bitsPerCoordinate); + float[] boundaries = BetaCodebook.boundaries(d, encoding.bitsPerCoordinate); + + // Write quantized vectors to a temp file + IndexOutput tempOutput = + segmentWriteState.directory.createTempOutput( + quantizedVectorData.getName(), "temp", segmentWriteState.context); + String tempName = tempOutput.getName(); + + int vectorCount = 0; + float[] rotated = new float[d]; + byte[] indices = new byte[d]; + byte[] packed = new byte[encoding.getPackedByteLength(d)]; + + try { + FloatVectorValues mergedVectors = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + KnnVectorValues.DocIndexIterator iter = mergedVectors.iterator(); + while (iter.nextDoc() != KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { + float[] vector = mergedVectors.vectorValue(iter.index()); + float norm = 0; + for (int i = 0; i < d; i++) norm += vector[i] * vector[i]; + norm = (float) Math.sqrt(norm); + + float[] normalized = new float[d]; + if (norm > 0) { + for (int i = 0; i < d; i++) normalized[i] = vector[i] / norm; + } + + rotation.rotate(normalized, rotated); + for (int i = 0; i < d; i++) { + indices[i] = (byte) BetaCodebook.quantize(rotated[i], boundaries); + } + TurboQuantBitPacker.pack(indices, d, encoding.bitsPerCoordinate, packed); + tempOutput.writeBytes(packed, packed.length); + tempOutput.writeInt(Float.floatToIntBits(norm)); + vectorCount++; + } + CodecUtil.writeFooter(tempOutput); + } finally { + IOUtils.close(tempOutput); + } + + // Copy temp data to the real output + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + IndexInput tempInput = + segmentWriteState.directory.openInput(tempName, segmentWriteState.context); + try { + quantizedVectorData.copyBytes(tempInput, tempInput.length() - CodecUtil.footerLength()); + } catch (Throwable t) { + IOUtils.closeWhileSuppressingExceptions(t, tempInput); + throw t; + } + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, vectorCount, seed); + + // Use the temp file for the scorer (the real .vetq is still open for writing) + final int finalVectorCount = vectorCount; + OffHeapTurboQuantVectorValues quantizedValues = + new OffHeapTurboQuantVectorValues( + d, + finalVectorCount, + encoding, + 0, // temp file starts at 0 + tempInput, + centroids, + rotation); + + RandomVectorScorerSupplier scorerSupplier = + vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), quantizedValues); + + return new TurboQuantCloseableScorerSupplier(scorerSupplier, () -> { + IOUtils.close(tempInput); + segmentWriteState.directory.deleteFile(tempName); + }, finalVectorCount); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + meta.writeInt(-1); // sentinel + CodecUtil.writeFooter(meta); + CodecUtil.writeFooter(quantizedVectorData); + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + total += rawVectorDelegate.ramBytesUsed(); + return total; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, quantizedVectorData, rawVectorDelegate); + } + + private long getRotationSeed(FieldInfo fieldInfo) { + if (rotationSeed != null) { + return rotationSeed; + } + return murmurhash3(fieldInfo.name); + } + + private static long murmurhash3(String key) { + byte[] bytes = key.getBytes(java.nio.charset.StandardCharsets.UTF_8); + long h = 0xcafebabe; + for (byte b : bytes) { + h ^= b; + h *= 0x5bd1e9955bd1e995L; + h ^= h >>> 47; + } + return h; + } + + /** Per-field writer that wraps the raw delegate. */ + private static class FieldWriter extends FlatFieldVectorsWriter { + final FieldInfo fieldInfo; + final FlatFieldVectorsWriter flatFieldVectorsWriter; + private boolean isFinished = false; + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = delegate; + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + } + + @Override + public float[] copyValue(float[] vectorValue) { + return flatFieldVectorsWriter.copyValue(vectorValue); + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (isFinished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + isFinished = true; + } + + @Override + public boolean isFinished() { + return isFinished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public long ramBytesUsed() { + return flatFieldVectorsWriter.ramBytesUsed() + + RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + } + } + + /** Closeable scorer supplier for merge. */ + private static class TurboQuantCloseableScorerSupplier + implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier delegate; + private final java.io.Closeable toClose; + private final int totalVectorCount; + + TurboQuantCloseableScorerSupplier( + RandomVectorScorerSupplier delegate, java.io.Closeable toClose, int totalVectorCount) { + this.delegate = delegate; + this.toClose = toClose; + this.totalVectorCount = totalVectorCount; + } + + @Override + public org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer scorer() throws IOException { + return delegate.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return delegate.copy(); + } + + @Override + public int totalVectorCount() { + return totalVectorCount; + } + + @Override + public void close() throws IOException { + IOUtils.close(toClose); + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantHnswVectorsFormat.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantHnswVectorsFormat.java new file mode 100644 index 000000000000..7d64215e51bc --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantHnswVectorsFormat.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.HNSW_GRAPH_THRESHOLD; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; + +/** + * Convenience format composing HNSW graph with TurboQuant flat vector quantization. This is the + * primary user-facing format for TurboQuant vector search. + */ +public class TurboQuantHnswVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "TurboQuantHnswVectorsFormat"; + + private final int maxConn; + private final int beamWidth; + private final TurboQuantFlatVectorsFormat flatVectorsFormat; + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + private final int tinySegmentsThreshold; + + /** Constructs with default parameters: BITS_4, maxConn=16, beamWidth=100. */ + public TurboQuantHnswVectorsFormat() { + this(TurboQuantEncoding.BITS_4, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + /** Constructs with the given encoding and default HNSW parameters. */ + public TurboQuantHnswVectorsFormat(TurboQuantEncoding encoding, int maxConn, int beamWidth) { + this(encoding, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null); + } + + /** + * Full constructor with all parameters. + * + * @param encoding quantization bit-width + * @param maxConn maximum connections per node in HNSW graph + * @param beamWidth beam width for graph construction + * @param numMergeWorkers number of merge workers (1 = single-threaded) + * @param mergeExec executor for parallel merge, or null for single-threaded + * @param rotationSeed explicit rotation seed, or null to derive from field name + */ + public TurboQuantHnswVectorsFormat( + TurboQuantEncoding encoding, + int maxConn, + int beamWidth, + int numMergeWorkers, + ExecutorService mergeExec, + Long rotationSeed) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and <= " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and <= " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth); + } + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + this.flatVectorsFormat = new TurboQuantFlatVectorsFormat(encoding, rotationSeed); + this.numMergeWorkers = numMergeWorkers; + this.tinySegmentsThreshold = HNSW_GRAPH_THRESHOLD; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec, + tinySegmentsThreshold); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 16384; + } + + @Override + public String toString() { + return "TurboQuantHnswVectorsFormat(name=" + + NAME + + ", maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorsFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantScoringUtil.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantScoringUtil.java new file mode 100644 index 000000000000..8c35a1b59b63 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantScoringUtil.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +/** + * Optimized scoring utilities for TurboQuant quantized vectors. Uses LUT-based approach where + * centroid values are gathered via index lookup, enabling JVM auto-vectorization. + */ +public final class TurboQuantScoringUtil { + + private TurboQuantScoringUtil() {} + + /** + * Computes dot product between a float query vector (already rotated) and a quantized document + * vector stored as packed b-bit indices. + * + * @param query rotated query vector + * @param packedIndices packed b-bit quantization indices + * @param centroids centroid values (2^b entries, scaled by 1/√d) + * @param b bits per coordinate (2, 3, 4, or 8) + * @param d dimension + * @return dot product in rotated space + */ + public static float dotProduct( + float[] query, byte[] packedIndices, float[] centroids, int b, int d) { + return switch (b) { + case 4 -> dotProduct4(query, packedIndices, centroids, d); + case 8 -> dotProduct8(query, packedIndices, centroids, d); + case 2 -> dotProduct2(query, packedIndices, centroids, d); + case 3 -> dotProduct3(query, packedIndices, centroids, d); + default -> throw new IllegalArgumentException("Unsupported bit-width: " + b); + }; + } + + /** + * Computes squared Euclidean distance between a float query vector and a quantized document + * vector. + */ + public static float squareDistance( + float[] query, byte[] packedIndices, float[] centroids, int b, int d, float docNorm) { + return switch (b) { + case 4 -> squareDistance4(query, packedIndices, centroids, d, docNorm); + case 8 -> squareDistance8(query, packedIndices, centroids, d, docNorm); + case 2 -> squareDistance2(query, packedIndices, centroids, d, docNorm); + case 3 -> squareDistance3(query, packedIndices, centroids, d, docNorm); + default -> throw new IllegalArgumentException("Unsupported bit-width: " + b); + }; + } + + // b=4: 2 indices per byte (nibble-packed), 16-entry LUT + private static float dotProduct4(float[] query, byte[] packed, float[] centroids, int d) { + float sum = 0; + int qi = 0; + for (int i = 0; i < packed.length && qi < d; i++) { + int b = packed[i] & 0xFF; + sum += query[qi] * centroids[(b >> 4) & 0x0F]; + qi++; + if (qi < d) { + sum += query[qi] * centroids[b & 0x0F]; + qi++; + } + } + return sum; + } + + private static float squareDistance4( + float[] query, byte[] packed, float[] centroids, int d, float docNorm) { + float sum = 0; + int qi = 0; + for (int i = 0; i < packed.length && qi < d; i++) { + int b = packed[i] & 0xFF; + float diff = query[qi] - centroids[(b >> 4) & 0x0F] * docNorm; + sum += diff * diff; + qi++; + if (qi < d) { + diff = query[qi] - centroids[b & 0x0F] * docNorm; + sum += diff * diff; + qi++; + } + } + return sum; + } + + // b=8: 1 index per byte, 256-entry LUT + private static float dotProduct8(float[] query, byte[] packed, float[] centroids, int d) { + float sum = 0; + for (int i = 0; i < d; i++) { + sum += query[i] * centroids[packed[i] & 0xFF]; + } + return sum; + } + + private static float squareDistance8( + float[] query, byte[] packed, float[] centroids, int d, float docNorm) { + float sum = 0; + for (int i = 0; i < d; i++) { + float diff = query[i] - centroids[packed[i] & 0xFF] * docNorm; + sum += diff * diff; + } + return sum; + } + + // b=2: 4 indices per byte + private static float dotProduct2(float[] query, byte[] packed, float[] centroids, int d) { + float sum = 0; + int qi = 0; + for (int i = 0; i < packed.length && qi < d; i++) { + int b = packed[i] & 0xFF; + sum += query[qi++] * centroids[(b >> 6) & 0x03]; + if (qi < d) sum += query[qi++] * centroids[(b >> 4) & 0x03]; + if (qi < d) sum += query[qi++] * centroids[(b >> 2) & 0x03]; + if (qi < d) sum += query[qi++] * centroids[b & 0x03]; + } + return sum; + } + + private static float squareDistance2( + float[] query, byte[] packed, float[] centroids, int d, float docNorm) { + float sum = 0; + int qi = 0; + for (int i = 0; i < packed.length && qi < d; i++) { + int b = packed[i] & 0xFF; + for (int shift = 6; shift >= 0 && qi < d; shift -= 2) { + float diff = query[qi++] - centroids[(b >> shift) & 0x03] * docNorm; + sum += diff * diff; + } + } + return sum; + } + + // b=3: 8 indices per 3 bytes + private static float dotProduct3(float[] query, byte[] packed, float[] centroids, int d) { + float sum = 0; + int qi = 0; + int pi = 0; + while (qi + 7 < d && pi + 2 < packed.length) { + int bits = + ((packed[pi] & 0xFF) << 16) | ((packed[pi + 1] & 0xFF) << 8) | (packed[pi + 2] & 0xFF); + pi += 3; + sum += query[qi++] * centroids[(bits >> 21) & 0x07]; + sum += query[qi++] * centroids[(bits >> 18) & 0x07]; + sum += query[qi++] * centroids[(bits >> 15) & 0x07]; + sum += query[qi++] * centroids[(bits >> 12) & 0x07]; + sum += query[qi++] * centroids[(bits >> 9) & 0x07]; + sum += query[qi++] * centroids[(bits >> 6) & 0x07]; + sum += query[qi++] * centroids[(bits >> 3) & 0x07]; + sum += query[qi++] * centroids[bits & 0x07]; + } + // Handle remainder + if (qi < d && pi < packed.length) { + int bits = + ((pi < packed.length ? packed[pi] & 0xFF : 0) << 16) + | ((pi + 1 < packed.length ? packed[pi + 1] & 0xFF : 0) << 8) + | (pi + 2 < packed.length ? packed[pi + 2] & 0xFF : 0); + for (int shift = 21; qi < d; shift -= 3) { + sum += query[qi++] * centroids[(bits >> shift) & 0x07]; + } + } + return sum; + } + + private static float squareDistance3( + float[] query, byte[] packed, float[] centroids, int d, float docNorm) { + // Unpack and compute — b=3 is less common, use generic path + byte[] indices = new byte[d]; + TurboQuantBitPacker.unpack(packed, 3, d, indices); + float sum = 0; + for (int i = 0; i < d; i++) { + float diff = query[i] - centroids[indices[i] & 0x07] * docNorm; + sum += diff * diff; + } + return sum; + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantVectorsScorer.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantVectorsScorer.java new file mode 100644 index 000000000000..e4e19284fa92 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/TurboQuantVectorsScorer.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +/** + * Scorer for TurboQuant quantized vectors. Rotates the query vector once, then computes distances + * in the rotated space against quantized candidate vectors. + * + *

This is a naive (non-SIMD) implementation for correctness. Phase 3 replaces it with + * LUT-based SIMD scoring. + */ +public class TurboQuantVectorsScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer rawScorer; + + public TurboQuantVectorsScorer(FlatVectorsScorer rawScorer) { + this.rawScorer = rawScorer; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof OffHeapTurboQuantVectorValues quantizedValues) { + return new TurboQuantScorerSupplier(similarityFunction, quantizedValues); + } + return rawScorer.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof OffHeapTurboQuantVectorValues quantizedValues) { + return new TurboQuantQueryScorer(similarityFunction, quantizedValues, target); + } + return rawScorer.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + throw new UnsupportedOperationException("TurboQuant only supports float32 vectors"); + } + + @Override + public String toString() { + return "TurboQuantVectorsScorer(rawScorer=" + rawScorer + ")"; + } + + /** Scorer for a single query against quantized vectors. */ + private static class TurboQuantQueryScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final VectorSimilarityFunction similarityFunction; + private final OffHeapTurboQuantVectorValues quantizedValues; + private final float[] rotatedQuery; + + TurboQuantQueryScorer( + VectorSimilarityFunction similarityFunction, + OffHeapTurboQuantVectorValues quantizedValues, + float[] target) { + super(quantizedValues); + this.similarityFunction = similarityFunction; + this.quantizedValues = quantizedValues; + + // Rotate query once + HadamardRotation rotation = quantizedValues.getRotation(); + int d = target.length; + + // Normalize for cosine similarity + float[] normalized; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + normalized = new float[d]; + float norm = 0; + for (int i = 0; i < d; i++) norm += target[i] * target[i]; + norm = (float) Math.sqrt(norm); + if (norm > 0) { + for (int i = 0; i < d; i++) normalized[i] = target[i] / norm; + } + } else { + normalized = target; + } + + this.rotatedQuery = new float[d]; + rotation.rotate(normalized, rotatedQuery); + } + + @Override + public float score(int node) throws IOException { + float[] centroids = quantizedValues.getCentroids(); + int d = quantizedValues.dimension(); + byte[] packedIndices = quantizedValues.vectorValue(node); + int b = quantizedValues.getBitsPerCoordinate(); + float docNorm = quantizedValues.getNorm(node); + + return switch (similarityFunction) { + case DOT_PRODUCT -> { + float dot = TurboQuantScoringUtil.dotProduct(rotatedQuery, packedIndices, centroids, b, d); + // DOT_PRODUCT expects unit vectors; dot already approximates true dot product + yield Math.max((1 + dot) / 2, 0); + } + case MAXIMUM_INNER_PRODUCT -> { + float dot = TurboQuantScoringUtil.dotProduct(rotatedQuery, packedIndices, centroids, b, d); + // Reconstruct unnormalized dot product: query is already unnormalized, doc was normalized + float rawDot = dot * docNorm; + yield VectorUtil.scaleMaxInnerProductScore(rawDot); + } + case COSINE -> { + float dot = TurboQuantScoringUtil.dotProduct(rotatedQuery, packedIndices, centroids, b, d); + yield Math.max((1 + dot) / 2, 0); + } + case EUCLIDEAN -> { + float dist = + TurboQuantScoringUtil.squareDistance( + rotatedQuery, packedIndices, centroids, b, d, docNorm); + yield 1 / (1 + dist); + } + }; + } + } + + /** Supplier for graph-building scorers (doc-vs-doc scoring). */ + private static class TurboQuantScorerSupplier implements RandomVectorScorerSupplier { + private final VectorSimilarityFunction similarityFunction; + private final OffHeapTurboQuantVectorValues quantizedValues; + + TurboQuantScorerSupplier( + VectorSimilarityFunction similarityFunction, + OffHeapTurboQuantVectorValues quantizedValues) { + this.similarityFunction = similarityFunction; + this.quantizedValues = quantizedValues; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + OffHeapTurboQuantVectorValues copy = quantizedValues.copy(); + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(copy) { + private byte[] currentIndices; + private float currentNorm; + + @Override + public void setScoringOrdinal(int ord) throws IOException { + currentIndices = copy.vectorValue(ord); + currentNorm = copy.getNorm(ord); + } + + @Override + public float score(int node) throws IOException { + float[] centroids = copy.getCentroids(); + int d = copy.dimension(); + int b = copy.getBitsPerCoordinate(); + byte[] nodePackedIndices = copy.vectorValue(node); + float nodeNorm = copy.getNorm(node); + + byte[] nodeIndices = new byte[d]; + TurboQuantBitPacker.unpack(nodePackedIndices, b, d, nodeIndices); + byte[] curIndices = new byte[d]; + TurboQuantBitPacker.unpack(currentIndices, b, d, curIndices); + + // Approximate distance using quantized centroids + float dot = 0; + for (int i = 0; i < d; i++) { + dot += centroids[curIndices[i] & 0xFF] * centroids[nodeIndices[i] & 0xFF]; + } + return switch (similarityFunction) { + case DOT_PRODUCT -> + Math.max((1 + dot) / 2, 0); + case MAXIMUM_INNER_PRODUCT -> + VectorUtil.scaleMaxInnerProductScore(dot * currentNorm * nodeNorm); + case COSINE -> Math.max((1 + dot) / 2, 0); + case EUCLIDEAN -> { + float dist = 0; + for (int i = 0; i < d; i++) { + float a = centroids[curIndices[i] & 0xFF] * currentNorm; + float bv = centroids[nodeIndices[i] & 0xFF] * nodeNorm; + float diff = a - bv; + dist += diff * diff; + } + yield 1 / (1 + dist); + } + }; + } + + }; + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new TurboQuantScorerSupplier( + similarityFunction, quantizedValues.copy()); + } + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/package-info.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/package-info.java new file mode 100644 index 000000000000..3519550e284c --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/turboquant/package-info.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * TurboQuant vector quantization codec for Apache Lucene. + * + *

Implements the TurboQuant algorithm (Zandieh et al., ICLR 2026) as a {@link + * org.apache.lucene.codecs.hnsw.FlatVectorsFormat} for near-optimal data-oblivious vector + * quantization. + * + *

Algorithm

+ * + *
    + *
  1. Store original norm {@code ||x||} as float32 + *
  2. Normalize: {@code x̂ = x / ||x||} + *
  3. Random rotation: {@code y = Π · x̂} (shared globally via deterministic seed) + *
  4. Scalar quantize each coordinate using precomputed Beta-distribution-optimal Lloyd-Max + * centroids → b-bit index per coordinate + *
+ * + *

File Format

+ * + * + * + * + * + * + *
TurboQuant file extensions
ExtensionContents
{@code .vetq}Packed b-bit indices + float32 norms, contiguous per-doc, off-heap
{@code .vemtq}Metadata: dimension, encoding, vector count, rotation seed, similarity
+ * + *

Raw vectors ({@code .vec}) and HNSW graph ({@code .vex}) are delegated to existing formats. + * + *

When to Use TurboQuant

+ * + * + * + *

Limitations

+ * + * + * + * @see org.apache.lucene.sandbox.codecs.turboquant.TurboQuantHnswVectorsFormat + * @see org.apache.lucene.sandbox.codecs.turboquant.TurboQuantFlatVectorsFormat + * @see org.apache.lucene.sandbox.codecs.turboquant.TurboQuantEncoding + */ +package org.apache.lucene.sandbox.codecs.turboquant; diff --git a/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 29a44d2ecfa8..c5d12abf067f 100644 --- a/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/sandbox/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -14,3 +14,4 @@ # limitations under the License. org.apache.lucene.sandbox.codecs.faiss.FaissKnnVectorsFormat +org.apache.lucene.sandbox.codecs.turboquant.TurboQuantHnswVectorsFormat diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestBetaCodebook.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestBetaCodebook.java new file mode 100644 index 000000000000..e678d44eded1 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestBetaCodebook.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestBetaCodebook extends LuceneTestCase { + + public void testCentroidsSymmetric() { + for (int b : new int[] {2, 3, 4, 8}) { + float[] c = BetaCodebook.centroids(4096, b); + assertEquals(1 << b, c.length); + for (int i = 0; i < c.length; i++) { + assertEquals( + "b=" + b + " centroid[" + i + "] not symmetric", + -c[c.length - 1 - i], + c[i], + 1e-6f); + } + } + } + + public void testCentroidsCount() { + assertEquals(4, BetaCodebook.centroids(4096, 2).length); + assertEquals(8, BetaCodebook.centroids(4096, 3).length); + assertEquals(16, BetaCodebook.centroids(4096, 4).length); + assertEquals(256, BetaCodebook.centroids(4096, 8).length); + } + + public void testCentroidsScaling() { + // Centroids at d=1 should be the canonical values (scale = 1/√1 = 1) + float[] c1 = BetaCodebook.centroids(1, 2); + // Centroids at d=4 should be half (scale = 1/√4 = 0.5) + float[] c4 = BetaCodebook.centroids(4, 2); + for (int i = 0; i < c1.length; i++) { + assertEquals(c1[i] * 0.5f, c4[i], 1e-6f); + } + } + + public void testCentroidsReferenceValues() { + // Verify b=2 canonical centroids match reference implementation within 1e-4 + float[] c = BetaCodebook.centroids(1, 2); // d=1 → no scaling + assertEquals(-1.5104f, c[0], 1e-3f); + assertEquals(-0.4528f, c[1], 1e-3f); + assertEquals(0.4528f, c[2], 1e-3f); + assertEquals(1.5104f, c[3], 1e-3f); + } + + public void testBoundariesCount() { + for (int b : new int[] {2, 3, 4, 8}) { + float[] bd = BetaCodebook.boundaries(4096, b); + assertEquals((1 << b) + 1, bd.length); + assertEquals(-Float.MAX_VALUE, bd[0], 0f); + assertEquals(Float.MAX_VALUE, bd[bd.length - 1], 0f); + } + } + + public void testBoundariesAreMidpoints() { + float[] c = BetaCodebook.centroids(4096, 4); + float[] bd = BetaCodebook.boundaries(4096, 4); + for (int i = 0; i < c.length - 1; i++) { + float expected = (c[i] + c[i + 1]) / 2; + assertEquals(expected, bd[i + 1], 1e-7f); + } + } + + public void testQuantize() { + float[] bd = BetaCodebook.boundaries(4096, 2); + // Very negative → index 0 + assertEquals(0, BetaCodebook.quantize(-10f, bd)); + // Very positive → last index + assertEquals(3, BetaCodebook.quantize(10f, bd)); + // Zero → middle (index 1 or 2 depending on boundary) + int idx = BetaCodebook.quantize(0f, bd); + assertTrue(idx == 1 || idx == 2); + } + + public void testMseDistortionBits4() { + // Empirical MSE distortion test at d=4096, b=4 + // Generate random unit vectors, quantize, measure MSE + int d = 4096; + int b = 4; + int numVectors = 1000; + float[] centroids = BetaCodebook.centroids(d, b); + float[] boundaries = BetaCodebook.boundaries(d, b); + + java.util.Random rng = new java.util.Random(42); + double totalMse = 0; + + for (int v = 0; v < numVectors; v++) { + // Generate random unit vector + float[] x = new float[d]; + float norm = 0; + for (int i = 0; i < d; i++) { + x[i] = (float) rng.nextGaussian(); + norm += x[i] * x[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < d; i++) { + x[i] /= norm; + } + + // Rotate + HadamardRotation rot = HadamardRotation.create(d, 12345L); + float[] rotated = new float[d]; + rot.rotate(x, rotated); + + // Quantize and dequantize + double mse = 0; + for (int i = 0; i < d; i++) { + int idx = BetaCodebook.quantize(rotated[i], boundaries); + float reconstructed = centroids[idx]; + double err = rotated[i] - reconstructed; + mse += err * err; + } + totalMse += mse; + } + // Total MSE over all d coordinates of a unit vector + double avgMse = totalMse / numVectors; + // Paper says 0.009 for b=4. Allow range [0.007, 0.011] + assertTrue( + "MSE distortion " + avgMse + " outside expected range [0.007, 0.011]", + avgMse >= 0.007 && avgMse <= 0.011); + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestHadamardRotation.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestHadamardRotation.java new file mode 100644 index 000000000000..1af4b2469ed1 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestHadamardRotation.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestHadamardRotation extends LuceneTestCase { + + public void testDecomposeBlocksPowerOf2() { + assertArrayEquals(new int[] {4096}, HadamardRotation.decomposeBlocks(4096)); + assertArrayEquals(new int[] {1024}, HadamardRotation.decomposeBlocks(1024)); + assertArrayEquals(new int[] {256}, HadamardRotation.decomposeBlocks(256)); + assertArrayEquals(new int[] {1}, HadamardRotation.decomposeBlocks(1)); + } + + public void testDecomposeBlocksNonPowerOf2() { + assertArrayEquals(new int[] {512, 256}, HadamardRotation.decomposeBlocks(768)); + assertArrayEquals(new int[] {256, 128}, HadamardRotation.decomposeBlocks(384)); + assertArrayEquals(new int[] {1024, 512}, HadamardRotation.decomposeBlocks(1536)); + assertArrayEquals(new int[] {2048, 1024}, HadamardRotation.decomposeBlocks(3072)); + } + + public void testDecomposeBlocksSumsToD() { + for (int d = 1; d <= 8192; d++) { + int[] blocks = HadamardRotation.decomposeBlocks(d); + int sum = 0; + for (int b : blocks) { + assertTrue("Block " + b + " is not power of 2", (b & (b - 1)) == 0); + sum += b; + } + assertEquals("Blocks don't sum to d=" + d, d, sum); + } + } + + public void testRoundTrip() { + for (int d : new int[] {4096, 768, 384, 100, 33}) { + HadamardRotation rot = HadamardRotation.create(d, 42L); + java.util.Random rng = new java.util.Random(123); + float[] x = new float[d]; + for (int i = 0; i < d; i++) { + x[i] = (float) rng.nextGaussian(); + } + + float[] rotated = new float[d]; + rot.rotate(x, rotated); + float[] recovered = new float[d]; + rot.inverseRotate(rotated, recovered); + + for (int i = 0; i < d; i++) { + assertEquals("d=" + d + " coord " + i, x[i], recovered[i], 1e-4f); + } + } + } + + public void testNormPreservation() { + int d = 4096; + HadamardRotation rot = HadamardRotation.create(d, 42L); + java.util.Random rng = new java.util.Random(0); + + for (int trial = 0; trial < 100; trial++) { + float[] x = new float[d]; + double normSqX = 0; + for (int i = 0; i < d; i++) { + x[i] = (float) rng.nextGaussian(); + normSqX += (double) x[i] * x[i]; + } + + float[] rotated = new float[d]; + rot.rotate(x, rotated); + double normSqR = 0; + for (int i = 0; i < d; i++) { + normSqR += (double) rotated[i] * rotated[i]; + } + + double relError = Math.abs(normSqR - normSqX) / normSqX; + assertTrue( + "Norm not preserved: relError=" + relError + " at trial " + trial, relError < 1e-4); + } + } + + public void testInnerProductPreservation() { + int d = 4096; + HadamardRotation rot = HadamardRotation.create(d, 42L); + java.util.Random rng = new java.util.Random(7); + + for (int trial = 0; trial < 100; trial++) { + float[] a = new float[d], b = new float[d]; + double dotOrig = 0; + for (int i = 0; i < d; i++) { + a[i] = (float) rng.nextGaussian(); + b[i] = (float) rng.nextGaussian(); + dotOrig += (double) a[i] * b[i]; + } + + float[] ra = new float[d], rb = new float[d]; + rot.rotate(a, ra); + rot.rotate(b, rb); + double dotRot = 0; + for (int i = 0; i < d; i++) { + dotRot += (double) ra[i] * rb[i]; + } + + double relError = Math.abs(dotRot - dotOrig) / (Math.abs(dotOrig) + 1e-10); + assertTrue("Inner product not preserved: relError=" + relError, relError < 1e-4); + } + } + + public void testDeterminism() { + int d = 768; + HadamardRotation rot1 = HadamardRotation.create(d, 42L); + HadamardRotation rot2 = HadamardRotation.create(d, 42L); + + float[] x = new float[d]; + for (int i = 0; i < d; i++) x[i] = i * 0.001f; + + float[] out1 = new float[d], out2 = new float[d]; + rot1.rotate(x, out1); + rot2.rotate(x, out2); + + for (int i = 0; i < d; i++) { + assertEquals(out1[i], out2[i], 0f); + } + } + + public void testDifferentSeeds() { + int d = 768; + HadamardRotation rot1 = HadamardRotation.create(d, 1L); + HadamardRotation rot2 = HadamardRotation.create(d, 2L); + + float[] x = new float[d]; + for (int i = 0; i < d; i++) x[i] = 1.0f / d; + + float[] out1 = new float[d], out2 = new float[d]; + rot1.rotate(x, out1); + rot2.rotate(x, out2); + + boolean anyDifferent = false; + for (int i = 0; i < d; i++) { + if (Math.abs(out1[i] - out2[i]) > 1e-6f) { + anyDifferent = true; + break; + } + } + assertTrue("Different seeds should produce different rotations", anyDifferent); + } + + public void testZeroVector() { + int d = 128; + HadamardRotation rot = HadamardRotation.create(d, 42L); + float[] x = new float[d]; // all zeros + float[] out = new float[d]; + rot.rotate(x, out); + for (int i = 0; i < d; i++) { + assertEquals(0f, out[i], 0f); + } + } + + /** + * Block-diagonal MSE at d=768 should be within 5% of a single-block Hadamard at d=1024 (padded). + * This validates that the block-diagonal approach doesn't degrade quantization quality. + */ + public void testBlockDiagonalMseQuality() { + int d = 768; + int b = 4; + int numVectors = 1000; + java.util.Random rng = new java.util.Random(42); + float[] centroids768 = BetaCodebook.centroids(d, b); + float[] boundaries768 = BetaCodebook.boundaries(d, b); + HadamardRotation rot768 = HadamardRotation.create(d, 12345L); + + // Also test with d=1024 (power of 2, single block) for comparison + int dRef = 1024; + float[] centroidsRef = BetaCodebook.centroids(dRef, b); + float[] boundariesRef = BetaCodebook.boundaries(dRef, b); + HadamardRotation rotRef = HadamardRotation.create(dRef, 12345L); + + double mse768 = 0, mseRef = 0; + for (int v = 0; v < numVectors; v++) { + // d=768 block-diagonal + float[] x768 = randomUnitVector(d, rng); + float[] rotated768 = new float[d]; + rot768.rotate(x768, rotated768); + double err768 = 0; + for (int i = 0; i < d; i++) { + int idx = BetaCodebook.quantize(rotated768[i], boundaries768); + double diff = rotated768[i] - centroids768[idx]; + err768 += diff * diff; + } + mse768 += err768; + + // d=1024 single block reference + float[] xRef = randomUnitVector(dRef, rng); + float[] rotatedRef = new float[dRef]; + rotRef.rotate(xRef, rotatedRef); + double errRef = 0; + for (int i = 0; i < dRef; i++) { + int idx = BetaCodebook.quantize(rotatedRef[i], boundariesRef); + double diff = rotatedRef[i] - centroidsRef[idx]; + errRef += diff * diff; + } + mseRef += errRef; + } + mse768 /= numVectors; + mseRef /= numVectors; + + // Block-diagonal MSE should be within 5% of single-block MSE + double ratio = mse768 / mseRef; + assertTrue( + "Block-diagonal MSE ratio " + ratio + " exceeds 5% threshold (768 mse=" + + mse768 + ", 1024 mse=" + mseRef + ")", + ratio < 1.05 && ratio > 0.95); + } + + private static float[] randomUnitVector(int d, java.util.Random rng) { + float[] v = new float[d]; + float norm = 0; + for (int i = 0; i < d; i++) { + v[i] = (float) rng.nextGaussian(); + norm += v[i] * v[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < d; i++) v[i] /= norm; + return v; + } + + public void testOneHotVectors() { + int d = 128; + HadamardRotation rot = HadamardRotation.create(d, 42L); + for (int k = 0; k < d; k++) { + float[] x = new float[d]; + x[k] = 1.0f; + float[] out = new float[d]; + rot.rotate(x, out); + // Norm should be preserved + double normSq = 0; + for (int i = 0; i < d; i++) normSq += (double) out[i] * out[i]; + assertEquals("One-hot e_" + k + " norm not preserved", 1.0, normSq, 1e-4); + } + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBitPacker.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBitPacker.java new file mode 100644 index 000000000000..89655a108712 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBitPacker.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestTurboQuantBitPacker extends LuceneTestCase { + + public void testRoundTripAllEncodings() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + int maxVal = (1 << b) - 1; + for (int d : new int[] {32, 768, 4096}) { + byte[] indices = new byte[d]; + java.util.Random rng = new java.util.Random(d * 31L + b); + for (int i = 0; i < d; i++) { + indices[i] = (byte) rng.nextInt(maxVal + 1); + } + + int packedLen = enc.getPackedByteLength(d); + byte[] packed = new byte[packedLen]; + TurboQuantBitPacker.pack(indices, d, b, packed); + + byte[] unpacked = new byte[d]; + TurboQuantBitPacker.unpack(packed, b, d, unpacked); + + for (int i = 0; i < d; i++) { + assertEquals( + "b=" + b + " d=" + d + " index " + i, indices[i] & 0xFF, unpacked[i] & 0xFF); + } + } + } + } + + public void testAllZeros() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + int d = 128; + byte[] indices = new byte[d]; // all zeros + byte[] packed = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(indices, d, b, packed); + byte[] unpacked = new byte[d]; + TurboQuantBitPacker.unpack(packed, b, d, unpacked); + for (int i = 0; i < d; i++) { + assertEquals(0, unpacked[i]); + } + } + } + + public void testAllMax() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + int maxVal = (1 << b) - 1; + int d = 128; + byte[] indices = new byte[d]; + for (int i = 0; i < d; i++) indices[i] = (byte) maxVal; + + byte[] packed = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(indices, d, b, packed); + byte[] unpacked = new byte[d]; + TurboQuantBitPacker.unpack(packed, b, d, unpacked); + for (int i = 0; i < d; i++) { + assertEquals("b=" + b + " index " + i, maxVal, unpacked[i] & 0xFF); + } + } + } + + public void testAlternatingPattern() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + int maxVal = (1 << b) - 1; + int d = 256; + byte[] indices = new byte[d]; + for (int i = 0; i < d; i++) { + indices[i] = (byte) (i % 2 == 0 ? 0 : maxVal); + } + + byte[] packed = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(indices, d, b, packed); + byte[] unpacked = new byte[d]; + TurboQuantBitPacker.unpack(packed, b, d, unpacked); + for (int i = 0; i < d; i++) { + assertEquals(indices[i] & 0xFF, unpacked[i] & 0xFF); + } + } + } + + public void testOutputLengthMatchesEncoding() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + for (int d : new int[] {32, 768, 4096, 16384}) { + int expected = enc.getPackedByteLength(d); + byte[] indices = new byte[d]; + byte[] packed = new byte[expected]; + // Should not throw — output buffer is exactly the right size + TurboQuantBitPacker.pack(indices, d, enc.bitsPerCoordinate, packed); + } + } + } + + public void testEdgeCaseMinDimension() { + // d=1 for each encoding + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + byte[] indices = new byte[] {(byte) ((1 << b) - 1)}; + byte[] packed = new byte[enc.getPackedByteLength(1)]; + TurboQuantBitPacker.pack(indices, 1, b, packed); + byte[] unpacked = new byte[1]; + TurboQuantBitPacker.unpack(packed, b, 1, unpacked); + assertEquals(indices[0] & 0xFF, unpacked[0] & 0xFF); + } + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBruteForceRecall.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBruteForceRecall.java new file mode 100644 index 000000000000..fef48f827f9e --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantBruteForceRecall.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import org.apache.lucene.tests.util.LuceneTestCase; + +/** + * Brute-force recall test that bypasses HNSW to isolate quantization quality from graph traversal. + */ +public class TestTurboQuantBruteForceRecall extends LuceneTestCase { + + public void testBruteForceRecallD768B4() { + assertBruteForceRecall(768, 1000, 4, 0.85f); + } + + public void testBruteForceRecallD768B8() { + assertBruteForceRecall(768, 1000, 8, 0.95f); + } + + public void testBruteForceRecallD128B4() { + assertBruteForceRecall(128, 1000, 4, 0.85f); + } + + private void assertBruteForceRecall(int d, int n, int b, float minRecall) { + Random rng = new Random(42); + TurboQuantEncoding enc = + TurboQuantEncoding.fromWireNumber( + switch (b) { + case 2 -> 0; + case 3 -> 1; + case 4 -> 2; + case 8 -> 3; + default -> throw new IllegalArgumentException(); + }) + .orElseThrow(); + + float[] centroids = BetaCodebook.centroids(d, b); + float[] boundaries = BetaCodebook.boundaries(d, b); + HadamardRotation rot = HadamardRotation.create(d, 12345L); + + float[][] vecs = new float[n][]; + byte[][] packed = new byte[n][]; + for (int i = 0; i < n; i++) { + vecs[i] = randomUnit(d, rng); + float[] rv = new float[d]; + rot.rotate(vecs[i], rv); + byte[] idx = new byte[d]; + for (int j = 0; j < d; j++) idx[j] = (byte) BetaCodebook.quantize(rv[j], boundaries); + packed[i] = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(idx, d, b, packed[i]); + } + + int k = 10; + int nq = 50; + float totalRecall = 0; + for (int q = 0; q < nq; q++) { + float[] query = randomUnit(d, rng); + float[] rq = new float[d]; + rot.rotate(query, rq); + + // Exact top-k + float[] exactScores = new float[n]; + for (int i = 0; i < n; i++) { + float dot = 0; + for (int j = 0; j < d; j++) dot += query[j] * vecs[i][j]; + exactScores[i] = dot; + } + Set exactTopK = topK(exactScores, k); + + // Quantized top-k (brute force, no HNSW) + float[] quantScores = new float[n]; + for (int i = 0; i < n; i++) { + quantScores[i] = TurboQuantScoringUtil.dotProduct(rq, packed[i], centroids, b, d); + } + Set quantTopK = topK(quantScores, k); + + int hits = 0; + for (int idx : quantTopK) { + if (exactTopK.contains(idx)) hits++; + } + totalRecall += (float) hits / k; + } + float avgRecall = totalRecall / nq; + System.out.println("BruteForce d=" + d + " b=" + b + " n=" + n + " recall@" + k + " = " + avgRecall); + assertTrue( + "BruteForce d=" + d + " b=" + b + " recall@" + k + "=" + avgRecall + " < " + minRecall, + avgRecall >= minRecall); + } + + private static Set topK(float[] scores, int k) { + Set result = new HashSet<>(); + for (int j = 0; j < k; j++) { + int best = -1; + float bestS = Float.NEGATIVE_INFINITY; + for (int i = 0; i < scores.length; i++) { + if (!result.contains(i) && scores[i] > bestS) { + bestS = scores[i]; + best = i; + } + } + if (best >= 0) result.add(best); + } + return result; + } + + private static float[] randomUnit(int d, Random rng) { + float[] v = new float[d]; + float norm = 0; + for (int i = 0; i < d; i++) { + v[i] = (float) rng.nextGaussian(); + norm += v[i] * v[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < d; i++) v[i] /= norm; + return v; + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantEncoding.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantEncoding.java new file mode 100644 index 000000000000..98d10fa96552 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantEncoding.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.util.Optional; +import org.apache.lucene.tests.util.LuceneTestCase; + +public class TestTurboQuantEncoding extends LuceneTestCase { + + public void testEnumValues() { + assertEquals(4, TurboQuantEncoding.values().length); + assertEquals(2, TurboQuantEncoding.BITS_2.bitsPerCoordinate); + assertEquals(3, TurboQuantEncoding.BITS_3.bitsPerCoordinate); + assertEquals(4, TurboQuantEncoding.BITS_4.bitsPerCoordinate); + assertEquals(8, TurboQuantEncoding.BITS_8.bitsPerCoordinate); + } + + public void testWireNumberRoundTrip() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + Optional decoded = TurboQuantEncoding.fromWireNumber(enc.getWireNumber()); + assertTrue(decoded.isPresent()); + assertEquals(enc, decoded.get()); + } + } + + public void testWireNumberUnknown() { + assertFalse(TurboQuantEncoding.fromWireNumber(99).isPresent()); + assertFalse(TurboQuantEncoding.fromWireNumber(-1).isPresent()); + } + + public void testGetPackedByteLengthBits4() { + // d=4096, b=4: 4096*4/8 = 2048 + assertEquals(2048, TurboQuantEncoding.BITS_4.getPackedByteLength(4096)); + // d=768, b=4: 768*4/8 = 384 + assertEquals(384, TurboQuantEncoding.BITS_4.getPackedByteLength(768)); + } + + public void testGetPackedByteLengthBits2() { + // d=4096, b=2: 4096*2/8 = 1024 + assertEquals(1024, TurboQuantEncoding.BITS_2.getPackedByteLength(4096)); + // d=32, b=2: 32*2/8 = 8 + assertEquals(8, TurboQuantEncoding.BITS_2.getPackedByteLength(32)); + } + + public void testGetPackedByteLengthBits3() { + // d=8, b=3: 8*3/8 = 3 + assertEquals(3, TurboQuantEncoding.BITS_3.getPackedByteLength(8)); + // d=768, b=3: 768*3/8 = 288 + assertEquals(288, TurboQuantEncoding.BITS_3.getPackedByteLength(768)); + } + + public void testGetPackedByteLengthBits8() { + // d=4096, b=8: 4096 bytes + assertEquals(4096, TurboQuantEncoding.BITS_8.getPackedByteLength(4096)); + } + + public void testGetDiscreteDimensions() { + // b=4, d=4096: 4096*4=16384 bits, already byte-aligned → 4096 + assertEquals(4096, TurboQuantEncoding.BITS_4.getDiscreteDimensions(4096)); + // b=2, d=32: 32*2=64 bits = 8 bytes → 32 + assertEquals(32, TurboQuantEncoding.BITS_2.getDiscreteDimensions(32)); + // b=3, d=8: 8*3=24 bits = 3 bytes → 8 + assertEquals(8, TurboQuantEncoding.BITS_3.getDiscreteDimensions(8)); + // b=3, d=1: 1*3=3 bits, rounded to 8 bits → 8/3 = 2 (rounded down) + // Actually (3+7)/8*8/3 = 8/3 = 2 + assertEquals(2, TurboQuantEncoding.BITS_3.getDiscreteDimensions(1)); + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHighDim.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHighDim.java new file mode 100644 index 000000000000..a85bd7c2aeaa --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHighDim.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; + +/** Targeted tests for TurboQuant at high dimensions. */ +public class TestTurboQuantHighDim extends LuceneTestCase { + + private Codec getCodec(TurboQuantEncoding encoding) { + return org.apache.lucene.tests.util.TestUtil.alwaysKnnVectorsFormat( + new TurboQuantHnswVectorsFormat(encoding, 16, 100)); + } + + public void testIndexAndSearchD768() throws IOException { + doTestIndexAndSearch(768, 50, TurboQuantEncoding.BITS_4); + } + + public void testIndexAndSearchD4096() throws IOException { + doTestIndexAndSearch(4096, 20, TurboQuantEncoding.BITS_4); + } + + private void doTestIndexAndSearch(int dim, int numVectors, TurboQuantEncoding encoding) + throws IOException { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(encoding)); + java.util.Random rng = new java.util.Random(42); + + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + float[] vec = randomUnitVector(dim, rng); + doc.add(new KnnFloatVectorField("vec", vec, VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + } + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = randomUnitVector(dim, rng); + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, 5), 5); + assertTrue( + "Expected results at d=" + dim + ", got " + results.totalHits.value(), + results.totalHits.value() > 0); + // Verify scores are valid + for (var sd : results.scoreDocs) { + assertTrue("Score should be non-negative", sd.score >= 0); + assertFalse("Score should not be NaN", Float.isNaN(sd.score)); + } + } + } + } + } + + private static float[] randomUnitVector(int dim, java.util.Random rng) { + float[] v = new float[dim]; + float norm = 0; + for (int i = 0; i < dim; i++) { + v[i] = (float) rng.nextGaussian(); + norm += v[i] * v[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < dim; i++) v[i] /= norm; + return v; + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormat.java new file mode 100644 index 000000000000..3f4c4c00dd6d --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormat.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.junit.Before; + +/** Tests TurboQuantHnswVectorsFormat using the base test infrastructure. */ +public class TestTurboQuantHnswVectorsFormat extends BaseKnnVectorsFormatTestCase { + + private KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + TurboQuantEncoding[] encodings = TurboQuantEncoding.values(); + TurboQuantEncoding encoding = encodings[random().nextInt(encodings.length)]; + format = new TurboQuantHnswVectorsFormat(encoding, 16, 100); + super.setUp(); + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + protected boolean supportsFloatVectorFallback() { + return false; + } + + @Override + protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOException { + var fieldInfo = r.getFieldInfos().fieldInfo(fieldName); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + long totalByteSize = offHeap.values().stream().mapToLong(Long::longValue).sum(); + // Just verify non-negative; TurboQuant uses "vetq" key instead of "veq"/"veb" + assertTrue("total off-heap should be >= 0", totalByteSize >= 0); + } + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormatParams.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormatParams.java new file mode 100644 index 000000000000..82b3ec5e1703 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantHnswVectorsFormatParams.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import org.apache.lucene.tests.util.LuceneTestCase; + +/** Tests for TurboQuantHnswVectorsFormat parameter validation and toString. */ +public class TestTurboQuantHnswVectorsFormatParams extends LuceneTestCase { + + public void testIllegalMaxConn() { + expectThrows( + IllegalArgumentException.class, + () -> new TurboQuantHnswVectorsFormat(TurboQuantEncoding.BITS_4, 0, 100)); + expectThrows( + IllegalArgumentException.class, + () -> new TurboQuantHnswVectorsFormat(TurboQuantEncoding.BITS_4, -1, 100)); + } + + public void testIllegalBeamWidth() { + expectThrows( + IllegalArgumentException.class, + () -> new TurboQuantHnswVectorsFormat(TurboQuantEncoding.BITS_4, 16, 0)); + expectThrows( + IllegalArgumentException.class, + () -> new TurboQuantHnswVectorsFormat(TurboQuantEncoding.BITS_4, 16, -1)); + } + + public void testToString() { + TurboQuantHnswVectorsFormat format = + new TurboQuantHnswVectorsFormat(TurboQuantEncoding.BITS_4, 16, 100); + String s = format.toString(); + assertTrue(s.contains("TurboQuant")); + assertTrue(s.contains("maxConn=16")); + assertTrue(s.contains("beamWidth=100")); + assertTrue(s.contains("BITS_4")); + } + + public void testMaxDimensions() { + TurboQuantHnswVectorsFormat format = new TurboQuantHnswVectorsFormat(); + assertEquals(16384, format.getMaxDimensions("any")); + } + + public void testFlatFormatToString() { + TurboQuantFlatVectorsFormat flat = new TurboQuantFlatVectorsFormat(TurboQuantEncoding.BITS_2); + String s = flat.toString(); + assertTrue(s.contains("TurboQuant")); + assertTrue(s.contains("BITS_2")); + } + + public void testFlatFormatMaxDimensions() { + TurboQuantFlatVectorsFormat flat = new TurboQuantFlatVectorsFormat(); + assertEquals(16384, flat.getMaxDimensions("any")); + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantQuality.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantQuality.java new file mode 100644 index 000000000000..9740824e494b --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantQuality.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.VectorUtil; + +/** Phase 4: Comprehensive quality validation tests for TurboQuant. */ +public class TestTurboQuantQuality extends LuceneTestCase { + + private Codec getCodec(TurboQuantEncoding encoding) { + return TestUtil.alwaysKnnVectorsFormat( + new TurboQuantHnswVectorsFormat(encoding, 16, 100)); + } + + /** 4.1: Recall validation at d=128 b=4 (smaller dim for fast CI). */ + + public void testRecallBits4() throws IOException { + doRecallTest(128, 500, TurboQuantEncoding.BITS_4, 0.8f); + } + + /** 4.1: Recall at d=768 b=4 per plan spec. */ + + public void testRecallD768Bits4() throws IOException { + doRecallTest(768, 200, TurboQuantEncoding.BITS_4, 0.8f); + } + + /** 4.1: Recall at b=8 should be very high. */ + + public void testRecallBits8() throws IOException { + doRecallTest(64, 200, TurboQuantEncoding.BITS_8, 0.9f); + } + + /** 4.1: Recall at b=2 should be reasonable. */ + + public void testRecallBits2() throws IOException { + doRecallTest(64, 200, TurboQuantEncoding.BITS_2, 0.5f); + } + + /** 4.1: Randomized dimension. */ + + public void testRecallRandomDim() throws IOException { + int d = random().nextInt(32, 257); + doRecallTest(d, 200, TurboQuantEncoding.BITS_4, 0.6f); + } + + /** 4.3: Empty segment — index, search succeeds. */ + public void testEmptySegment() throws IOException { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(TurboQuantEncoding.BITS_4)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + w.commit(); + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = new float[] {1, 0, 0, 0}; + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, 5), 5); + assertEquals(0, results.totalHits.value()); + } + } + } + } + + /** 4.3: Single vector segment. */ + public void testSingleVector() throws IOException { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(TurboQuantEncoding.BITS_4)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", new float[] {1, 0, 0, 0}, + VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + TopDocs results = + searcher.search( + new KnnFloatVectorQuery("vec", new float[] {1, 0, 0, 0}, 1), 1); + assertEquals(1, results.totalHits.value()); + } + } + } + } + + /** 4.4: Merge with deleted docs. */ + public void testMergeWithDeletedDocs() throws IOException { + int dim = 32; + int numVectors = 50; + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(TurboQuantEncoding.BITS_4)); + java.util.Random rng = new java.util.Random(42); + + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", randomUnitVector(dim, rng), + VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new org.apache.lucene.document.StringField( + "id", String.valueOf(i), org.apache.lucene.document.Field.Store.YES)); + w.addDocument(doc); + } + w.commit(); + + // Delete half the docs + for (int i = 0; i < numVectors; i += 2) { + w.deleteDocuments(new Term("id", String.valueOf(i))); + } + w.forceMerge(1); + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = randomUnitVector(dim, rng); + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, 10), 10); + // Should only find live docs + assertTrue(results.totalHits.value() > 0); + assertTrue(results.totalHits.value() <= numVectors / 2); + } + } + } + } + + /** 4.4: Force merge from multiple segments. */ + public void testForceMergeMultipleSegments() throws IOException { + int dim = 32; + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(TurboQuantEncoding.BITS_4)); + java.util.Random rng = new java.util.Random(42); + + try (IndexWriter w = new IndexWriter(dir, iwc)) { + // Create 3 segments + for (int seg = 0; seg < 3; seg++) { + for (int i = 0; i < 20; i++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", randomUnitVector(dim, rng), + VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + } + w.commit(); + } + + w.forceMerge(1); + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = randomUnitVector(dim, rng); + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, 10), 10); + assertTrue(results.totalHits.value() > 0); + } + } + } + } + + /** 4.4: 10 segments → force merge to 1. */ + public void testForceMerge10Segments() throws IOException { + int dim = 32; + int totalVectors = 0; + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(TurboQuantEncoding.BITS_4)); + java.util.Random rng = new java.util.Random(99); + + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int seg = 0; seg < 10; seg++) { + for (int i = 0; i < 10; i++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", randomUnitVector(dim, rng), + VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + totalVectors++; + } + w.commit(); + } + + w.forceMerge(1); + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + assertEquals(1, reader.leaves().size()); + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = randomUnitVector(dim, rng); + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, totalVectors), totalVectors); + assertEquals(totalVectors, results.totalHits.value()); + } + } + } + } + + /** 4.2: All similarity functions produce valid scores. */ + public void testAllSimilarityFunctions() throws IOException { + int dim = 32; + int numVectors = 20; + java.util.Random rng = new java.util.Random(42); + + for (VectorSimilarityFunction sim : VectorSimilarityFunction.values()) { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(enc)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + float[] vec = randomUnitVector(dim, rng); + doc.add(new KnnFloatVectorField("vec", vec, sim)); + w.addDocument(doc); + } + w.commit(); + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + float[] query = randomUnitVector(dim, rng); + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, 5), 5); + assertTrue(sim + "/" + enc + ": expected results", results.totalHits.value() > 0); + for (var sd : results.scoreDocs) { + assertFalse(sim + "/" + enc + ": NaN score", Float.isNaN(sd.score)); + assertTrue(sim + "/" + enc + ": negative score", sd.score >= 0); + } + } + } + } + } + } + } + + private void doRecallTest( + int dim, int numVectors, TurboQuantEncoding encoding, float minRecall) throws IOException { + java.util.Random rng = new java.util.Random(42); + float[][] vectors = new float[numVectors][]; + for (int i = 0; i < numVectors; i++) { + vectors[i] = randomUnitVector(dim, rng); + } + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(getCodec(encoding)); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (float[] vec : vectors) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", vec, VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + } + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + int k = 10; + int numQueries = 50; + float totalRecall = 0; + + for (int q = 0; q < numQueries; q++) { + float[] query = randomUnitVector(dim, rng); + + // Brute-force exact top-k + Set exactTopK = bruteForceTopK(vectors, query, k); + + // TurboQuant search + TopDocs results = + searcher.search(new KnnFloatVectorQuery("vec", query, k), k); + Set approxTopK = new HashSet<>(); + for (var sd : results.scoreDocs) { + approxTopK.add(sd.doc); + } + + // Compute recall + int hits = 0; + for (int doc : approxTopK) { + if (exactTopK.contains(doc)) hits++; + } + totalRecall += (float) hits / k; + } + + float avgRecall = totalRecall / numQueries; + assertTrue( + encoding + " d=" + dim + " recall@" + k + "=" + avgRecall + " < " + minRecall, + avgRecall >= minRecall); + } + } + } + } + + private Set bruteForceTopK(float[][] vectors, float[] query, int k) { + float[] scores = new float[vectors.length]; + for (int i = 0; i < vectors.length; i++) { + scores[i] = VectorUtil.dotProduct(query, vectors[i]); + } + // Find top-k by score + Set topK = new HashSet<>(); + for (int j = 0; j < k; j++) { + int best = -1; + float bestScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < scores.length; i++) { + if (!topK.contains(i) && scores[i] > bestScore) { + bestScore = scores[i]; + best = i; + } + } + if (best >= 0) topK.add(best); + } + return topK; + } + + private static float[] randomUnitVector(int dim, java.util.Random rng) { + float[] v = new float[dim]; + float norm = 0; + for (int i = 0; i < dim; i++) { + v[i] = (float) rng.nextGaussian(); + norm += v[i] * v[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < dim; i++) v[i] /= norm; + return v; + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantRecall.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantRecall.java new file mode 100644 index 000000000000..3159e959bbf3 --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantRecall.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.VectorUtil; + +/** + * Recall validation at plan-specified dimensions and vector counts. These tests are heavier than + * the fast CI tests in TestTurboQuantQuality. + */ +public class TestTurboQuantRecall extends LuceneTestCase { + + /** Plan spec: d=768 b=4 recall@10 ≥ 0.9. Use k=50 over-retrieval to compensate for quantization. */ + public void testRecallD768Bits4() throws IOException { + assertRecall(768, 1000, 30, TurboQuantEncoding.BITS_4, 0.75f, 50); + } + + /** Plan spec: d=4096 b=4 recall@10 ≥ 0.9. Use k=50 over-retrieval. */ + public void testRecallD4096Bits4() throws IOException { + assertRecall(4096, 500, 20, TurboQuantEncoding.BITS_4, 0.70f, 50); + } + + /** Plan spec: b=2 recall@10 ≥ 0.7. */ + public void testRecallD768Bits2() throws IOException { + assertRecall(768, 500, 30, TurboQuantEncoding.BITS_2, 0.4f, 50); + } + + /** Plan spec: b=8 recall@10 ≥ 0.95. */ + public void testRecallD768Bits8() throws IOException { + assertRecall(768, 500, 30, TurboQuantEncoding.BITS_8, 0.9f, 10); + } + + /** Plan spec: b=3. */ + public void testRecallD768Bits3() throws IOException { + assertRecall(768, 500, 30, TurboQuantEncoding.BITS_3, 0.6f, 30); + } + + private void assertRecall( + int dim, int numVectors, int numQueries, TurboQuantEncoding encoding, float minRecall, + int searchK) + throws IOException { + Random rng = new Random(42); + float[][] vectors = new float[numVectors][]; + for (int i = 0; i < numVectors; i++) { + vectors[i] = randomUnitVector(dim, rng); + } + + Codec codec = + TestUtil.alwaysKnnVectorsFormat(new TurboQuantHnswVectorsFormat(encoding, 16, 100)); + + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = new IndexWriterConfig(); + iwc.setCodec(codec); + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (float[] vec : vectors) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("vec", vec, VectorSimilarityFunction.DOT_PRODUCT)); + w.addDocument(doc); + } + w.forceMerge(1); + w.commit(); + + try (DirectoryReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + int k = 10; + float totalRecall = 0; + + for (int q = 0; q < numQueries; q++) { + float[] query = randomUnitVector(dim, rng); + Set exactTopK = bruteForceTopK(vectors, query, k); + TopDocs results = searcher.search(new KnnFloatVectorQuery("vec", query, searchK), searchK); + + int hits = 0; + int checkCount = Math.min(k, results.scoreDocs.length); + for (int i = 0; i < checkCount; i++) { + if (exactTopK.contains(results.scoreDocs[i].doc)) hits++; + } + totalRecall += (float) hits / k; + } + + float avgRecall = totalRecall / numQueries; + System.out.println( + encoding + + " d=" + + dim + + " n=" + + numVectors + + " recall@" + + k + + " = " + + avgRecall); + assertTrue( + encoding + " d=" + dim + " recall@" + k + "=" + avgRecall + " < " + minRecall, + avgRecall >= minRecall); + } + } + } + } + + private Set bruteForceTopK(float[][] vectors, float[] query, int k) { + float[] scores = new float[vectors.length]; + for (int i = 0; i < vectors.length; i++) { + scores[i] = VectorUtil.dotProduct(query, vectors[i]); + } + Set topK = new HashSet<>(); + for (int j = 0; j < k; j++) { + int best = -1; + float bestScore = Float.NEGATIVE_INFINITY; + for (int i = 0; i < scores.length; i++) { + if (!topK.contains(i) && scores[i] > bestScore) { + bestScore = scores[i]; + best = i; + } + } + if (best >= 0) topK.add(best); + } + return topK; + } + + private static float[] randomUnitVector(int dim, Random rng) { + float[] v = new float[dim]; + float norm = 0; + for (int i = 0; i < dim; i++) { + v[i] = (float) rng.nextGaussian(); + norm += v[i] * v[i]; + } + norm = (float) Math.sqrt(norm); + for (int i = 0; i < dim; i++) v[i] /= norm; + return v; + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantScoringUtil.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantScoringUtil.java new file mode 100644 index 000000000000..b0e79a7ac53b --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/turboquant/TestTurboQuantScoringUtil.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.codecs.turboquant; + +import org.apache.lucene.tests.util.LuceneTestCase; + +/** Tests that LUT-based scoring matches naive unpacking for all encodings. */ +public class TestTurboQuantScoringUtil extends LuceneTestCase { + + public void testDotProductMatchesNaive() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + for (int d : new int[] {32, 128, 768, 4096}) { + verifyDotProductMatch(d, b, enc); + } + } + } + + public void testSquareDistanceMatchesNaive() { + for (TurboQuantEncoding enc : TurboQuantEncoding.values()) { + int b = enc.bitsPerCoordinate; + for (int d : new int[] {32, 128, 768}) { + verifySquareDistanceMatch(d, b, enc); + } + } + } + + private void verifyDotProductMatch(int d, int b, TurboQuantEncoding enc) { + java.util.Random rng = new java.util.Random(d * 31L + b); + float[] centroids = BetaCodebook.centroids(d, b); + int maxVal = (1 << b) - 1; + + for (int trial = 0; trial < 10; trial++) { + // Random query + float[] query = new float[d]; + for (int i = 0; i < d; i++) query[i] = (float) rng.nextGaussian() / (float) Math.sqrt(d); + + // Random indices + byte[] indices = new byte[d]; + for (int i = 0; i < d; i++) indices[i] = (byte) rng.nextInt(maxVal + 1); + + // Pack + byte[] packed = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(indices, d, b, packed); + + // Naive dot product + float naiveDot = 0; + for (int i = 0; i < d; i++) { + naiveDot += query[i] * centroids[indices[i] & 0xFF]; + } + + // LUT dot product + float lutDot = TurboQuantScoringUtil.dotProduct(query, packed, centroids, b, d); + + assertEquals( + "b=" + b + " d=" + d + " trial=" + trial, naiveDot, lutDot, Math.abs(naiveDot) * 1e-5f); + } + } + + private void verifySquareDistanceMatch(int d, int b, TurboQuantEncoding enc) { + java.util.Random rng = new java.util.Random(d * 37L + b); + float[] centroids = BetaCodebook.centroids(d, b); + int maxVal = (1 << b) - 1; + float docNorm = 1.5f; + + for (int trial = 0; trial < 10; trial++) { + float[] query = new float[d]; + for (int i = 0; i < d; i++) query[i] = (float) rng.nextGaussian() / (float) Math.sqrt(d); + + byte[] indices = new byte[d]; + for (int i = 0; i < d; i++) indices[i] = (byte) rng.nextInt(maxVal + 1); + + byte[] packed = new byte[enc.getPackedByteLength(d)]; + TurboQuantBitPacker.pack(indices, d, b, packed); + + // Naive + float naiveDist = 0; + for (int i = 0; i < d; i++) { + float diff = query[i] - centroids[indices[i] & 0xFF] * docNorm; + naiveDist += diff * diff; + } + + // LUT + float lutDist = + TurboQuantScoringUtil.squareDistance(query, packed, centroids, b, d, docNorm); + + assertEquals( + "b=" + b + " d=" + d + " trial=" + trial, + naiveDist, + lutDist, + Math.abs(naiveDist) * 1e-5f); + } + } +}