Skip to content

Commit 250462c

Browse files
committed
Add NVQ example
1 parent 46bd115 commit 250462c

3 files changed

Lines changed: 180 additions & 0 deletions

File tree

jvector-examples/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@
143143
<artifactId>jackson-databind</artifactId>
144144
<version>2.17.1</version>
145145
</dependency>
146+
<dependency>
147+
<groupId>me.tongfei</groupId>
148+
<artifactId>progressbar</artifactId>
149+
<version>0.10.2</version>
150+
</dependency>
146151
<dependency>
147152
<groupId>junit</groupId>
148153
<artifactId>junit</artifactId>
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.example.tutorial;
18+
19+
import java.io.IOException;
20+
import java.io.UncheckedIOException;
21+
import java.nio.file.Files;
22+
import java.nio.file.Path;
23+
import java.util.List;
24+
import java.util.Map;
25+
import java.util.stream.Collectors;
26+
import java.util.stream.IntStream;
27+
28+
import io.github.jbellis.jvector.disk.ReaderSupplierFactory;
29+
import io.github.jbellis.jvector.example.benchmarks.datasets.DataSets;
30+
import io.github.jbellis.jvector.example.util.AccuracyMetrics;
31+
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
32+
import io.github.jbellis.jvector.graph.GraphSearcher;
33+
import io.github.jbellis.jvector.graph.ImmutableGraphIndex;
34+
import io.github.jbellis.jvector.graph.SearchResult;
35+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
36+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter;
37+
import io.github.jbellis.jvector.graph.disk.OrdinalMapper;
38+
import io.github.jbellis.jvector.graph.disk.feature.Feature;
39+
import io.github.jbellis.jvector.graph.disk.feature.FeatureId;
40+
import io.github.jbellis.jvector.graph.disk.feature.NVQ;
41+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
42+
import io.github.jbellis.jvector.graph.similarity.DefaultSearchScoreProvider;
43+
import io.github.jbellis.jvector.quantization.MutablePQVectors;
44+
import io.github.jbellis.jvector.quantization.NVQuantization;
45+
import io.github.jbellis.jvector.quantization.ProductQuantization;
46+
import io.github.jbellis.jvector.util.Bits;
47+
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
48+
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
49+
import me.tongfei.progressbar.ProgressBar;
50+
51+
// Demonstrates using Non-uniform Vector Quantization (NVQ) for reducing the footprint of the disk graph.
52+
public class NvqExample {
53+
public static void main(String[] args) throws IOException {
54+
// Load a preconfigured dataset
55+
var ds = DataSets.loadDataSet("ada002-100k").orElseThrow(() ->
56+
new RuntimeException("dataset not found"))
57+
.getDataSet();
58+
var dim = ds.getDimension();
59+
var vsf = ds.getSimilarityFunction();
60+
var base = ds.getBaseRavv();
61+
62+
var numSubVectors = 2;
63+
64+
// Setup NVQ parameters.
65+
// The base vectors RAVV instance is used only for computing the global mean
66+
var nvq = NVQuantization.compute(base, numSubVectors);
67+
// Use this method instead if you don't have all the vectors up-front but can estimate the mean
68+
// var nvq = NVQuantization.create(scaledGlobalMean, numSubVectors);
69+
70+
// Graph construction parameters
71+
var M = 32;
72+
var ef = 100;
73+
var nOv = 1.2f;
74+
var alpha = 1.2f;
75+
var addHierarchy = true;
76+
77+
var pqMFactor = 8;
78+
var pqM = (ds.getDimension() + pqMFactor - 1) / pqMFactor;
79+
var pqClusterCount = 256;
80+
var pqGloballyCenter = false;
81+
82+
// PQ is used for graph building and first-stage scoring during query
83+
var pq = ProductQuantization.compute(base, pqM, pqClusterCount, pqGloballyCenter);
84+
85+
// Empty PQVectors instance, will be updated as we stream in vectors
86+
var pqv = new MutablePQVectors(pq);
87+
var bsp = BuildScoreProvider.pqBuildScoreProvider(vsf, pqv);
88+
89+
var graphPath = Path.of("./local/tmp.jvgraph");
90+
Files.deleteIfExists(graphPath);
91+
92+
System.out.println("Building graph in streaming mode...");
93+
try (
94+
// Create the graph builder using PQ-based scoring
95+
var builder = new GraphIndexBuilder(bsp, dim, M, ef, nOv, alpha, addHierarchy);
96+
// Create the on-disk writer configured with NVQ feature
97+
// This allows us to write both the graph structure and NVQ-compressed vectors
98+
var writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), graphPath)
99+
.with(new NVQ(nvq))
100+
.withMapper(new OrdinalMapper.IdentityMapper(base.size() - 1))
101+
.build();
102+
var pb = new ProgressBar("Build graph", base.size());
103+
) {
104+
105+
PhysicalCoreExecutor.pool().submit(() -> {
106+
IntStream.range(0, base.size())
107+
.parallel()
108+
.forEach(ordinal -> {
109+
var vec = base.getVector(ordinal);
110+
111+
// Encode the PQ vector first, then add the graph node
112+
pqv.encodeAndSet(ordinal, vec);
113+
builder.addGraphNode(ordinal, base.getVector(ordinal));
114+
115+
// Encode and write NVQ vectors for later re-ranking
116+
var nvqVec = nvq.encode(vec);
117+
Map<FeatureId, Feature.State> featureMap = Map.of(
118+
FeatureId.NVQ_VECTORS, new NVQ.State(nvqVec)
119+
);
120+
try {
121+
writer.writeFeaturesInline(ordinal, featureMap);
122+
} catch (IOException e) {
123+
throw new UncheckedIOException(e);
124+
}
125+
pb.step();
126+
});
127+
}).join();
128+
pb.close();
129+
130+
// cleanup
131+
System.out.println("Cleanup...");
132+
builder.cleanup();
133+
writer.write(Map.of());
134+
}
135+
136+
// Search parameters
137+
var topK = 10;
138+
var rerankK = 100;
139+
140+
List<SearchResult> results;
141+
142+
System.out.println("Loading and searching the graph...");
143+
try (
144+
var rs = ReaderSupplierFactory.open(graphPath);
145+
var graph = OnDiskGraphIndex.load(rs);
146+
var searchers = ExplicitThreadLocal.withInitial(() -> new GraphSearcher(graph));
147+
) {
148+
results = ds.getQueryVectors()
149+
.parallelStream()
150+
.map(query -> {
151+
var searcher = searchers.get();
152+
var scoringView = (ImmutableGraphIndex.ScoringView) searcher.getView();
153+
154+
// Two-phase search with NVQ:
155+
// 1. Use PQ for fast approximate search to get rerankK candidates
156+
var asf = pqv.precomputedScoreFunctionFor(query, vsf);
157+
// 2. Use NVQ-compressed vectors from disk for accurate reranking to topK
158+
// The reranker automatically uses the NVQ vectors stored in the graph
159+
var reranker = scoringView.rerankerFor(query, vsf);
160+
var ssp = new DefaultSearchScoreProvider(asf, reranker);
161+
return searcher.search(ssp, topK, rerankK, 0.0f, 0.0f, Bits.ALL);
162+
})
163+
.collect(Collectors.toList());
164+
} catch (Exception e) {
165+
throw new RuntimeException(e);
166+
}
167+
168+
// Evaluate search accuracy
169+
var recall = AccuracyMetrics.recallFromSearchResults(ds.getGroundTruth(), results, topK, topK);
170+
System.out.println("Recall: " + recall);
171+
}
172+
}

jvector-examples/src/main/java/io/github/jbellis/jvector/example/tutorial/TutorialRunner.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ public static void main(String[] args) throws IOException {
3838
case "ltm":
3939
LargerThanMemory.main(forwardArgs);
4040
break;
41+
case "nvq":
42+
NvqExample.main(forwardArgs);
43+
break;
4144
default:
4245
throw new IllegalArgumentException("Unknown example" + args[0]);
4346
}

0 commit comments

Comments
 (0)