diff --git a/mysql-test/main/vector_bulk.combinations b/mysql-test/main/vector_bulk.combinations new file mode 100644 index 0000000000000..ce2a60320816a --- /dev/null +++ b/mysql-test/main/vector_bulk.combinations @@ -0,0 +1,9 @@ +[innodb] +innodb +default-storage-engine=innodb + +[myisam] +default-storage-engine=myisam + +[aria] +default-storage-engine=aria diff --git a/mysql-test/main/vector_bulk.result b/mysql-test/main/vector_bulk.result new file mode 100644 index 0000000000000..2cb6c88598e56 --- /dev/null +++ b/mysql-test/main/vector_bulk.result @@ -0,0 +1,44 @@ +# +# Test memory budget fallback during MHNSW bulk insert +# +# 1. Normal bulk insert works when memory budget is large enough +create table t1 (id int auto_increment primary key, v vector(3) not null); +insert into t1 (v) values (x'000000000000000000000000'), +(x'0000803f0000000000000000'), +(x'000000000000803f00000000'), +(x'00000000000000000000803f'); +alter table t1 add vector index (v); +show warnings; +Level Code Message +# Test search using the successfully bulk-built index +select id, vec_distance_euclidean(v, x'0000803f0000000000000000') d from t1 order by d limit 2; +id d +2 0 +1 1 +drop table t1; +# 2. Bulk insert falls back to normal insert when mhnsw_max_cache_size is small +set @old_cache_size= @@global.mhnsw_max_cache_size; +set global mhnsw_max_cache_size= 1048576; +set max_recursive_iterations= 300; +create table t1 (id int auto_increment primary key, v vector(3000) not null); +insert into t1 (v) +with recursive cte as ( +select 1 as n +union all +select n + 1 from cte where n < 115 +) +select repeat(x'00', 12000) from cte; +insert into t1 (id, v) values (999, concat(repeat(x'00', 11996), x'0000803f')); +# Adding index with small cache size should trigger fallback and show a warning +alter table t1 add vector index (v) m=200; +Warnings: +Note 1105 MHNSW: Bulk insert disabled because estimated memory usage (estimated_mem) exceeds mhnsw_max_cache_size (1048576). Falling back to normal insert. +show warnings; +Level Code Message +Note 1105 MHNSW: Bulk insert disabled because estimated memory usage (estimated_mem) exceeds mhnsw_max_cache_size (1048576). Falling back to normal insert. +# Test search using the fallback-built index to ensure it is healthy and correct +select id, vec_distance_euclidean(v, concat(repeat(x'00', 11996), x'0000803f')) d from t1 order by d limit 1; +id d +999 0 +drop table t1; +set global mhnsw_max_cache_size= @old_cache_size; diff --git a/mysql-test/main/vector_bulk.test b/mysql-test/main/vector_bulk.test new file mode 100644 index 0000000000000..7ecd1ca37374c --- /dev/null +++ b/mysql-test/main/vector_bulk.test @@ -0,0 +1,51 @@ +--echo # +--echo # Test memory budget fallback during MHNSW bulk insert +--echo # + +--echo # 1. Normal bulk insert works when memory budget is large enough +create table t1 (id int auto_increment primary key, v vector(3) not null); +insert into t1 (v) values (x'000000000000000000000000'), + (x'0000803f0000000000000000'), + (x'000000000000803f00000000'), + (x'00000000000000000000803f'); +alter table t1 add vector index (v); +show warnings; + +--echo # Test search using the successfully bulk-built index +--replace_regex /(\.\d{5})\d+/\1/ +select id, vec_distance_euclidean(v, x'0000803f0000000000000000') d from t1 order by d limit 2; + +drop table t1; + +--echo # 2. Bulk insert falls back to normal insert when mhnsw_max_cache_size is small +set @old_cache_size= @@global.mhnsw_max_cache_size; +set global mhnsw_max_cache_size= 1048576; +set max_recursive_iterations= 300; + +create table t1 (id int auto_increment primary key, v vector(3000) not null); + +# Insert 115 rows of 3000-dimensional vectors. +# Total size: 115 * ~9288 bytes/row (with M=200) = ~1.06 MB, exceeding 1MB. +insert into t1 (v) +with recursive cte as ( + select 1 as n + union all + select n + 1 from cte where n < 115 +) +select repeat(x'00', 12000) from cte; + +# Insert a unique search target vector with fixed ID 999 to guarantee deterministic results across all engines +insert into t1 (id, v) values (999, concat(repeat(x'00', 11996), x'0000803f')); + +--echo # Adding index with small cache size should trigger fallback and show a warning +--replace_regex /usage \(\d+\)/usage (estimated_mem)/ +alter table t1 add vector index (v) m=200; +--replace_regex /usage \(\d+\)/usage (estimated_mem)/ +show warnings; + +--echo # Test search using the fallback-built index to ensure it is healthy and correct +--replace_regex /(\.\d{5})\d+/\1/ +select id, vec_distance_euclidean(v, concat(repeat(x'00', 11996), x'0000803f')) d from t1 order by d limit 1; + +drop table t1; +set global mhnsw_max_cache_size= @old_cache_size; diff --git a/sql/sql_base.cc b/sql/sql_base.cc index 95abfa798bf65..7211c6fc43921 100644 --- a/sql/sql_base.cc +++ b/sql/sql_base.cc @@ -10145,11 +10145,15 @@ int TABLE::unlock_hlindexes() int TABLE::hlindexes_on_insert() { - DBUG_ASSERT(s->hlindexes() == (hlindex != NULL)); - if (hlindex && hlindex->in_use) - if (int err= mhnsw_insert(this, key_info + s->keys)) - return err; - return 0; + DBUG_ASSERT(s->hlindexes() == (hlindex != NULL)); + if (hlindex && hlindex->in_use) + { + if (hlindex->bulk_insert_active) + return mhnsw_bulk_insert_row(this, key_info + s->keys); + else + return mhnsw_insert(this, key_info + s->keys); + } + return 0; } int TABLE::hlindexes_on_update() @@ -10208,3 +10212,36 @@ int TABLE::hlindex_read_end() { return mhnsw_read_end(this); } + +int TABLE::hlindexes_bulk_insert_begin(ha_rows rows) +{ + if (s->hlindexes()) + { + if (!hlindex || !hlindex->in_use) + if (int err= open_hlindexes_for_write()) + return err; + + if (hlindex && hlindex->in_use) + { + int err= mhnsw_bulk_insert_begin(this, key_info + s->keys, rows); + if (err) + { + hlindex->bulk_insert_active= false; + return err; + } + if (hlindex->context) + hlindex->bulk_insert_active= true; + } + } + return 0; +} + +int TABLE::hlindexes_bulk_insert_end() +{ + if (hlindex && hlindex->in_use) + { + hlindex->bulk_insert_active= false; + return mhnsw_bulk_insert_end(this, key_info + s->keys); + } + return 0; +} diff --git a/sql/sql_table.cc b/sql/sql_table.cc index b07f16102a6bb..7791257b6ae13 100644 --- a/sql/sql_table.cc +++ b/sql/sql_table.cc @@ -12616,6 +12616,7 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, bool make_unversioned= from->versioned() && !to->versioned(); bool keep_versioned= from->versioned() && to->versioned(); bool bulk_insert_started= 0; + bool hlindex_bulk_started= 0; Field *to_row_start= NULL, *to_row_end= NULL, *from_row_end= NULL; MYSQL_TIME query_start; DBUG_ENTER("copy_data_between_tables"); @@ -12662,11 +12663,20 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, from->file->info(HA_STATUS_VARIABLE); to->file->extra(HA_EXTRA_PREPARE_FOR_ALTER_TABLE); - if (!to->s->long_unique_table && !to->s->hlindexes()) + + if (!to->s->long_unique_table) { - to->file->ha_start_bulk_insert(from->file->stats.records, - ignore ? 0 : HA_CREATE_UNIQUE_INDEX_BY_SORT); - bulk_insert_started= 1; + if (to->s->hlindexes()) + { + if (to->hlindexes_bulk_insert_begin(from->file->stats.records) == 0) + hlindex_bulk_started= 1; + } + if (!to->s->hlindexes() || hlindex_bulk_started) + { + to->file->ha_start_bulk_insert(from->file->stats.records, + ignore ? 0 : HA_CREATE_UNIQUE_INDEX_BY_SORT); + bulk_insert_started= 1; + } } mysql_stage_set_work_estimated(thd->m_stage_progress_psi, from->file->stats.records); List_iterator it(alter_info->create_list); @@ -12999,6 +13009,14 @@ copy_data_between_tables(THD *thd, TABLE *from, TABLE *to, } bulk_insert_started= 0; + if (hlindex_bulk_started && to->hlindexes_bulk_insert_end() && error <= 0) + { + if (!thd->is_error()) + to->file->print_error(my_errno, MYF(0)); + error= 1; + } + hlindex_bulk_started=0; + if (error <= 0 && !to->s->hlindexes()) { Abort_on_warning_instant_set save_abort_on_warning(thd, false); diff --git a/sql/table.h b/sql/table.h index 0713341840127..9d5a47cbe2b55 100644 --- a/sql/table.h +++ b/sql/table.h @@ -1632,6 +1632,7 @@ struct TABLE */ bool alias_name_used; /* true if table_name is alias */ bool get_fields_in_item_tree; /* Signal to fix_field */ + bool bulk_insert_active=false; /* mhnsw bulk_insert_started flag */ private: bool m_needs_reopen; bool created; /* For tmp tables. TRUE <=> tmp table was actually created.*/ @@ -1875,6 +1876,8 @@ struct TABLE int hlindexes_on_update(); int hlindexes_on_delete(const uchar *buf); int hlindexes_on_delete_all(bool truncate); + int hlindexes_bulk_insert_begin(ha_rows rows); + int hlindexes_bulk_insert_end(); int unlock_hlindexes(); void prepare_triggers_for_insert_stmt_or_event(); diff --git a/sql/vector_mhnsw.cc b/sql/vector_mhnsw.cc index c480c36c7e7ad..4889c42f074c5 100644 --- a/sql/vector_mhnsw.cc +++ b/sql/vector_mhnsw.cc @@ -23,6 +23,9 @@ #include #include #include "bloom_filters.h" +#include +#include +#include // distance can be a little bit < 0 because of fast math static constexpr float NEAREST = -1.0f; @@ -393,9 +396,11 @@ struct Neighborhood: public Sql_alloc { FVectorNode **links; size_t num; + Atomic_relaxed num_bulk; FVectorNode **init(FVectorNode **ptr, size_t n) { num= 0; + num_bulk.store(0, std::memory_order_relaxed); links= ptr; n= MY_ALIGN(n, 8); bzero(ptr, n*sizeof(*ptr)); @@ -480,7 +485,7 @@ class FVectorNode class MHNSW_Share : public Sql_alloc { mysql_mutex_t cache_lock; // for node_cache and stats - mysql_mutex_t node_lock[8]; + mysql_mutex_t node_lock[32]; // XXX how to choose what's the best value here? void cache_internal(FVectorNode *node) { @@ -510,7 +515,7 @@ class MHNSW_Share : public Sql_alloc const uint M; metric_type metric; bool use_subdist; - + bool bulk_active=false; MHNSW_Share(TABLE *t) : tref_len(t->file->ref_length), gref_len(t->hlindex->file->ref_length), M(static_cast(t->s->key_info[t->s->keys].option_struct->M)), @@ -666,6 +671,8 @@ class MHNSW_Share : public Sql_alloc stats.subdist.add(addend.subdist); mysql_mutex_unlock(&cache_lock); } + + }; /* @@ -1012,6 +1019,8 @@ int FVectorNode::load_from_record(TABLE *graph) FVector *vec_ptr= FVector::align_ptr(tref() + tref_len()); memcpy(vec_ptr->data(), v->ptr(), v->length()); vec_ptr->postprocess(ctx->use_subdist, ctx->vec_len); + if (ctx->metric == COSINE) + vec_ptr->abs2= 0.5f; longlong layer= graph->field[FIELD_LAYER]->val_int(); if (layer > 100) // 10e30 nodes at M=2, more at larger M's @@ -1044,7 +1053,11 @@ int FVectorNode::load_from_record(TABLE *graph) void FVectorNode::push_neighbor(size_t layer, FVectorNode *other) { DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer)); - neighbors[layer].links[neighbors[layer].num++]= other; + size_t cur_num= neighbors[layer].num; + neighbors[layer].links[cur_num]= other; + neighbors[layer].num= cur_num + 1; + if (ctx->bulk_active) + neighbors[layer].num_bulk.store(cur_num + 1, std::memory_order_release); } size_t FVectorNode::tref_len() const { return ctx->tref_len; } @@ -1068,8 +1081,10 @@ struct MHNSW_param Stats acc; dgt_mode mode; double max_est_size; - MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer) - : ctx(ctx), graph(graph), layer(layer) + MEM_ROOT *mem_root; + MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer, MEM_ROOT *mem_root_arg= nullptr) + : ctx(ctx), graph(graph), layer(layer), + mem_root(mem_root_arg ? mem_root_arg : (graph ? graph->in_use->mem_root : nullptr)) { Stats stats; ctx->read_stats(&stats); @@ -1157,7 +1172,7 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target, if (pq.init(max_ef, false, Visited::cmp)) return my_errno= HA_ERR_OUT_OF_MEM; - MEM_ROOT * const root= p->graph->in_use->mem_root; + MEM_ROOT * const root= p->mem_root; auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections); size_t discarded_num= 0; Neighborhood &neighbors= target->neighbors[p->layer]; @@ -1171,29 +1186,39 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target, } if (extra_candidate) pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target->vec))); - DBUG_ASSERT(pq.elements()); - neighbors.num= 0; - while (pq.elements() && neighbors.num < max_neighbor_connections) + size_t temp_num = 0; + FVectorNode **temp_links = (FVectorNode**)my_safe_alloca(sizeof(FVectorNode*) * max_neighbor_connections); + + while (pq.elements() && temp_num < max_neighbor_connections) { Visited *vec= pq.pop(); FVectorNode * const node= vec->node; const float target_dista= std::max(32*FLT_EPSILON, vec->distance_to_target); bool discard= false; - for (size_t i=0; i < neighbors.num; i++) - if ((discard= node->distance_greater_than(neighbors.links[i]->vec, + for (size_t i=0; i < temp_num; i++) + if ((discard= node->distance_greater_than(temp_links[i]->vec, target_dista, p->mode, &p->acc) < target_dista)) break; if (!discard) - target->push_neighbor(p->layer, node); - else if (discarded_num + neighbors.num < max_neighbor_connections) + temp_links[temp_num++]= node; + else if (discarded_num + temp_num < max_neighbor_connections) discarded[discarded_num++]= vec; } - for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++) - target->push_neighbor(p->layer, discarded[i]->node); + for (size_t i= 0; i < discarded_num && temp_num < max_neighbor_connections; i++) + temp_links[temp_num++]= discarded[i]->node; + + // Publish the new neighbors atomically + for (size_t i= 0; i < temp_num; i++) + neighbors.links[i]= temp_links[i]; + + neighbors.num= temp_num; + if (p->ctx->bulk_active) + neighbors.num_bulk.store(temp_num, std::memory_order_release); + my_safe_afree(temp_links, sizeof(FVectorNode*) * max_neighbor_connections); my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections); return 0; } @@ -1254,19 +1279,28 @@ int FVectorNode::save(TABLE *graph) static int update_second_degree_neighbors(MHNSW_param *p, FVectorNode *node) { const uint max_neighbors= p->ctx->max_neighbors(p->layer); - // it seems that one could update nodes in the gref order - // to avoid InnoDB deadlocks, but it produces no noticeable effect - for (size_t i=0; i < node->neighbors[p->layer].num; i++) + const bool bulk= p->ctx->bulk_active; + + for (size_t i= 0; i < node->neighbors[p->layer].num; i++) { FVectorNode *neigh= node->neighbors[p->layer].links[i]; + uint ticket= 0; + if (bulk) + ticket= p->ctx->lock_node(neigh); + Neighborhood &neighneighbors= neigh->neighbors[p->layer]; + int err= 0; if (neighneighbors.num < max_neighbors) neigh->push_neighbor(p->layer, node); else - if (int err= select_neighbors(p, neigh, neighneighbors, node, - max_neighbors)) - return err; - if (int err= neigh->save(p->graph)) + err= select_neighbors(p, neigh, neighneighbors, node, max_neighbors); + + if (bulk) + p->ctx->unlock_node(ticket); + else if (!err) + err= neigh->save(p->graph); + + if (err) return err; } return 0; @@ -1290,7 +1324,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, { DBUG_ASSERT(inout->num > 0); - MEM_ROOT * const root= p->graph->in_use->mem_root; + MEM_ROOT * const root= p->mem_root; Queue candidates, best; bool skip_deleted; uint ef= result_size; @@ -1340,7 +1374,11 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold, visited.flush(); Neighborhood &neighbors= cur.node->neighbors[p->layer]; - FVectorNode **links= neighbors.links, **end= links + neighbors.num; + FVectorNode **links= neighbors.links; + size_t cur_num= p->ctx->bulk_active + ? neighbors.num_bulk.load(std::memory_order_acquire) + : neighbors.num; + FVectorNode **end= links + cur_num; for (; links < end; links+= 8) { uint8_t res= visited.seen(links); @@ -1504,6 +1542,337 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo) } +struct MHNSW_Bulk_context : public Sql_alloc { + MHNSW_Share *ctx; + DYNAMIC_ARRAY nodes; + size_t start_node_idx; + uint8_t current_max_layer; +}; + + + +struct BulkBuildThreadArg +{ + MHNSW_Bulk_context *bulk; + size_t start_idx; + size_t end_idx; + int error; +}; + + +static void *bulk_build_thread(void *param) +{ + my_thread_init(); + SCOPE_EXIT([]() { my_thread_end(); }); + + BulkBuildThreadArg *arg= (BulkBuildThreadArg*) param; + MHNSW_Bulk_context *bulk= arg->bulk; + MHNSW_Share *ctx= bulk->ctx; + // Sort this thread's chunk by descending layer + FVectorNode **chunk_start= dynamic_element(&bulk->nodes, arg->start_idx, FVectorNode**); + FVectorNode **chunk_end= dynamic_element(&bulk->nodes, arg->end_idx, FVectorNode**); + std::sort(chunk_start, chunk_end, [](const FVectorNode *a, const FVectorNode *b) { + return a->max_layer > b->max_layer; + }); + + MEM_ROOT thread_root; + init_alloc_root(PSI_INSTRUMENT_MEM, &thread_root, 256*1024, 0, MYF(0)); + SCOPE_EXIT([&thread_root]() { free_root(&thread_root, MYF(0)); }); + + for (size_t i = arg->start_idx; i < arg->end_idx; i++) + { + FVectorNode *target= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + const uint8_t max_layer= ctx->start->max_layer; + uint8_t target_layer= target->max_layer; + + MHNSW_param p(ctx, nullptr, max_layer, &thread_root); + p.acc.graph_size= 1; + + const size_t max_found= ctx->max_neighbors(0); + Neighborhood candidates; + candidates.init((FVectorNode**)alloc_root(&thread_root, sizeof(FVectorNode*) * (max_found + 8)), max_found); + candidates.links[candidates.num++]= ctx->start; + + for (; p.layer > target_layer; p.layer--) + { + if ((arg->error= search_layer(&p, target->vec, NEAREST, 1, &candidates, false))) + return nullptr; + } + + for (; p.layer >= 0; p.layer--) + { + uint max_neighbors= ctx->max_neighbors(p.layer); + if ((arg->error= search_layer(&p, target->vec, NEAREST, max_neighbors, &candidates, true))) + return nullptr; + if ((arg->error= select_neighbors(&p, target, candidates, 0, max_neighbors))) + return nullptr; + } + + ctx->add_to_stats(p.acc); + + for (p.layer= target_layer; p.layer >= 0; p.layer--) + { + if ((arg->error= update_second_degree_neighbors(&p, target))) + return nullptr; + } + + + free_root(&thread_root, MYF(MY_MARK_BLOCKS_FREE)); + } + + return nullptr; +} + +int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows) +{ + TABLE *graph= table->hlindex; + DBUG_ASSERT(graph); + DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); + DBUG_ASSERT(keyinfo->usable_key_parts == 1); + + MHNSW_Share *ctx= nullptr; + int err= MHNSW_Share::acquire(&ctx, table, true); + if (err && err != HA_ERR_END_OF_FILE && err != HA_ERR_KEY_NOT_FOUND) + { + if (ctx) + ctx->release(table); + return err; + } + + if (ctx->vec_len == 0) + ctx->set_lengths(keyinfo->key_part->field->field_length); + + size_t node_alloc_size= sizeof(FVectorNode) + ctx->gref_len + ctx->tref_len + + FVector::alloc_size(ctx->vec_len); + size_t neighborhood_alloc_size= sizeof(Neighborhood) + + sizeof(FVectorNode*) * MY_ALIGN(ctx->M, 4) * 2; + + ulonglong estimated_mem= rows * (sizeof(FVectorNode*) + node_alloc_size + + neighborhood_alloc_size); + + if (estimated_mem > mhnsw_max_cache_size) + { + push_warning_printf(table->in_use, Sql_condition::WARN_LEVEL_NOTE, + ER_UNKNOWN_ERROR, + "MHNSW: Bulk insert disabled because estimated memory usage (%llu) " + "exceeds mhnsw_max_cache_size (%llu). Falling back to normal insert.", + (ulonglong)estimated_mem, (ulonglong)mhnsw_max_cache_size); + ctx->release(table); + return 0; + } + + uint N= std::thread::hardware_concurrency(); + if (N <= 1) + { + push_warning_printf(table->in_use, Sql_condition::WARN_LEVEL_NOTE, + ER_UNKNOWN_ERROR, + "MHNSW: Bulk insert disabled because available thread count (%u) is <= 1. " + "Falling back to normal insert.", + N); + ctx->release(table); + return 0; + } + if (rows < N * 100) + { + ctx->release(table); + return 0; + } + + MHNSW_Bulk_context *bulk= new (table->in_use->mem_root) MHNSW_Bulk_context(); + if (!bulk) + { + ctx->release(table); + return HA_ERR_OUT_OF_MEM; + } + + bulk->ctx= ctx; + + /*we add a 10% margin to avoid reallocations when rows is approximate (InnoDB)*/ + if (my_init_dynamic_array(PSI_INSTRUMENT_MEM, &bulk->nodes, sizeof(FVectorNode*), + rows + rows / 10, rows, MYF(0))) + { + ctx->release(table); + return HA_ERR_OUT_OF_MEM; + } + + bulk->ctx->bulk_active= 1; + DBUG_ASSERT(!bulk->ctx->start); + bulk->current_max_layer= 0; + bulk->start_node_idx= 0; + table->hlindex->context= bulk; + return 0; +} + +int mhnsw_bulk_insert_row(TABLE *table, KEY *keyinfo) +{ + TABLE *graph= table->hlindex; + MHNSW_Bulk_context *bulk= (MHNSW_Bulk_context*)graph->context; + MHNSW_Share *ctx= bulk->ctx; + MY_BITMAP *old_map= dbug_tmp_use_all_columns(table, &table->read_set); + SCOPE_EXIT([table, old_map]() { + dbug_tmp_restore_column_map(&table->read_set, old_map); + }); + + DBUG_ASSERT(graph); + DBUG_ASSERT(bulk); + DBUG_ASSERT(keyinfo->algorithm == HA_KEY_ALG_VECTOR); + DBUG_ASSERT(keyinfo->usable_key_parts == 1); + + Field *vec_field= keyinfo->key_part->field; + String buf, *res= vec_field->val_str(&buf); + + DBUG_ASSERT(vec_field->binary()); + DBUG_ASSERT(vec_field->cmp_type() == STRING_RESULT); + DBUG_ASSERT(res); // ER_INDEX_CANNOT_HAVE_NULL + DBUG_ASSERT(res->length() > 0 && res->length() % 4 == 0); + DBUG_ASSERT(table->file->ref_length <= graph->field[FIELD_TREF]->field_length); + + table->file->position(table->record[0]); + + if (ctx->byte_len == 0) + ctx->set_lengths(res->length()); + + if (ctx->byte_len != res->length()) + return my_errno= HA_ERR_CRASHED; + + const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M); + double log= -std::log(my_rnd(&table->in_use->rand)) * NORMALIZATION_FACTOR; + uint8_t max_layer= bulk->current_max_layer; + uint8_t target_layer= std::min(static_cast(std::floor(log)), max_layer + 1); + + if (bulk->nodes.elements == 0) + target_layer= 0; + + if (target_layer > bulk->current_max_layer) + { + bulk->current_max_layer= target_layer; + bulk->start_node_idx= bulk->nodes.elements; + } + + FVectorNode *node= new (ctx->alloc_node()) + FVectorNode(ctx, table->file->ref, target_layer, res->ptr()); + + if (insert_dynamic(&bulk->nodes, (uchar*)&node)) + return HA_ERR_OUT_OF_MEM; + + return 0; +} + +int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo) +{ + TABLE *graph= table->hlindex; + if (!graph->context) + return 0; + + MHNSW_Bulk_context *bulk= (MHNSW_Bulk_context*)graph->context; + + DBUG_ASSERT(graph); + DBUG_ASSERT(bulk); + + MHNSW_Share *ctx= bulk->ctx; + SCOPE_EXIT([ctx, bulk, table](){ + delete_dynamic(&bulk->nodes); + ctx->bulk_active= 0; + ctx->release(table); + table->hlindex->context= nullptr; + }); + + if (bulk->nodes.elements == 0) + return 0; + + // Swap the start node (highest layer) to index 0 + if (bulk->start_node_idx != 0) + { + FVectorNode **arr= (FVectorNode**)bulk->nodes.buffer; + std::swap(arr[0], arr[bulk->start_node_idx]); + } + ctx->start= *dynamic_element(&bulk->nodes, 0, FVectorNode**); + + // XXX how many threads to use? + uint N= std::thread::hardware_concurrency(); + size_t total_nodes= bulk->nodes.elements - 1; + size_t workers= std::min(N, total_nodes); + + pthread_t *threads= (pthread_t*) my_malloc(PSI_INSTRUMENT_MEM, sizeof(pthread_t) * workers, MYF(MY_WME)); + BulkBuildThreadArg *args= (BulkBuildThreadArg*) my_malloc(PSI_INSTRUMENT_MEM, sizeof(BulkBuildThreadArg) * workers, MYF(MY_WME)); + SCOPE_EXIT([threads, args]() { + my_free(threads); + my_free(args); + }); + if (!threads || !args) + { + return HA_ERR_OUT_OF_MEM; + } + + size_t chunk_size = total_nodes / workers; + size_t remainder = total_nodes % workers; + size_t current_start = 1; + + size_t workers_spawned= 0; + + for (size_t i= 0; i < workers; i++) + { + size_t count = chunk_size + (i == 0 ? remainder : 0); + args[i].bulk= bulk; + args[i].start_idx = current_start; + args[i].end_idx = current_start + count; + args[i].error= 0; + current_start += count; + + int err= mysql_thread_create(0, &threads[i], nullptr, bulk_build_thread, &args[i]); + if (err) + { + for (size_t j= 0; j < workers_spawned; j++) + pthread_join(threads[j], nullptr); + return HA_ERR_OUT_OF_MEM; + } + workers_spawned++; + } + + int final_err= 0; + for (size_t i= 0; i < workers_spawned; i++) + { + pthread_join(threads[i], nullptr); + if (args[i].error && !final_err) + final_err= args[i].error; + } + + if (final_err) + return final_err; + + graph->file->ha_start_bulk_insert(bulk->nodes.elements, 0); + bool bulk_base_started= true; + SCOPE_EXIT([graph, &bulk_base_started](){ + if (bulk_base_started) + graph->file->ha_end_bulk_insert(); + }); + + for (size_t i= 0; i < bulk->nodes.elements; i++) + { + FVectorNode *node= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + if (int err= node->save(graph)) + return err; + } + + bulk_base_started= false; + if (int err= graph->file->ha_end_bulk_insert()) + return err; + + if (int err= graph->file->ha_rnd_init(0)) + return err; + SCOPE_EXIT([graph](){ graph->file->ha_rnd_end(); }); + + // fix neighbors grefs + for (size_t i= 0; i < bulk->nodes.elements; i++) + { + FVectorNode *node= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**); + if (int err= node->save(graph)) + return err; + } + + return 0; +} + struct Search_context: public Sql_alloc { Neighborhood found; diff --git a/sql/vector_mhnsw.h b/sql/vector_mhnsw.h index fbb61e14773f9..e6a8622e3e609 100644 --- a/sql/vector_mhnsw.h +++ b/sql/vector_mhnsw.h @@ -34,6 +34,9 @@ int mhnsw_invalidate(TABLE *table, const uchar *rec, KEY *keyinfo); int mhnsw_delete_all(TABLE *table, KEY *keyinfo, bool truncate); void mhnsw_free(TABLE_SHARE *share); Item_func_vec_distance::distance_kind mhnsw_uses_distance(const TABLE *table, KEY *keyinfo); +int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows); +int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo); +int mhnsw_bulk_insert_row(TABLE *table, KEY *keyinfo); extern ha_create_table_option mhnsw_index_options[]; extern st_plugin_int *mhnsw_plugin;