Skip to content

Commit 33206e5

Browse files
committed
Add PQ rerank path
1 parent d7613c4 commit 33206e5

13 files changed

Lines changed: 659 additions & 28 deletions

File tree

README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ await writeVectors({
6666
normalize: true, // L2-normalize on write; lets search skip sqrt for cosine
6767
binary: true, // also write 1-bit-per-dim sign column for binary+rerank search
6868
clusters: 128, // k-means clusters for phase-1 pruning (implies binary: true)
69+
pq: true, // optional product-quantized codes for approximate scoring before rerank
6970
vectors: myEmbedder(), // any sync or async iterable of { id, vector }
7071
})
7172
```
@@ -155,6 +156,7 @@ const results = await searchVectors({
155156
source: 'https://example.com/vectors.parquet', // URL, local file path, or an open AsyncBuffer
156157
query: queryVec, // Float32Array of length `dimension`
157158
topK: 10,
159+
algorithm: 'auto', // 'auto' | 'exact' | 'binary' | 'pq'
158160
rerankFactor: 10, // candidate pool = topK * rerankFactor (default 10). Set to 0 to force exact full scan.
159161
probe: 0.25, // fraction of clusters to scan in phase 1 (default 0.25). Set to 1 to scan all clusters; pass an integer > 1 for an absolute count.
160162
})
@@ -166,11 +168,11 @@ const results = await searchVectors({
166168

167169
### How it works
168170

169-
Three columns: `id` (STRING), `vector` (`FIXED_LEN_BYTE_ARRAY(4 × dim)`, raw float32 bytes, `UNCOMPRESSED`), and — when `binary: true``vector_bin` (`FIXED_LEN_BYTE_ARRAY(dim/8)`, 1 bit per dim).
171+
Core columns: `id` (STRING), `vector` (`FIXED_LEN_BYTE_ARRAY(4 × dim)`, raw float32 bytes, `UNCOMPRESSED`), and optional ANN columns: `vector_bin` (`FIXED_LEN_BYTE_ARRAY(dim/8)`, 1 bit per dim) when `binary: true`, and `vector_pq` (`FIXED_LEN_BYTE_ARRAY(pqSegments)`) when `pq: true`.
170172

171173
**Exact search path** (no binary column, or `rerankFactor: 0`): single pass over the float32 column via `parquetRead({ onChunk })`. Each row-group's decoded `Uint8Array[]` shares a backing buffer, so we view it as one aligned `Float32Array` and stride by `dim` — zero per-row allocations.
172174

173-
**Binary + cluster + rerank path** (default when `binary: true`):
175+
**Binary + cluster + rerank path** (default when `binary: true` and no PQ column is present):
174176

175177
1. **Build-time clustering** (when `clusters > 0`): k-means on the 1-bit codes using Hamming distance and bit-majority voting. Cluster ids are then renumbered via a greedy nearest-neighbor walk so that adjacent ids = similar centroids — this makes the top-N nearest clusters at query time tend to land in fewer contiguous row ranges. Rows are sorted by the new cluster id. Centroids and per-cluster row counts go into KV metadata.
176178
2. **Phase 1 — cluster pruning**: rank clusters by Hamming(query, centroid), pick the top `probe` fraction, and Hamming-scan only those clusters' row ranges. With 32 KB pages and `useOffsetIndex`, hyparquet fetches only the pages covering each cluster's rows.
@@ -179,6 +181,8 @@ Three columns: `id` (STRING), `vector` (`FIXED_LEN_BYTE_ARRAY(4 × dim)`, raw fl
179181

180182
A `cachedAsyncBuffer` deduplicates footer / offset-index byte ranges across all the parallel `parquetRead` calls.
181183

184+
**PQ + rerank path** (`algorithm: 'pq'`, or `auto` when a file has PQ but no binary column): scan compact `vector_pq` codes over the selected cluster ranges, approximate-score candidates with lookup tables built from the query and stored PQ codebooks, then fetch full float32 vectors only for the candidate pool and exact-rerank as above. When `clusters > 0`, PQ uses the same contiguous cluster row ranges as the binary path.
185+
182186
For pre-normalized vectors with `metric: 'cosine'`, the search normalizes the query once and scores via dot product to skip the per-candidate sqrt loop.
183187

184188
### File layout
@@ -188,6 +192,7 @@ For pre-normalized vectors with `metric: 'cosine'`, the search normalizes the qu
188192
| `id` | `STRING` (UTF8) | variable | always |
189193
| `vector` | `FIXED_LEN_BYTE_ARRAY(4 × dim)` | `4 × dim` | always |
190194
| `vector_bin` | `FIXED_LEN_BYTE_ARRAY(dim/8)` | `dim/8` | when `binary: true` |
195+
| `vector_pq` | `FIXED_LEN_BYTE_ARRAY(pqSegments)` | `pqSegments` | when `pq: true` |
191196

192197
Key-value metadata:
193198

@@ -198,10 +203,14 @@ Key-value metadata:
198203
| `hypvector.metric` | `cosine` \| `dot` \| `euclidean` |
199204
| `hypvector.normalized` | `true` if vectors were L2-normalized on write |
200205
| `hypvector.binary` | `true` if the `vector_bin` column is present |
206+
| `hypvector.pq` | `true` if the `vector_pq` column is present |
201207
| `hypvector.count` | number of vectors |
202208
| `hypvector.clusters` | number of k-means clusters (0 if not clustered) |
203209
| `hypvector.centroids` | base64-encoded centroid binary codes (`clusters × dim/8` bytes); present when `clusters > 0` |
204210
| `hypvector.clusterCounts` | base64-encoded `Uint32Array` of per-cluster row counts; present when `clusters > 0` |
211+
| `hypvector.pq.segments` | number of PQ sub-vectors / bytes per code; present when `pq: true` |
212+
| `hypvector.pq.centroids` | centroids per PQ sub-vector; present when `pq: true` |
213+
| `hypvector.pq.codebooks` | base64-encoded `Float32Array` codebooks (`pq.centroids × dim` floats); present when `pq: true` |
205214

206215
### CLI
207216

bin/inspect.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ export async function inspect({ path }) {
2222
console.log(`Metric: ${meta.metric}`)
2323
console.log(`Normalized: ${meta.normalized}`)
2424
console.log(`Binary column: ${meta.hasBinary}`)
25+
console.log(`PQ column: ${meta.hasPq}`)
26+
if (meta.hasPq) {
27+
console.log(`PQ segments: ${meta.pqSegments}`)
28+
console.log(`PQ centroids: ${meta.pqCentroids}`)
29+
}
2530
console.log(`Row groups: ${metadata.row_groups.length.toLocaleString()}`)
2631
console.log(`Raw float32 size: ${rawSize.toLocaleString()} bytes`)
2732
console.log(`Overhead: ${(ratio * 100).toFixed(1)}% of raw`)

scripts/ablation.js

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
* Variants:
77
* A) base vector + id only (search must use exact full scan)
88
* B) +binary adds vector_bin column (binary phase 1 + per-cand phase 2 reads)
9-
* C) +cluster B plus k-means clustering + cluster_id col + centroids/counts KV
10-
* D) +int8 C plus vector_i8 column (int8 cascade between phases 1 and 2)
9+
* C) +cluster B plus k-means clustering + centroids/counts KV
10+
* D) +PQ C plus vector_pq column + PQ codebooks
1111
*
1212
* Page size is held at 32 KB for B-D so we isolate the feature contribution
1313
* from the page-size knob.
@@ -41,6 +41,7 @@ const variants = [
4141
{ name: 'A_base', label: 'A) base (vec only)', opts: { binary: false } },
4242
{ name: 'B_binary', label: 'B) +binary', opts: { binary: true } },
4343
{ name: 'C_cluster', label: 'C) +cluster', opts: { binary: true, clusters: 128 } },
44+
{ name: 'D_pq', label: 'D) +cluster+PQ', opts: { binary: true, clusters: 128, pq: true }, search: { algorithm: 'pq' } },
4445
]
4546

4647
for (const v of variants) {
@@ -130,6 +131,7 @@ for (const v of variants) {
130131
const opts = {}
131132
// For base file, rerankFactor=0 forces exact path. For others, default rerank/probe.
132133
if (v.name === 'A_base') opts.rerankFactor = 0
134+
Object.assign(opts, v.search)
133135
const r = await bench(v.path, opts)
134136
let hits = 0, total = 0
135137
for (let q = 0; q < ref.tops.length; q += 1) {

src/constants.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ export const defaultVectorColumn = 'vector'
1313
// Default name of the binary (sign-bit) rerank column
1414
export const defaultBinaryColumn = 'vector_bin'
1515

16+
// Default name of the product-quantized vector code column
17+
export const defaultPqColumn = 'vector_pq'
18+
1619
// Default name of the id column
1720
export const defaultIdColumn = 'id'
1821

@@ -29,3 +32,10 @@ export const defaultClusterIterations = 6
2932
// Default fraction of clusters scanned in phase 1 at query time when the
3033
// file has cluster metadata. Lower = faster but lower recall.
3134
export const defaultClusterProbeFraction = 0.25
35+
36+
// Default product quantization settings. The initial PQ path stores one
37+
// code byte per segment, with values in [0, defaultPqCentroids).
38+
export const defaultPqSegments = 32
39+
export const defaultPqCentroids = 16
40+
export const defaultPqIterations = 8
41+
export const defaultPqSampleSize = 4096

src/index.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export type {
1212
HypVectorMetadata,
1313
PrefetchBinaryOptions,
1414
ReadVectorsOptions,
15+
SearchAlgorithm,
1516
SearchResult,
1617
SearchVectorsOptions,
1718
VectorRecord,

src/pq.js

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/**
2+
* Product quantization helpers.
3+
*
4+
* Codebooks are stored segment-major. For segment s with bounds
5+
* [bounds[s], bounds[s + 1]), the codebook block starts at
6+
* `centroids * bounds[s]` and contains `centroids * segmentDim` float32s.
7+
*/
8+
9+
/**
10+
* @import { DistanceMetric, HypVectorMetadata } from './types.js'
11+
*/
12+
13+
/**
14+
* Build product-quantized codes for a set of vectors.
15+
*
16+
* @param {object} options
17+
* @param {Float32Array[]} options.vectors
18+
* @param {number} options.dimension
19+
* @param {number} options.segments
20+
* @param {number} options.centroids
21+
* @param {number} options.iterations
22+
* @param {number} options.sampleSize
23+
* @param {number} options.seed
24+
* @returns {{ codes: Uint8Array[], codebooks: Float32Array, segments: number, centroids: number }}
25+
*/
26+
export function buildPq({ vectors, dimension, segments, centroids, iterations, sampleSize, seed }) {
27+
if (!Number.isInteger(segments) || segments <= 0) {
28+
throw new Error(`pqSegments must be a positive integer, got ${segments}`)
29+
}
30+
if (!Number.isInteger(centroids) || centroids <= 1 || centroids > 256) {
31+
throw new Error(`pqCentroids must be an integer in [2, 256], got ${centroids}`)
32+
}
33+
const effectiveSegments = Math.min(segments, dimension)
34+
const bounds = pqSegmentBounds(dimension, effectiveSegments)
35+
const sample = sampleIndices(vectors.length, sampleSize)
36+
const codebooks = new Float32Array(centroids * dimension)
37+
38+
for (let s = 0; s < effectiveSegments; s += 1) {
39+
trainSegment({
40+
vectors,
41+
sample,
42+
start: bounds[s],
43+
end: bounds[s + 1],
44+
centroids,
45+
iterations,
46+
seed: seed + s * 1009,
47+
out: codebooks,
48+
})
49+
}
50+
51+
const codes = new Array(vectors.length)
52+
for (let i = 0; i < vectors.length; i += 1) {
53+
codes[i] = encodePqVector(vectors[i], codebooks, dimension, effectiveSegments, centroids)
54+
}
55+
56+
return { codes, codebooks, segments: effectiveSegments, centroids }
57+
}
58+
59+
/**
60+
* Return segment boundaries that cover [0, dimension).
61+
*
62+
* @param {number} dimension
63+
* @param {number} segments
64+
* @returns {Uint32Array}
65+
*/
66+
export function pqSegmentBounds(dimension, segments) {
67+
const bounds = new Uint32Array(segments + 1)
68+
for (let s = 0; s <= segments; s += 1) {
69+
bounds[s] = Math.floor(s * dimension / segments)
70+
}
71+
return bounds
72+
}
73+
74+
/**
75+
* Encode one vector against trained PQ codebooks.
76+
*
77+
* @param {Float32Array} vector
78+
* @param {Float32Array} codebooks
79+
* @param {number} dimension
80+
* @param {number} segments
81+
* @param {number} centroids
82+
* @returns {Uint8Array}
83+
*/
84+
export function encodePqVector(vector, codebooks, dimension, segments, centroids) {
85+
const bounds = pqSegmentBounds(dimension, segments)
86+
const code = new Uint8Array(segments)
87+
for (let s = 0; s < segments; s += 1) {
88+
const start = bounds[s]
89+
const end = bounds[s + 1]
90+
code[s] = nearestCentroid(vector, codebooks, start, end, centroids)
91+
}
92+
return code
93+
}
94+
95+
/**
96+
* Build per-segment lookup tables for approximate PQ scoring.
97+
*
98+
* For euclidean search the table stores squared L2 contributions and lower
99+
* values are better. For dot/cosine search it stores dot-product
100+
* contributions and higher values are better.
101+
*
102+
* @param {Float32Array} query
103+
* @param {HypVectorMetadata} meta
104+
* @param {DistanceMetric} metric
105+
* @returns {{ table: Float32Array, approxMetric: DistanceMetric }}
106+
*/
107+
export function buildPqTables(query, meta, metric) {
108+
if (!meta.hasPq || !meta.pqCodebooks || !meta.pqSegments || !meta.pqCentroids) {
109+
throw new Error('PQ metadata is missing')
110+
}
111+
const table = new Float32Array(meta.pqSegments * meta.pqCentroids)
112+
const bounds = pqSegmentBounds(meta.dimension, meta.pqSegments)
113+
for (let s = 0; s < meta.pqSegments; s += 1) {
114+
const start = bounds[s]
115+
const end = bounds[s + 1]
116+
const dim = end - start
117+
const block = meta.pqCentroids * start
118+
for (let c = 0; c < meta.pqCentroids; c += 1) {
119+
const centroid = block + c * dim
120+
let score = 0
121+
if (metric === 'euclidean') {
122+
for (let d = 0; d < dim; d += 1) {
123+
const delta = query[start + d] - meta.pqCodebooks[centroid + d]
124+
score += delta * delta
125+
}
126+
} else {
127+
for (let d = 0; d < dim; d += 1) {
128+
score += query[start + d] * meta.pqCodebooks[centroid + d]
129+
}
130+
}
131+
table[s * meta.pqCentroids + c] = score
132+
}
133+
}
134+
return { table, approxMetric: metric === 'euclidean' ? 'euclidean' : 'dot' }
135+
}
136+
137+
/**
138+
* Train one subspace codebook with k-means over a deterministic sample.
139+
*
140+
* @param {object} options
141+
* @param {Float32Array[]} options.vectors
142+
* @param {Int32Array} options.sample
143+
* @param {number} options.start
144+
* @param {number} options.end
145+
* @param {number} options.centroids
146+
* @param {number} options.iterations
147+
* @param {number} options.seed
148+
* @param {Float32Array} options.out
149+
*/
150+
function trainSegment({ vectors, sample, start, end, centroids, iterations, seed, out }) {
151+
const dim = end - start
152+
const block = centroids * start
153+
const sampleCount = sample.length
154+
if (sampleCount === 0) return
155+
156+
for (let c = 0; c < centroids; c += 1) {
157+
const src = vectors[sample[Math.floor(c * sampleCount / centroids)]]
158+
out.set(src.subarray(start, end), block + c * dim)
159+
}
160+
161+
for (let iter = 0; iter < iterations; iter += 1) {
162+
const counts = new Int32Array(centroids)
163+
const sums = new Float32Array(centroids * dim)
164+
165+
for (let i = 0; i < sampleCount; i += 1) {
166+
const vector = vectors[sample[i]]
167+
const best = nearestCentroid(vector, out, start, end, centroids)
168+
counts[best] += 1
169+
const sumOff = best * dim
170+
for (let d = 0; d < dim; d += 1) sums[sumOff + d] += vector[start + d]
171+
}
172+
173+
for (let c = 0; c < centroids; c += 1) {
174+
const dst = block + c * dim
175+
if (counts[c] === 0) {
176+
const src = vectors[sample[reseedIndex(seed, iter, c, sampleCount)]]
177+
out.set(src.subarray(start, end), dst)
178+
continue
179+
}
180+
const inv = 1 / counts[c]
181+
const sumOff = c * dim
182+
for (let d = 0; d < dim; d += 1) out[dst + d] = sums[sumOff + d] * inv
183+
}
184+
}
185+
}
186+
187+
/**
188+
* Find the nearest centroid for one segment under squared L2.
189+
*
190+
* @param {Float32Array} vector
191+
* @param {Float32Array} codebooks
192+
* @param {number} start
193+
* @param {number} end
194+
* @param {number} centroids
195+
* @returns {number}
196+
*/
197+
function nearestCentroid(vector, codebooks, start, end, centroids) {
198+
const dim = end - start
199+
const block = centroids * start
200+
let best = 0
201+
let bestDist = Infinity
202+
for (let c = 0; c < centroids; c += 1) {
203+
const off = block + c * dim
204+
let dist = 0
205+
for (let d = 0; d < dim; d += 1) {
206+
const delta = vector[start + d] - codebooks[off + d]
207+
dist += delta * delta
208+
if (dist >= bestDist) break
209+
}
210+
if (dist < bestDist) {
211+
bestDist = dist
212+
best = c
213+
}
214+
}
215+
return best
216+
}
217+
218+
/**
219+
* Deterministic evenly-spaced sample indices.
220+
*
221+
* @param {number} count
222+
* @param {number} sampleSize
223+
* @returns {Int32Array}
224+
*/
225+
function sampleIndices(count, sampleSize) {
226+
const n = Math.min(count, Math.max(1, sampleSize))
227+
const out = new Int32Array(n)
228+
for (let i = 0; i < n; i += 1) out[i] = Math.floor(i * count / n)
229+
return out
230+
}
231+
232+
/**
233+
* @param {number} seed
234+
* @param {number} iter
235+
* @param {number} centroid
236+
* @param {number} sampleCount
237+
* @returns {number}
238+
*/
239+
function reseedIndex(seed, iter, centroid, sampleCount) {
240+
let s = (seed ^ Math.imul(iter + 1, 2654435761) ^ Math.imul(centroid + 1, 2246822519)) >>> 0
241+
s = Math.imul(s, 1664525) + 1013904223 >>> 0
242+
return s % sampleCount
243+
}

0 commit comments

Comments
 (0)