Skip to content

Commit c7efb29

Browse files
authored
Merge branch 'main' into abhijat/feat/implement-chunk-tag-load
2 parents 03f82eb + 2a31378 commit c7efb29

17 files changed

Lines changed: 1049 additions & 418 deletions

src/core/search/hnsw_index.cc

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -132,35 +132,13 @@ struct HnswlibAdapter {
132132
HnswIndexMetadata GetMetadata() const {
133133
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
134134
HnswIndexMetadata metadata;
135-
metadata.max_elements = world_.max_elements_;
136-
metadata.cur_element_count = world_.cur_element_count.load();
137-
metadata.maxlevel = world_.maxlevel_;
138135
metadata.enterpoint_node = world_.enterpoint_node_;
139136
return metadata;
140137
}
141138

142-
void SetMetadata(const HnswIndexMetadata& metadata) {
143-
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
144-
absl::WriterMutexLock resize_lock(&resize_mutex_);
145-
146-
// SetMetadata is only called during deserialization before the index is used.
147-
// Assert the index is empty to ensure no concurrent operations are possible.
148-
DCHECK_EQ(world_.cur_element_count.load(), 0u)
149-
<< "SetMetadata should only be called on an empty index during deserialization";
150-
151-
// Runtime check for release builds to prevent silent corruption
152-
if (world_.cur_element_count.load() != 0) {
153-
LOG(ERROR) << "SetMetadata called on non-empty HNSW index with "
154-
<< world_.cur_element_count.load() << " elements, ignoring";
155-
return;
156-
}
157-
158-
// Pre-allocate capacity based on expected element count, but don't set cur_element_count.
159-
// cur_element_count will be set by RestoreFromNodes when the actual nodes are restored.
160-
if (world_.max_elements_ < metadata.cur_element_count) {
161-
world_.resizeIndex(metadata.cur_element_count);
162-
}
163-
// Note: Don't set cur_element_count here - RestoreFromNodes will set it after restoring nodes.
139+
int GetMaxLevel() const {
140+
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock);
141+
return world_.maxlevel_;
164142
}
165143

166144
size_t GetNodeCount() const {
@@ -280,31 +258,41 @@ struct HnswlibAdapter {
280258
}
281259

282260
public:
283-
// Restore HNSW graph structure from serialized nodes with metadata
284-
void RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata) {
261+
// Restore HNSW graph structure from serialized nodes with metadata.
262+
// Returns false if the input is inconsistent (e.g. entry point not in node set) —
263+
// caller should fall back to rebuilding the index from the keyspace.
264+
bool RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata) {
285265
MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock);
286266
absl::WriterMutexLock resize_lock(&resize_mutex_);
287267

288268
if (nodes.empty()) {
289-
return;
269+
return true;
290270
}
291271

292272
// RestoreFromNodes is only called during deserialization on a freshly created index.
293273
// Assert the index is empty to prevent memory leaks from double-allocation of linkLists_.
294274
DCHECK_EQ(world_.cur_element_count.load(), 0u)
295275
<< "RestoreFromNodes should only be called on an empty index during deserialization";
296276

297-
// Ensure we have enough capacity.
298-
// Metadata may have been captured before the snapshot read-lock, so
299-
// cur_element_count can be smaller than actual node internal_ids when
300-
// concurrent writes happen. Compute the real requirement from nodes.
277+
// hnswlib pairs enterpoint_node_ with maxlevel_; node levels are immutable after
278+
// creation, so the entry point's level in the serialized set equals the live
279+
// maxlevel at metadata capture. max(node.level) would risk OOB reads when a
280+
// concurrent Add raised maxlevel between capture and node serialization.
301281
size_t max_internal_id = 0;
282+
int entrypoint_level = -1;
302283
for (const auto& node : nodes) {
303284
max_internal_id = std::max<size_t>(max_internal_id, node.internal_id);
285+
if (node.internal_id == metadata.enterpoint_node)
286+
entrypoint_level = node.level;
304287
}
305-
size_t required_capacity = std::max(metadata.cur_element_count, max_internal_id + 1);
306-
if (world_.max_elements_ < required_capacity) {
307-
world_.resizeIndex(required_capacity);
288+
if (entrypoint_level < 0) {
289+
LOG(ERROR) << "HNSW restore: entry point internal_id=" << metadata.enterpoint_node
290+
<< " not present in serialized node set (" << nodes.size()
291+
<< " nodes); skipping restore — index will be rebuilt from the keyspace";
292+
return false;
293+
}
294+
if (world_.max_elements_ < max_internal_id + 1) {
295+
world_.resizeIndex(max_internal_id + 1);
308296
}
309297

310298
// Restore each node - directly set up memory and fields
@@ -378,12 +366,13 @@ struct HnswlibAdapter {
378366
}
379367

380368
// Set the metadata for the graph
381-
world_.maxlevel_ = metadata.maxlevel;
369+
world_.maxlevel_ = entrypoint_level;
382370
world_.enterpoint_node_ = metadata.enterpoint_node;
383371

384372
VLOG(1) << "Restored HNSW index with " << restored_count
385-
<< " nodes, maxlevel=" << metadata.maxlevel
373+
<< " nodes, maxlevel=" << entrypoint_level
386374
<< ", enterpoint=" << metadata.enterpoint_node;
375+
return true;
387376
}
388377

389378
// Update vector data for an existing node (used after RestoreFromNodes).
@@ -502,8 +491,8 @@ HnswIndexMetadata HnswVectorIndex::GetMetadata() const {
502491
return adapter_->GetMetadata();
503492
}
504493

505-
void HnswVectorIndex::SetMetadata(const HnswIndexMetadata& metadata) {
506-
adapter_->SetMetadata(metadata);
494+
int HnswVectorIndex::GetMaxLevel() const {
495+
return adapter_->GetMaxLevel();
507496
}
508497

509498
size_t HnswVectorIndex::GetNodeCount() const {
@@ -514,9 +503,9 @@ std::vector<HnswNodeData> HnswVectorIndex::GetNodesRange(size_t start, size_t en
514503
return adapter_->GetNodesRange(start, end);
515504
}
516505

517-
void HnswVectorIndex::RestoreFromNodes(const std::vector<HnswNodeData>& nodes,
506+
bool HnswVectorIndex::RestoreFromNodes(const std::vector<HnswNodeData>& nodes,
518507
const HnswIndexMetadata& metadata) {
519-
adapter_->RestoreFromNodes(nodes, metadata);
508+
return adapter_->RestoreFromNodes(nodes, metadata);
520509
}
521510

522511
bool HnswVectorIndex::UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc,

src/core/search/hnsw_index.h

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@
1111

1212
namespace dfly::search {
1313

14-
// Metadata structure for HNSW index serialization
15-
// Contains the key parameters needed to restore the index state
14+
// Wire format for HNSW index AUX. Only the entry point is persisted: capacity is
15+
// derived from max(internal_id)+1 in the node set and maxlevel from the entry-point
16+
// node's level (hnswlib pairs enterpoint_node_ with maxlevel_, and node levels are
17+
// immutable after creation).
1618
struct HnswIndexMetadata {
17-
size_t max_elements = 0; // Maximum number of elements the index can hold
18-
// Note: cur_element_count may be smaller than actual node count during concurrent writes,
19-
// so we compute the real requirement from nodes during restoration.
20-
// TODO: consider removing it from metadata and rely entirely on node data for restoration.
21-
size_t cur_element_count = 0; // Current number of elements in the index
22-
int maxlevel = -1; // Maximum level of the graph
23-
size_t enterpoint_node = 0; // Entry point node for the graph
19+
size_t enterpoint_node = 0;
2420
};
2521

2622
// Node data structure for HNSW serialization
@@ -75,8 +71,9 @@ class HnswVectorIndex {
7571
// Get metadata for serialization
7672
HnswIndexMetadata GetMetadata() const;
7773

78-
// Set metadata (used during restoration)
79-
void SetMetadata(const HnswIndexMetadata& metadata);
74+
// Current graph maxlevel_. Exposed for introspection and tests that need to
75+
// verify invariants preserved by RestoreFromNodes (entry point must sit at maxlevel).
76+
int GetMaxLevel() const;
8077

8178
// Get total number of nodes in the index
8279
size_t GetNodeCount() const;
@@ -85,10 +82,12 @@ class HnswVectorIndex {
8582
// Returns vector of node data for serialization
8683
std::vector<HnswNodeData> GetNodesRange(size_t start, size_t end) const;
8784

88-
// Restore graph structure from serialized nodes with metadata
89-
// This restores the HNSW graph links but NOT the vector data
90-
// Vector data must be populated separately via UpdateVectorData
91-
void RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata);
85+
// Restore graph structure from serialized nodes with metadata.
86+
// Restores links only; vector data must be populated separately via UpdateVectorData.
87+
// Returns false if the metadata is inconsistent with the node set (e.g. the entry
88+
// point is missing from the serialized nodes) — caller should then leave the index
89+
// empty and let the higher-level rebuild path repopulate it from the keyspace.
90+
bool RestoreFromNodes(const std::vector<HnswNodeData>& nodes, const HnswIndexMetadata& metadata);
9291

9392
// Update vector data for an existing node (used after RestoreFromNodes)
9493
// This populates the vector data for a node that already has graph links

src/core/search/search_test.cc

Lines changed: 101 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,30 @@ TEST_F(KnnTest, AutoResize) {
10721072
EXPECT_EQ(indices.GetAllDocs().size(), 100);
10731073
}
10741074

1075+
// Seeds the given HNSW index with `n` deterministic random vectors of dim `dim` using
1076+
// the given RNG seed. Returns the owning MockedDocuments so the caller can pass them
1077+
// back to UpdateVectorData after a restore. Used by the serialization/restore tests.
1078+
inline vector<MockedDocument> SeedHnswIndex(HnswVectorIndex& index, size_t n, size_t dim,
1079+
uint32_t rng_seed) {
1080+
vector<MockedDocument> docs(n);
1081+
std::mt19937 rng(rng_seed);
1082+
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
1083+
for (size_t i = 0; i < n; i++) {
1084+
vector<float> coords(dim);
1085+
for (size_t d = 0; d < dim; d++)
1086+
coords[d] = dist(rng);
1087+
docs[i] = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}};
1088+
index.Add(i, docs[i], "vec");
1089+
}
1090+
return docs;
1091+
}
1092+
1093+
// Snapshots all nodes from the index under its read lock.
1094+
inline vector<HnswNodeData> SnapshotHnswNodes(const HnswVectorIndex& index) {
1095+
auto lock = index.GetReadLock();
1096+
return index.GetNodesRange(0, index.GetNodeCount());
1097+
}
1098+
10751099
// Parameterized HNSW serialization round-trip test.
10761100
// Parameters: {num_elements, dim, similarity}
10771101
struct HnswSerParam {
@@ -1108,27 +1132,12 @@ TEST_P(HnswSerializationTest, RoundTrip) {
11081132
params.hnsw_ef_construction = 200;
11091133

11101134
HnswVectorIndex original(params, /*copy_vector=*/true);
1135+
vector<MockedDocument> docs = SeedHnswIndex(original, num_elements, dim, /*rng_seed=*/42);
11111136

1112-
std::mt19937 rng(42);
1113-
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
1114-
vector<MockedDocument> docs(num_elements);
1115-
for (size_t i = 0; i < num_elements; i++) {
1116-
vector<float> coords(dim);
1117-
for (size_t d = 0; d < dim; d++)
1118-
coords[d] = dist(rng);
1119-
docs[i] = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}};
1120-
original.Add(i, docs[i], "vec");
1121-
}
1122-
1123-
// Serialize
11241137
auto metadata = original.GetMetadata();
1125-
ASSERT_EQ(metadata.cur_element_count, num_elements);
1138+
ASSERT_EQ(original.GetNodeCount(), num_elements);
11261139

1127-
std::vector<HnswNodeData> nodes;
1128-
{
1129-
auto lock = original.GetReadLock();
1130-
nodes = original.GetNodesRange(0, metadata.cur_element_count);
1131-
}
1140+
std::vector<HnswNodeData> nodes = SnapshotHnswNodes(original);
11321141
ASSERT_EQ(nodes.size(), num_elements);
11331142

11341143
// Verify node data integrity
@@ -1139,8 +1148,7 @@ TEST_P(HnswSerializationTest, RoundTrip) {
11391148

11401149
// Deserialize into a fresh index
11411150
HnswVectorIndex restored(params, /*copy_vector=*/true);
1142-
restored.SetMetadata(metadata);
1143-
restored.RestoreFromNodes(nodes, metadata);
1151+
ASSERT_TRUE(restored.RestoreFromNodes(nodes, metadata));
11441152

11451153
// Before UpdateVectorData, all nodes must be marked deleted.
11461154
// KNN should safely return empty results (no crash from nullptr dereference).
@@ -1153,17 +1161,16 @@ TEST_P(HnswSerializationTest, RoundTrip) {
11531161
for (size_t i = 0; i < num_elements; i++)
11541162
restored.UpdateVectorData(i, docs[i], "vec");
11551163

1156-
// Metadata must match
11571164
auto rm = restored.GetMetadata();
1158-
EXPECT_EQ(rm.cur_element_count, metadata.cur_element_count);
1159-
EXPECT_EQ(rm.maxlevel, metadata.maxlevel);
1165+
EXPECT_EQ(restored.GetNodeCount(), num_elements);
11601166
EXPECT_EQ(rm.enterpoint_node, metadata.enterpoint_node);
1167+
EXPECT_EQ(restored.GetMaxLevel(), original.GetMaxLevel());
11611168

11621169
// Graph links must be identical
11631170
std::vector<HnswNodeData> restored_nodes;
11641171
{
11651172
auto lock = restored.GetReadLock();
1166-
restored_nodes = restored.GetNodesRange(0, rm.cur_element_count);
1173+
restored_nodes = restored.GetNodesRange(0, restored.GetNodeCount());
11671174
}
11681175
ASSERT_EQ(restored_nodes.size(), nodes.size());
11691176
for (size_t i = 0; i < nodes.size(); i++) {
@@ -1209,6 +1216,76 @@ TEST_P(HnswSerializationTest, RoundTrip) {
12091216
}
12101217
}
12111218

1219+
// Regression for the save-side race where an Add raises maxlevel between metadata
1220+
// capture and node serialization (see RestoreFromNodes for the rationale). Simulated
1221+
// by forging metadata with a low-level entry point against a multi-level node set;
1222+
// expects maxlevel_ to clamp to the entry point's level rather than max(node.level).
1223+
TEST(HnswRestoreInvariant, MaxLevelClampedToEntryPointLevel) {
1224+
constexpr size_t kDim = 8;
1225+
constexpr size_t kN = 100;
1226+
1227+
InitTLSearchMR(PMR_NS::get_default_resource());
1228+
absl::Cleanup cleanup = [] { InitTLSearchMR(nullptr); };
1229+
1230+
SchemaField::VectorParams params;
1231+
params.use_hnsw = true;
1232+
params.dim = kDim;
1233+
params.sim = VectorSimilarity::L2;
1234+
params.capacity = kN;
1235+
params.hnsw_m = 16;
1236+
params.hnsw_ef_construction = 200;
1237+
1238+
HnswVectorIndex original(params, /*copy_vector=*/true);
1239+
SeedHnswIndex(original, kN, kDim, /*rng_seed=*/42);
1240+
std::vector<HnswNodeData> nodes = SnapshotHnswNodes(original);
1241+
1242+
int global_max_level = -1;
1243+
std::optional<uint32_t> low_level_internal_id;
1244+
for (const auto& n : nodes) {
1245+
global_max_level = std::max(global_max_level, n.level);
1246+
if (!low_level_internal_id && n.level == 0)
1247+
low_level_internal_id = n.internal_id;
1248+
}
1249+
ASSERT_GT(global_max_level, 0) << "test setup: need a multi-level graph";
1250+
ASSERT_TRUE(low_level_internal_id.has_value()) << "test setup: need a level-0 node";
1251+
1252+
HnswIndexMetadata forged_metadata{.enterpoint_node = *low_level_internal_id};
1253+
1254+
HnswVectorIndex restored(params, /*copy_vector=*/true);
1255+
ASSERT_TRUE(restored.RestoreFromNodes(nodes, forged_metadata));
1256+
1257+
EXPECT_EQ(restored.GetMaxLevel(), 0)
1258+
<< "maxlevel_ must equal entry-point level; got " << restored.GetMaxLevel()
1259+
<< " while node set max level=" << global_max_level;
1260+
}
1261+
1262+
// Malformed/mismatched metadata (entry point not in serialized node set) must
1263+
// fail restoration gracefully — returning false — instead of SIGABRT'ing via
1264+
// CHECK. Callers then rebuild the index from the keyspace.
1265+
TEST(HnswRestoreInvariant, MissingEntrypointFailsGracefully) {
1266+
constexpr size_t kDim = 4;
1267+
constexpr size_t kN = 10;
1268+
1269+
InitTLSearchMR(PMR_NS::get_default_resource());
1270+
absl::Cleanup cleanup = [] { InitTLSearchMR(nullptr); };
1271+
1272+
SchemaField::VectorParams params;
1273+
params.use_hnsw = true;
1274+
params.dim = kDim;
1275+
params.sim = VectorSimilarity::L2;
1276+
params.capacity = kN;
1277+
params.hnsw_m = 16;
1278+
params.hnsw_ef_construction = 200;
1279+
1280+
HnswVectorIndex original(params, /*copy_vector=*/true);
1281+
SeedHnswIndex(original, kN, kDim, /*rng_seed=*/7);
1282+
std::vector<HnswNodeData> nodes = SnapshotHnswNodes(original);
1283+
1284+
HnswIndexMetadata bad_metadata{.enterpoint_node = 999999}; // well past any real id
1285+
HnswVectorIndex restored(params, /*copy_vector=*/true);
1286+
EXPECT_FALSE(restored.RestoreFromNodes(nodes, bad_metadata));
1287+
}
1288+
12121289
// Regression: in borrowed mode (copy_vector=false), Remove marks the node deleted
12131290
// but hnswlib still traverses it and dereferences its data pointer. If the external
12141291
// data is freed (as happens after DEL), the pointer dangles. The fix in DoRemove

0 commit comments

Comments
 (0)