Skip to content

Commit e043a63

Browse files
jshooksam-herman
authored andcommitted
support OnHeapGraphReconstruction
get incremental graph generation working add serialization for NeighborsCache add ord-mapping logic from PQ vectors to RAVV add ord mapping from RAVV to graph creation Signed-off-by: Samuel Herman <sherman8915@gmail.com>
1 parent ddabdb8 commit e043a63

9 files changed

Lines changed: 633 additions & 10 deletions

File tree

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) {
351351
}
352352
}
353353

354-
private static class NeighborIterator implements NodesIterator {
354+
public static class NeighborIterator implements NodesIterator {
355355
private final NodeArray neighbors;
356356
private int i;
357357

@@ -374,5 +374,9 @@ public boolean hasNext() {
374374
public int nextInt() {
375375
return neighbors.getNode(i++);
376376
}
377+
378+
public NodeArray merge(NodeArray other) {
379+
return NodeArray.merge(neighbors, other);
380+
}
377381
}
378382
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import io.github.jbellis.jvector.disk.RandomAccessReader;
2121
import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel;
2222
import io.github.jbellis.jvector.graph.SearchResult.NodeScore;
23+
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
24+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
2325
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
2426
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
2527
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
@@ -30,6 +32,7 @@
3032
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
3133
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
3234
import io.github.jbellis.jvector.vector.types.VectorFloat;
35+
import org.agrona.collections.IntArrayList;
3336
import org.slf4j.Logger;
3437
import org.slf4j.LoggerFactory;
3538

@@ -338,6 +341,50 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
338341
this.rng = new Random(0);
339342
}
340343

344+
/**
345+
* Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk
346+
* copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object
347+
*
348+
* @param buildScoreProvider the provider responsible for calculating build scores.
349+
* @param onDiskGraphIndex the on-disk representation of the graph index to be processed and converted.
350+
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
351+
* organized by levels and nodes.
352+
* @param beamWidth the width of the beam used during the graph building process.
353+
* @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit.
354+
* @param alpha the weight factor for balancing score computations.
355+
* @param addHierarchy whether to add hierarchical structures while building the graph.
356+
* @param refineFinalGraph whether to perform a refinement step on the final graph structure.
357+
* @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building.
358+
* @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building.
359+
*
360+
* @throws IOException if an I/O error occurs during the graph loading or conversion process.
361+
*/
362+
public GraphIndexBuilder(BuildScoreProvider buildScoreProvider, OnDiskGraphIndex onDiskGraphIndex, NeighborsScoreCache perLevelNeighborsScoreCache, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) throws IOException {
363+
this.scoreProvider = buildScoreProvider;
364+
this.neighborOverflow = neighborOverflow;
365+
this.dimension = onDiskGraphIndex.getDimension();
366+
this.alpha = alpha;
367+
this.addHierarchy = addHierarchy;
368+
this.refineFinalGraph = refineFinalGraph;
369+
this.beamWidth = beamWidth;
370+
this.simdExecutor = simdExecutor;
371+
this.parallelExecutor = parallelExecutor;
372+
373+
this.graph = OnHeapGraphIndex.convertToHeap(onDiskGraphIndex, perLevelNeighborsScoreCache, buildScoreProvider, neighborOverflow, alpha);
374+
375+
this.searchers = ExplicitThreadLocal.withInitial(() -> {
376+
var gs = new GraphSearcher(graph);
377+
gs.usePruning(false);
378+
return gs;
379+
});
380+
381+
// in scratch, we store candidates in reverse order: worse candidates are first
382+
this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
383+
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1)));
384+
385+
this.rng = new Random(0);
386+
}
387+
341388
// used by Cassandra when it fine-tunes the PQ codebook
342389
public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
343390
var newBuilder = new GraphIndexBuilder(newProvider,
@@ -750,7 +797,7 @@ public synchronized long removeDeletedNodes() {
750797
return memorySize;
751798
}
752799

753-
private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray concurrent) {
800+
private void updateNeighbors(int layer, int nodeId, NodeArray natural, NodeArray concurrent) {
754801
// if either natural or concurrent is empty, skip the merge
755802
NodeArray toMerge;
756803
if (concurrent.size() == 0) {
@@ -761,7 +808,7 @@ private void updateNeighbors(int level, int nodeId, NodeArray natural, NodeArray
761808
toMerge = NodeArray.merge(natural, concurrent);
762809
}
763810
// toMerge may be approximate-scored, but insertDiverse will compute exact scores for the diverse ones
764-
graph.addEdges(level, nodeId, toMerge, neighborOverflow);
811+
graph.addEdges(layer, nodeId, toMerge, neighborOverflow);
765812
}
766813

767814
private static NodeArray toScratchCandidates(NodeScore[] candidates, NodeArray scratch) {
@@ -876,6 +923,7 @@ private void loadV4(RandomAccessReader in) throws IOException {
876923
graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode));
877924
}
878925

926+
879927
@Deprecated
880928
private void loadV3(RandomAccessReader in, int size) throws IOException {
881929
if (graph.size() != 0) {

jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
package io.github.jbellis.jvector.graph;
2626

2727
import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors;
28+
import io.github.jbellis.jvector.graph.disk.NeighborsScoreCache;
29+
import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex;
2830
import io.github.jbellis.jvector.graph.diversity.DiversityProvider;
2931
import io.github.jbellis.jvector.util.Accountable;
3032
import io.github.jbellis.jvector.util.BitSet;
@@ -33,14 +35,22 @@
3335
import io.github.jbellis.jvector.util.RamUsageEstimator;
3436
import io.github.jbellis.jvector.util.SparseIntMap;
3537
import io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
38+
import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider;
39+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
40+
import io.github.jbellis.jvector.util.*;
41+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
42+
import io.github.jbellis.jvector.vector.types.VectorFloat;
3643
import org.agrona.collections.IntArrayList;
3744

3845
import java.io.DataOutput;
3946
import java.io.IOException;
4047
import java.io.UncheckedIOException;
4148
import java.util.ArrayList;
4249
import java.util.List;
50+
import java.util.Map;
4351
import java.util.NoSuchElementException;
52+
import java.util.concurrent.ConcurrentHashMap;
53+
import java.util.concurrent.ConcurrentMap;
4454
import java.util.concurrent.atomic.AtomicInteger;
4555
import java.util.concurrent.atomic.AtomicIntegerArray;
4656
import java.util.concurrent.atomic.AtomicReference;
@@ -441,7 +451,7 @@ public boolean hasNext() {
441451
}
442452
}
443453

444-
private class FrozenView implements View {
454+
public class FrozenView implements View {
445455
@Override
446456
public NodesIterator getNeighborsIterator(int level, int node) {
447457
return OnHeapGraphIndex.this.getNeighborsIterator(level, node);
@@ -598,4 +608,68 @@ private void ensureCapacity(int node) {
598608
}
599609
}
600610
}
611+
612+
/**
613+
* Converts an OnDiskGraphIndex to an OnHeapGraphIndex by copying all nodes, their levels, and neighbors,
614+
* along with other configuration details, from disk-based storage to heap-based storage.
615+
*
616+
* @param diskIndex the disk-based index to be converted
617+
* @param perLevelNeighborsScoreCache the cache containing pre-computed neighbor scores,
618+
* organized by levels and nodes.
619+
* @param bsp The build score provider to be used for
620+
* @param overflowRatio usually 1.2f
621+
* @param alpha usually 1.2f
622+
* @return an OnHeapGraphIndex that is equivalent to the provided OnDiskGraphIndex but operates in heap memory
623+
* @throws IOException if an I/O error occurs during the conversion process
624+
*/
625+
public static OnHeapGraphIndex convertToHeap(OnDiskGraphIndex diskIndex,
626+
NeighborsScoreCache perLevelNeighborsScoreCache,
627+
BuildScoreProvider bsp,
628+
float overflowRatio,
629+
float alpha) throws IOException {
630+
631+
// Create a new OnHeapGraphIndex with the appropriate configuration
632+
List<Integer> maxDegrees = new ArrayList<>();
633+
for (int level = 0; level <= diskIndex.getMaxLevel(); level++) {
634+
maxDegrees.add(diskIndex.getDegree(level));
635+
}
636+
637+
OnHeapGraphIndex heapIndex = new OnHeapGraphIndex(
638+
maxDegrees,
639+
overflowRatio, // overflow ratio
640+
new VamanaDiversityProvider(bsp, alpha) // diversity provider - can be null for basic usage
641+
);
642+
643+
// Copy all nodes and their connections from disk to heap
644+
try (var view = diskIndex.getView()) {
645+
// Copy nodes level by level
646+
for (int level = 0; level <= diskIndex.getMaxLevel(); level++) {
647+
final NodesIterator nodesIterator = diskIndex.getNodes(level);
648+
final Map<Integer, NodeArray> levelNeighborsScoreCache = perLevelNeighborsScoreCache.getNeighborsScoresInLevel(level);
649+
if (levelNeighborsScoreCache == null) {
650+
throw new IllegalStateException("No neighbors score cache found for level " + level);
651+
}
652+
if (nodesIterator.size() != levelNeighborsScoreCache.size()) {
653+
throw new IllegalStateException("Neighbors score cache size mismatch for level " + level +
654+
". Expected (currently in index): " + nodesIterator.size() + ", but got (in cache): " + levelNeighborsScoreCache.size());
655+
}
656+
657+
while (nodesIterator.hasNext()) {
658+
int nodeId = nodesIterator.next();
659+
660+
// Copy neighbors
661+
final NodeArray neighbors = levelNeighborsScoreCache.get(nodeId).copy();
662+
663+
// Add the node with its neighbors
664+
heapIndex.addNode(level, nodeId, neighbors);
665+
heapIndex.markComplete(new NodeAtLevel(level, nodeId));
666+
}
667+
}
668+
669+
// Set the entry point
670+
heapIndex.updateEntryNode(view.entryNode());
671+
}
672+
673+
return heapIndex;
674+
}
601675
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.graph.disk;
18+
19+
import io.github.jbellis.jvector.disk.IndexWriter;
20+
import io.github.jbellis.jvector.disk.RandomAccessReader;
21+
import io.github.jbellis.jvector.graph.*;
22+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
23+
24+
import java.io.IOException;
25+
import java.util.HashMap;
26+
import java.util.Map;
27+
28+
/**
29+
* Cache containing pre-computed neighbor scores, organized by levels and nodes.
30+
* <p>
31+
* This cache bridges the gap between {@link OnDiskGraphIndex} and {@link OnHeapGraphIndex}:
32+
* <ul>
33+
* <li>{@link OnDiskGraphIndex} stores only neighbor IDs (not scores) for space efficiency</li>
34+
* <li>{@link OnHeapGraphIndex} requires neighbor scores for pruning operations</li>
35+
* </ul>
36+
* <p>
37+
* When converting from disk to heap representation, this cache avoids expensive score
38+
* recomputation by providing pre-calculated neighbor scores for all graph levels.
39+
*
40+
* @see OnHeapGraphIndex#convertToHeap(OnDiskGraphIndex, NeighborsScoreCache, BuildScoreProvider, float, float)
41+
*
42+
* This is particularly useful when merging new nodes into an existing graph.
43+
* @see GraphIndexBuilder#buildAndMergeNewNodes(OnDiskGraphIndex, NeighborsScoreCache, RandomAccessVectorValues, BuildScoreProvider, int, int[], int, float, float, boolean)
44+
*/
45+
public class NeighborsScoreCache {
46+
private final Map<Integer, Map<Integer, NodeArray>> perLevelNeighborsScoreCache;
47+
48+
public NeighborsScoreCache(OnHeapGraphIndex graphIndex) throws IOException {
49+
try (OnHeapGraphIndex.FrozenView view = graphIndex.getFrozenView()) {
50+
final Map<Integer, Map<Integer, NodeArray>> perLevelNeighborsScoreCache = new HashMap<>(graphIndex.getMaxLevel() + 1);
51+
for (int level = 0; level <= graphIndex.getMaxLevel(); level++) {
52+
final Map<Integer, NodeArray> levelNeighborsScores = new HashMap<>(graphIndex.size(level) + 1);
53+
final NodesIterator nodesIterator = graphIndex.getNodes(level);
54+
while (nodesIterator.hasNext()) {
55+
final int nodeId = nodesIterator.nextInt();
56+
57+
ConcurrentNeighborMap.NeighborIterator neighborIterator = (ConcurrentNeighborMap.NeighborIterator) view.getNeighborsIterator(level, nodeId);
58+
final NodeArray neighbours = neighborIterator.merge(new NodeArray(neighborIterator.size()));
59+
levelNeighborsScores.put(nodeId, neighbours);
60+
}
61+
62+
perLevelNeighborsScoreCache.put(level, levelNeighborsScores);
63+
}
64+
65+
this.perLevelNeighborsScoreCache = perLevelNeighborsScoreCache;
66+
}
67+
}
68+
69+
public NeighborsScoreCache(RandomAccessReader in) throws IOException {
70+
final int numberOfLevels = in.readInt();
71+
perLevelNeighborsScoreCache = new HashMap<>(numberOfLevels);
72+
for (int i = 0; i < numberOfLevels; i++) {
73+
final int level = in.readInt();
74+
final int numberOfNodesInLevel = in.readInt();
75+
final Map<Integer, NodeArray> levelNeighborsScores = new HashMap<>(numberOfNodesInLevel);
76+
for (int j = 0; j < numberOfNodesInLevel; j++) {
77+
final int nodeId = in.readInt();
78+
final int numberOfNeighbors = in.readInt();
79+
final NodeArray nodeArray = new NodeArray(numberOfNeighbors);
80+
for (int k = 0; k < numberOfNeighbors; k++) {
81+
final int neighborNodeId = in.readInt();
82+
final float neighborScore = in.readFloat();
83+
nodeArray.insertSorted(neighborNodeId, neighborScore);
84+
}
85+
levelNeighborsScores.put(nodeId, nodeArray);
86+
}
87+
perLevelNeighborsScoreCache.put(level, levelNeighborsScores);
88+
}
89+
}
90+
91+
public void write(IndexWriter out) throws IOException {
92+
out.writeInt(perLevelNeighborsScoreCache.size()); // write the number of levels
93+
for (Map.Entry<Integer ,Map<Integer, NodeArray>> levelNeighborsScores : perLevelNeighborsScoreCache.entrySet()) {
94+
final int level = levelNeighborsScores.getKey();
95+
out.writeInt(level);
96+
out.writeInt(levelNeighborsScores.getValue().size()); // write the number of nodes in the level
97+
// Write the neighborhoods for each node in the level
98+
for (Map.Entry<Integer, NodeArray> nodeArrayEntry : levelNeighborsScores.getValue().entrySet()) {
99+
final int nodeId = nodeArrayEntry.getKey();
100+
out.writeInt(nodeId);
101+
final NodeArray nodeArray = nodeArrayEntry.getValue();
102+
out.writeInt(nodeArray.size()); // write the number of neighbors for the node
103+
// Write the nodeArray(neighbors)
104+
for (int i = 0; i < nodeArray.size(); i++) {
105+
out.writeInt(nodeArray.getNode(i));
106+
out.writeFloat(nodeArray.getScore(i));
107+
}
108+
}
109+
}
110+
}
111+
112+
public Map<Integer, NodeArray> getNeighborsScoresInLevel(int level) {
113+
return perLevelNeighborsScoreCache.get(level);
114+
}
115+
116+
117+
}

jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import io.github.jbellis.jvector.vector.types.VectorFloat;
2626
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
2727

28+
import java.util.stream.IntStream;
29+
2830
/**
2931
* Encapsulates comparing node distances for GraphIndexBuilder.
3032
*/
@@ -83,8 +85,17 @@ public interface BuildScoreProvider {
8385

8486
/**
8587
* Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction.
88+
*
89+
* Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function.
8690
*/
8791
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
92+
return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction);
93+
}
94+
95+
/**
96+
* Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction.
97+
*/
98+
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) {
8899
// We need two sources of vectors in order to perform diversity check comparisons without
89100
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
90101
var vectors = ravv.threadLocalSupplier();
@@ -113,22 +124,22 @@ public VectorFloat<?> approximateCentroid() {
113124
@Override
114125
public SearchScoreProvider searchProviderFor(VectorFloat<?> vector) {
115126
var vc = vectorsCopy.get();
116-
return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc);
127+
return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc);
117128
}
118129

119130
@Override
120131
public SearchScoreProvider searchProviderFor(int node1) {
121132
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
122-
var v = randomAccessVectorValues.getVector(node1);
133+
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
123134
return searchProviderFor(v);
124135
}
125136

126137
@Override
127138
public SearchScoreProvider diversityProviderFor(int node1) {
128139
RandomAccessVectorValues randomAccessVectorValues = vectors.get();
129-
var v = randomAccessVectorValues.getVector(node1);
140+
var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]);
130141
var vc = vectorsCopy.get();
131-
return DefaultSearchScoreProvider.exact(v, similarityFunction, vc);
142+
return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc);
132143
}
133144
};
134145
}

0 commit comments

Comments
 (0)