Skip to content

Commit 5f36f3c

Browse files
committed
Fix omega build and optimize regressions
1 parent fceebba commit 5f36f3c

7 files changed

Lines changed: 72 additions & 75 deletions

File tree

src/core/algorithm/hnsw_rabitq/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ cc_library(
1515
NAME core_knn_hnsw_rabitq
1616
STATIC SHARED STRICT ALWAYS_LINK
1717
SRCS *.cc
18-
LIBS core_framework core_utility rabitqlib sparsehash
18+
LIBS core_framework core_utility core_knn_cluster rabitqlib sparsehash
1919
INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm
2020
VERSION "${PROXIMA_ZVEC_VERSION}"
2121
)

src/core/mixed_reducer/mixed_streamer_reducer.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include "mixed_streamer_reducer.h"
15+
#include <algorithm>
1516
#include <ailego/pattern/defer.h>
1617
#include <utility/sparse_utility.h>
1718
#include <zvec/ailego/utility/file_helper.h>
@@ -141,15 +142,16 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) {
141142
ailego::ElapsedTime timer;
142143

143144

144-
std::vector<int> add_results(num_of_add_threads_, -1);
145+
const size_t add_thread_count = enable_pk_rewrite_ ? 1 : num_of_add_threads_;
146+
std::vector<int> add_results(add_thread_count, -1);
145147
auto add_group = thread_pool_->make_group();
146148

147149
std::vector<int> read_results(streamers_.size(), -1);
148150
// TODO: use id instead of key
149151
uint32_t id_offset = 0, next_id = 0;
150152

151153
if (is_sparse_) {
152-
for (size_t i = 0; i < num_of_add_threads_; i++) {
154+
for (size_t i = 0; i < add_thread_count; i++) {
153155
add_group->submit(ailego::Closure::New(
154156
this, &MixedStreamerReducer::add_sparse_vec, &add_results[i]));
155157
}
@@ -162,7 +164,7 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) {
162164

163165
sparse_mt_list_.done();
164166
} else {
165-
for (size_t i = 0; i < num_of_add_threads_; i++) {
167+
for (size_t i = 0; i < add_thread_count; i++) {
166168
add_group->submit(ailego::Closure::New(
167169
this, &MixedStreamerReducer::add_vec, &add_results[i]));
168170
// add_vec(&add_results[i]);
@@ -304,6 +306,7 @@ int MixedStreamerReducer::read_vec(size_t source_streamer_index,
304306

305307
IndexProvider::Pointer provider = streamer->create_provider();
306308
IndexProvider::Iterator::Pointer iterator = provider->create_iterator();
309+
std::vector<std::pair<uint32_t, std::vector<uint8_t>>> pending_items;
307310

308311
while (iterator->is_valid()) {
309312
if (stop_flag_ != nullptr && stop_flag_->load(std::memory_order_relaxed)) {
@@ -332,13 +335,19 @@ int MixedStreamerReducer::read_vec(size_t source_streamer_index,
332335
memcpy(bytes.data(), iterator->data(), bytes.size());
333336
}
334337

335-
// TODO: use id instead of key
336-
if (!mt_list_.produce(VectorItem((*next_id)++, std::move(bytes)))) {
337-
LOG_ERROR("Produce vector to queue failed. key[%lu]",
338-
(size_t)iterator->key());
338+
pending_items.emplace_back(iterator->key() + id_offset, std::move(bytes));
339+
iterator->next();
340+
}
341+
342+
std::sort(pending_items.begin(), pending_items.end(),
343+
[](const auto &lhs, const auto &rhs) {
344+
return lhs.first < rhs.first;
345+
});
346+
for (auto &item : pending_items) {
347+
if (!mt_list_.produce(VectorItem((*next_id)++, std::move(item.second)))) {
348+
LOG_ERROR("Produce vector to queue failed. key[%u]", item.first);
339349
return IndexError_Runtime;
340350
}
341-
iterator->next();
342351
}
343352
return 0;
344353
}
@@ -508,6 +517,7 @@ int MixedStreamerReducer::read_sparse_vec(size_t source_streamer_index,
508517
streamer->create_sparse_provider();
509518
IndexStreamer::SparseProvider::Iterator::Pointer iterator =
510519
provider->create_iterator();
520+
std::vector<SparseVectorItem> pending_items;
511521

512522
while (iterator->is_valid()) {
513523
if (stop_flag_ != nullptr && stop_flag_->load(std::memory_order_relaxed)) {
@@ -547,15 +557,24 @@ int MixedStreamerReducer::read_sparse_vec(size_t source_streamer_index,
547557
memcpy(sparse_indices.data(), iterator->sparse_indices(),
548558
sparse_indices.size() * sizeof(uint32_t));
549559

550-
// TODO: use id instead of key
551-
if (!sparse_mt_list_.produce(SparseVectorItem((*next_id)++,
552-
std::move(sparse_indices),
553-
std::move(sparse_values)))) {
560+
pending_items.emplace_back(iterator->key() + id_offset,
561+
std::move(sparse_indices),
562+
std::move(sparse_values));
563+
iterator->next();
564+
}
565+
566+
std::sort(pending_items.begin(), pending_items.end(),
567+
[](const SparseVectorItem &lhs, const SparseVectorItem &rhs) {
568+
return lhs.pkey_ < rhs.pkey_;
569+
});
570+
for (auto &item : pending_items) {
571+
if (!sparse_mt_list_.produce(SparseVectorItem(
572+
(*next_id)++, std::move(item.sparse_indices_),
573+
std::move(item.sparse_values_)))) {
554574
LOG_ERROR("Produce vector to queue failed. key[%lu]",
555-
(size_t)iterator->key());
575+
static_cast<size_t>(item.pkey_));
556576
return IndexError_Runtime;
557577
}
558-
iterator->next();
559578
}
560579
return 0;
561580
}

src/db/collection.cc

Lines changed: 4 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -834,60 +834,10 @@ Status CollectionImpl::Optimize(const OptimizeOptions &options) {
834834
return Status::OK();
835835
}
836836

837-
// Step 1: Build vector indexes if not ready
838-
// This ensures indexes are built even for single segments that won't be
839-
// compacted
840-
std::vector<SegmentTask::Ptr> index_build_tasks;
841-
for (auto &segment : persist_segments) {
842-
if (!segment->all_vector_index_ready()) {
843-
// Build all vector indexes for this segment
844-
index_build_tasks.push_back(SegmentTask::CreateCreateVectorIndexTask(
845-
CreateVectorIndexTask{segment, "", nullptr, options.concurrency_}));
846-
}
847-
}
848-
849-
if (!index_build_tasks.empty()) {
850-
LOG_INFO("Building vector indexes for %zu segments",
851-
index_build_tasks.size());
852-
auto s = execute_tasks(index_build_tasks);
853-
CHECK_RETURN_STATUS(s);
854-
855-
// Update segment metadata
856-
std::lock_guard write_lock(write_mtx_);
857-
Version new_version = version_manager_->get_current_version();
858-
859-
for (auto &task : index_build_tasks) {
860-
auto task_info = task->task_info();
861-
if (std::holds_alternative<CreateVectorIndexTask>(task_info)) {
862-
auto create_index_task = std::get<CreateVectorIndexTask>(task_info);
863-
s = new_version.update_persisted_segment_meta(
864-
create_index_task.output_segment_meta_);
865-
CHECK_RETURN_STATUS(s);
866-
}
867-
}
868-
869-
s = version_manager_->apply(new_version);
870-
CHECK_RETURN_STATUS(s);
871-
s = version_manager_->flush();
872-
CHECK_RETURN_STATUS(s);
873-
874-
// Reload indexes in segments
875-
for (auto &task : index_build_tasks) {
876-
auto task_info = task->task_info();
877-
if (std::holds_alternative<CreateVectorIndexTask>(task_info)) {
878-
auto create_index_task = std::get<CreateVectorIndexTask>(task_info);
879-
s = create_index_task.input_segment_->reload_vector_index(
880-
*schema_, create_index_task.output_segment_meta_,
881-
create_index_task.output_vector_indexers_,
882-
create_index_task.output_quant_vector_indexers_);
883-
CHECK_RETURN_STATUS(s);
884-
}
885-
}
886-
887-
LOG_INFO("Completed building vector indexes");
888-
}
889-
890-
// Step 2: build segment compact task
837+
// Build optimize tasks once so compacted segments are merged directly from
838+
// their current per-segment sources. Pre-building filtered vector indexes for
839+
// every persisted segment would shift source row ids before compaction and
840+
// break alignment with scalar row-id filters.
891841
auto delete_store_clone = delete_store_->clone();
892842
auto tasks =
893843
build_compact_task(schema_, persist_segments, options.concurrency_,

src/db/index/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,19 @@ include(${PROJECT_ROOT_DIR}/cmake/option.cmake)
33

44
cc_library(
55
NAME zvec_index STATIC STRICT
6-
SRCS *.cc segment/*.cc column/vector_column/*.cc column/inverted_column/*.cc storage/*.cc storage/wal/*.cc common/*.cc
6+
SRCS *.cc
7+
segment/*.cc
8+
column/vector_column/*.cc
9+
column/inverted_column/*.cc
10+
storage/*.cc
11+
storage/wal/*.cc
12+
common/*.cc
13+
../training/*.cc
714
LIBS zvec_common
815
zvec_proto
916
rocksdb
1017
core_interface
18+
omega
1119
Arrow::arrow_static
1220
Arrow::arrow_compute
1321
Arrow::arrow_dataset

tests/core/interface/omega_training_session_test.cc

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,31 @@ TEST(OmegaTrainingSessionTest, ConsumeArtifactsAggregatesRecordsAndGtCmps) {
3434
first.training_query_id = 0;
3535
first.total_cmps = 13;
3636
first.gt_cmps_per_rank = {3, 7, 11};
37-
first.records.push_back(
38-
TrainingRecord{0, 1, 3, 0.1f, 0.2f, std::vector<float>(7, 1.0f), 1});
37+
TrainingRecord first_record;
38+
first_record.query_id = 0;
39+
first_record.hops_visited = 1;
40+
first_record.cmps_visited = 3;
41+
first_record.dist_1st = 0.1f;
42+
first_record.dist_start = 0.2f;
43+
first_record.traversal_window_stats = {1.0f, 1.0f, 1.0f, 1.0f,
44+
1.0f, 1.0f, 1.0f};
45+
first_record.label = 1;
46+
first.records.push_back(first_record);
3947

4048
QueryTrainingArtifacts second;
4149
second.training_query_id = 2;
4250
second.total_cmps = 21;
4351
second.gt_cmps_per_rank = {5, 9, 15};
44-
second.records.push_back(
45-
TrainingRecord{2, 4, 8, 0.3f, 0.4f, std::vector<float>(7, 2.0f), 0});
52+
TrainingRecord second_record;
53+
second_record.query_id = 2;
54+
second_record.hops_visited = 4;
55+
second_record.cmps_visited = 8;
56+
second_record.dist_1st = 0.3f;
57+
second_record.dist_start = 0.4f;
58+
second_record.traversal_window_stats = {2.0f, 2.0f, 2.0f, 2.0f,
59+
2.0f, 2.0f, 2.0f};
60+
second_record.label = 0;
61+
second.records.push_back(second_record);
4662

4763
session.CollectQueryArtifacts(std::move(first));
4864
session.CollectQueryArtifacts(std::move(second));

tests/db/sqlengine/mock_segment.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,10 @@ class MockSegment : public Segment {
504504
return Status::OK();
505505
}
506506

507+
Status retrain_omega_model() override {
508+
return Status::OK();
509+
}
510+
507511
Status destroy() override {
508512
return Status::OK();
509513
}

0 commit comments

Comments
 (0)