Skip to content

Commit fcdc5f3

Browse files
committed
implemented parallel build
1 parent db869c7 commit fcdc5f3

1 file changed

Lines changed: 186 additions & 66 deletions

File tree

sql/vector_mhnsw.cc

Lines changed: 186 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <scope.h>
2424
#include <my_atomic_wrapper.h>
2525
#include "bloom_filters.h"
26+
#include <thread>
27+
#include <atomic>
2628

2729
// distance can be a little bit < 0 because of fast math
2830
static constexpr float NEAREST = -1.0f;
@@ -480,7 +482,7 @@ class FVectorNode
480482
class MHNSW_Share : public Sql_alloc
481483
{
482484
mysql_mutex_t cache_lock; // for node_cache and stats
483-
mysql_mutex_t node_lock[8];
485+
mysql_mutex_t node_lock[32];
484486

485487
void cache_internal(FVectorNode *node)
486488
{
@@ -666,6 +668,14 @@ class MHNSW_Share : public Sql_alloc
666668
stats.subdist.add(addend.subdist);
667669
mysql_mutex_unlock(&cache_lock);
668670
}
671+
672+
void update_start_parallel(FVectorNode *node)
673+
{
674+
mysql_mutex_lock(&cache_lock);
675+
if (!start || node->max_layer > start->max_layer)
676+
start= node;
677+
mysql_mutex_unlock(&cache_lock);
678+
}
669679
};
670680

671681
/*
@@ -1046,7 +1056,10 @@ int FVectorNode::load_from_record(TABLE *graph)
10461056
void FVectorNode::push_neighbor(size_t layer, FVectorNode *other)
10471057
{
10481058
DBUG_ASSERT(neighbors[layer].num < ctx->max_neighbors(layer));
1049-
neighbors[layer].links[neighbors[layer].num++]= other;
1059+
size_t cur_num= neighbors[layer].num;
1060+
neighbors[layer].links[cur_num]= other;
1061+
std::atomic_thread_fence(std::memory_order_release);
1062+
neighbors[layer].num= cur_num + 1;
10501063
}
10511064

10521065
size_t FVectorNode::tref_len() const { return ctx->tref_len; }
@@ -1070,8 +1083,10 @@ struct MHNSW_param
10701083
Stats acc;
10711084
dgt_mode mode;
10721085
double max_est_size;
1073-
MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer)
1074-
: ctx(ctx), graph(graph), layer(layer)
1086+
MEM_ROOT *mem_root;
1087+
MHNSW_param(MHNSW_Share *ctx, TABLE *graph, int layer, MEM_ROOT *mem_root_arg= nullptr)
1088+
: ctx(ctx), graph(graph), layer(layer),
1089+
mem_root(mem_root_arg ? mem_root_arg : (graph ? graph->in_use->mem_root : nullptr))
10751090
{
10761091
Stats stats;
10771092
ctx->read_stats(&stats);
@@ -1159,7 +1174,7 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target,
11591174
if (pq.init(max_ef, false, Visited::cmp))
11601175
return my_errno= HA_ERR_OUT_OF_MEM;
11611176

1162-
MEM_ROOT * const root= p->graph->in_use->mem_root;
1177+
MEM_ROOT * const root= p->mem_root;
11631178
auto discarded= (Visited**)my_safe_alloca(sizeof(Visited**)*max_neighbor_connections);
11641179
size_t discarded_num= 0;
11651180
Neighborhood &neighbors= target->neighbors[p->layer];
@@ -1173,29 +1188,38 @@ static int select_neighbors(MHNSW_param *p, FVectorNode *target,
11731188
}
11741189
if (extra_candidate)
11751190
pq.push(new (root) Visited(extra_candidate, extra_candidate->distance_to(target->vec)));
1176-
11771191
DBUG_ASSERT(pq.elements());
1178-
neighbors.num= 0;
11791192

1180-
while (pq.elements() && neighbors.num < max_neighbor_connections)
1193+
size_t temp_num = 0;
1194+
FVectorNode **temp_links = (FVectorNode**)my_safe_alloca(sizeof(FVectorNode*) * max_neighbor_connections);
1195+
1196+
while (pq.elements() && temp_num < max_neighbor_connections)
11811197
{
11821198
Visited *vec= pq.pop();
11831199
FVectorNode * const node= vec->node;
11841200
const float target_dista= std::max(32*FLT_EPSILON, vec->distance_to_target);
11851201
bool discard= false;
1186-
for (size_t i=0; i < neighbors.num; i++)
1187-
if ((discard= node->distance_greater_than(neighbors.links[i]->vec,
1202+
for (size_t i=0; i < temp_num; i++)
1203+
if ((discard= node->distance_greater_than(temp_links[i]->vec,
11881204
target_dista, p->mode, &p->acc) < target_dista))
11891205
break;
11901206
if (!discard)
1191-
target->push_neighbor(p->layer, node);
1192-
else if (discarded_num + neighbors.num < max_neighbor_connections)
1207+
temp_links[temp_num++]= node;
1208+
else if (discarded_num + temp_num < max_neighbor_connections)
11931209
discarded[discarded_num++]= vec;
11941210
}
11951211

1196-
for (size_t i=0; i < discarded_num && neighbors.num < max_neighbor_connections; i++)
1197-
target->push_neighbor(p->layer, discarded[i]->node);
1212+
for (size_t i= 0; i < discarded_num && temp_num < max_neighbor_connections; i++)
1213+
temp_links[temp_num++]= discarded[i]->node;
11981214

1215+
// Publish the new neighbors atomically
1216+
for (size_t i= 0; i < temp_num; i++)
1217+
neighbors.links[i]= temp_links[i];
1218+
1219+
std::atomic_thread_fence(std::memory_order_release);
1220+
neighbors.num= temp_num;
1221+
1222+
my_safe_afree(temp_links, sizeof(FVectorNode*) * max_neighbor_connections);
11991223
my_safe_afree(discarded, sizeof(Visited**)*max_neighbor_connections);
12001224
return 0;
12011225
}
@@ -1256,21 +1280,29 @@ int FVectorNode::save(TABLE *graph)
12561280
static int update_second_degree_neighbors(MHNSW_param *p, FVectorNode *node)
12571281
{
12581282
const uint max_neighbors= p->ctx->max_neighbors(p->layer);
1259-
// it seems that one could update nodes in the gref order
1260-
// to avoid InnoDB deadlocks, but it produces no noticeable effect
1261-
for (size_t i=0; i < node->neighbors[p->layer].num; i++)
1283+
const bool bulk= p->ctx->bulk_active;
1284+
1285+
for (size_t i= 0; i < node->neighbors[p->layer].num; i++)
12621286
{
12631287
FVectorNode *neigh= node->neighbors[p->layer].links[i];
1288+
uint ticket= 0;
1289+
if (bulk)
1290+
ticket= p->ctx->lock_node(neigh);
1291+
12641292
Neighborhood &neighneighbors= neigh->neighbors[p->layer];
1293+
int err= 0;
12651294
if (neighneighbors.num < max_neighbors)
12661295
neigh->push_neighbor(p->layer, node);
12671296
else
1268-
if (int err= select_neighbors(p, neigh, neighneighbors, node,
1269-
max_neighbors))
1270-
return err;
1271-
if (!p->ctx->bulk_active)
1272-
if (int err= neigh->save(p->graph))
1273-
return err;
1297+
err= select_neighbors(p, neigh, neighneighbors, node, max_neighbors);
1298+
1299+
if (bulk)
1300+
p->ctx->unlock_node(ticket);
1301+
else if (!err)
1302+
err= neigh->save(p->graph);
1303+
1304+
if (err)
1305+
return err;
12741306
}
12751307
return 0;
12761308
}
@@ -1293,7 +1325,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
12931325
{
12941326
DBUG_ASSERT(inout->num > 0);
12951327

1296-
MEM_ROOT * const root= p->graph->in_use->mem_root;
1328+
MEM_ROOT * const root= p->mem_root;
12971329
Queue<Visited> candidates, best;
12981330
bool skip_deleted;
12991331
uint ef= result_size;
@@ -1343,6 +1375,7 @@ static int search_layer(MHNSW_param *p, const FVector *target, float threshold,
13431375
visited.flush();
13441376

13451377
Neighborhood &neighbors= cur.node->neighbors[p->layer];
1378+
std::atomic_thread_fence(std::memory_order_acquire);
13461379
FVectorNode **links= neighbors.links, **end= links + neighbors.num;
13471380
for (; links < end; links+= 8)
13481381
{
@@ -1513,6 +1546,78 @@ struct MHNSW_Bulk_context : public Sql_alloc {
15131546
uint8_t current_max_layer;
15141547
};
15151548

1549+
1550+
1551+
struct BulkBuildThreadArg
1552+
{
1553+
MHNSW_Bulk_context *bulk;
1554+
uint start_idx;
1555+
uint end_idx;
1556+
int error;
1557+
};
1558+
1559+
1560+
static void *bulk_build_thread(void *param)
1561+
{
1562+
my_thread_init();
1563+
SCOPE_EXIT([]() { my_thread_end(); });
1564+
1565+
BulkBuildThreadArg *arg= (BulkBuildThreadArg*) param;
1566+
MHNSW_Bulk_context *bulk= arg->bulk;
1567+
MHNSW_Share *ctx= bulk->ctx;
1568+
1569+
MEM_ROOT thread_root;
1570+
init_alloc_root(PSI_INSTRUMENT_MEM, &thread_root, 256*1024, 0, MYF(0));
1571+
SCOPE_EXIT([&thread_root]() { free_root(&thread_root, MYF(0)); });
1572+
1573+
for (uint i = arg->start_idx; i < arg->end_idx; i++)
1574+
{
1575+
FVectorNode *target= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**);
1576+
const uint8_t max_layer= ctx->start->max_layer;
1577+
uint8_t target_layer= target->max_layer;
1578+
1579+
MHNSW_param p(ctx, nullptr, max_layer, &thread_root);
1580+
p.acc.graph_size= 1;
1581+
1582+
const size_t max_found= ctx->max_neighbors(0);
1583+
Neighborhood candidates;
1584+
candidates.init((FVectorNode**)alloc_root(&thread_root, sizeof(FVectorNode*) * (max_found + 8)), max_found);
1585+
candidates.links[candidates.num++]= ctx->start;
1586+
1587+
for (; p.layer > target_layer; p.layer--)
1588+
{
1589+
if ((arg->error= search_layer(&p, target->vec, NEAREST, 1, &candidates, false)))
1590+
return nullptr;
1591+
}
1592+
1593+
for (; p.layer >= 0; p.layer--)
1594+
{
1595+
uint max_neighbors= ctx->max_neighbors(p.layer);
1596+
if ((arg->error= search_layer(&p, target->vec, NEAREST, max_neighbors, &candidates, true)))
1597+
return nullptr;
1598+
if ((arg->error= select_neighbors(&p, target, candidates, 0, max_neighbors)))
1599+
return nullptr;
1600+
}
1601+
1602+
ctx->add_to_stats(p.acc);
1603+
1604+
for (p.layer= target_layer; p.layer >= 0; p.layer--)
1605+
{
1606+
if ((arg->error= update_second_degree_neighbors(&p, target)))
1607+
return nullptr;
1608+
}
1609+
1610+
if (target_layer > max_layer)
1611+
{
1612+
ctx->update_start_parallel(target);
1613+
}
1614+
1615+
free_root(&thread_root, MYF(MY_MARK_BLOCKS_FREE));
1616+
}
1617+
1618+
return nullptr;
1619+
}
1620+
15161621
int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows)
15171622
{
15181623
TABLE *graph= table->hlindex;
@@ -1551,6 +1656,18 @@ int mhnsw_bulk_insert_begin(TABLE *table, KEY *keyinfo, ha_rows rows)
15511656
return 0;
15521657
}
15531658

1659+
uint N= std::thread::hardware_concurrency();
1660+
if (N <= 1)
1661+
{
1662+
push_warning_printf(table->in_use, Sql_condition::WARN_LEVEL_NOTE,
1663+
ER_UNKNOWN_ERROR,
1664+
"MHNSW: Bulk insert disabled because available thread count (%u) is <= 1. "
1665+
"Falling back to normal insert.",
1666+
N);
1667+
ctx->release(table);
1668+
return 0;
1669+
}
1670+
15541671
MHNSW_Bulk_context *bulk= new (table->in_use->mem_root) MHNSW_Bulk_context();
15551672
if (!bulk)
15561673
{
@@ -1631,7 +1748,6 @@ int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo)
16311748
if (!graph->context)
16321749
return 0;
16331750

1634-
THD *thd= table->in_use;
16351751
MHNSW_Bulk_context *bulk= (MHNSW_Bulk_context*)graph->context;
16361752

16371753
DBUG_ASSERT(graph);
@@ -1645,59 +1761,63 @@ int mhnsw_bulk_insert_end(TABLE *table, KEY *keyinfo)
16451761
table->hlindex->context= nullptr;
16461762
});
16471763

1648-
for (uint i= 0; i < bulk->nodes.elements; i++)
1649-
{
1650-
FVectorNode *target= *(FVectorNode**)dynamic_element(&bulk->nodes, i, FVectorNode**);
1764+
if (bulk->nodes.elements == 0)
1765+
return 0;
16511766

1652-
if (!ctx->start)
1653-
{
1654-
ctx->start= target;
1655-
continue;
1656-
}
1767+
FVectorNode *first_target= *(FVectorNode**)dynamic_element(&bulk->nodes, 0, FVectorNode**);
1768+
ctx->start= first_target;
16571769

1658-
MEM_ROOT_SAVEPOINT memroot_sv;
1659-
root_make_savepoint(thd->mem_root, &memroot_sv);
1660-
SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); });
1770+
uint N= std::thread::hardware_concurrency();
1771+
uint total_nodes= bulk->nodes.elements - 1;
1772+
uint workers= std::min(N, total_nodes);
16611773

1662-
const uint8_t max_layer= ctx->start->max_layer;
1663-
uint8_t target_layer= target->max_layer;
1774+
pthread_t *threads= (pthread_t*) my_malloc(PSI_INSTRUMENT_MEM, sizeof(pthread_t) * workers, MYF(MY_WME));
1775+
BulkBuildThreadArg *args= (BulkBuildThreadArg*) my_malloc(PSI_INSTRUMENT_MEM, sizeof(BulkBuildThreadArg) * workers, MYF(MY_WME));
1776+
SCOPE_EXIT([threads, args]() {
1777+
my_free(threads);
1778+
my_free(args);
1779+
});
1780+
if (!threads || !args)
1781+
{
1782+
return HA_ERR_OUT_OF_MEM;
1783+
}
16641784

1665-
MHNSW_param p(ctx, graph, max_layer);
1666-
p.acc.graph_size= 1;
1785+
uint chunk_size = total_nodes / workers;
1786+
uint remainder = total_nodes % workers;
1787+
uint current_start = 1;
16671788

1668-
const size_t max_found= ctx->max_neighbors(0);
1669-
Neighborhood candidates;
1670-
candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
1671-
candidates.links[candidates.num++]= ctx->start;
1789+
uint workers_spawned= 0;
16721790

1673-
for (; p.layer > target_layer; p.layer--)
1674-
{
1675-
if (int err= search_layer(&p, target->vec, NEAREST, 1, &candidates, false))
1676-
return err;
1677-
}
1791+
for (uint i= 0; i < workers; i++)
1792+
{
1793+
uint count = chunk_size + (i == 0 ? remainder : 0);
1794+
args[i].bulk= bulk;
1795+
args[i].start_idx = current_start;
1796+
args[i].end_idx = current_start + count;
1797+
args[i].error= 0;
1798+
current_start += count;
16781799

1679-
for (; p.layer >= 0; p.layer--)
1800+
int err= mysql_thread_create(0, &threads[i], nullptr, bulk_build_thread, &args[i]);
1801+
if (err)
16801802
{
1681-
uint max_neighbors= ctx->max_neighbors(p.layer);
1682-
if (int err= search_layer(&p, target->vec, NEAREST, max_neighbors,
1683-
&candidates, true))
1684-
return err;
1685-
if (int err= select_neighbors(&p, target, candidates, 0, max_neighbors))
1686-
return err;
1803+
for (uint j= 0; j < workers_spawned; j++)
1804+
pthread_join(threads[j], nullptr);
1805+
return err;
16871806
}
1807+
workers_spawned++;
1808+
}
16881809

1689-
ctx->add_to_stats(p.acc);
1690-
1691-
if (target_layer > max_layer)
1692-
ctx->start= target;
1693-
1694-
for (p.layer= target_layer; p.layer >= 0; p.layer--)
1695-
{
1696-
if (int err= update_second_degree_neighbors(&p, target))
1697-
return err;
1698-
}
1810+
int final_err= 0;
1811+
for (uint i= 0; i < workers_spawned; i++)
1812+
{
1813+
pthread_join(threads[i], nullptr);
1814+
if (args[i].error && !final_err)
1815+
final_err= args[i].error;
16991816
}
17001817

1818+
if (final_err)
1819+
return final_err;
1820+
17011821
graph->file->ha_start_bulk_insert(bulk->nodes.elements, 0);
17021822
bool bulk_base_started= true;
17031823
SCOPE_EXIT([graph, &bulk_base_started](){

0 commit comments

Comments
 (0)