Skip to content

Commit 476804e

Browse files
committed
Add concurrent insert in performence case
1. Fix concurrent insert memory and process cleanup 2. Add configurable load concurrency for performance cases 3. Make CLI Ctrl+C work by polling has_running() instead of blocking on concurrent.futures.wait(), which swallows SIGINT. 4. Remove perf-case insert from SerialInsertRunner 5. Ignore S608 lint rule and fix formatting Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent 99c3115 commit 476804e

21 files changed

Lines changed: 719 additions & 152 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ lint.ignore = [
125125
"INP001", # TODO
126126
"TID252", # TODO
127127
"N801", "N802", "N815",
128-
"S101", "S108", "S603", "S311",
128+
"S101", "S108", "S603", "S311", "S608",
129129
"PLR2004",
130130
"RUF017",
131131
"C416",

tests/test_concurrent_runner.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Tests for ConcurrentInsertRunner against a running Milvus instance.
2+
3+
Includes:
4+
- Correctness tests (threading & async backends)
5+
- Parameterized benchmark: serial vs concurrent across (batch_size, workers) matrix
6+
7+
NUM_PER_BATCH is set via os.environ before each run. Since runners execute
8+
task() in a spawn subprocess that re-imports config, the env var takes effect.
9+
10+
Requires:
11+
- Milvus running at localhost:19530
12+
- Network access to download OpenAI 50K dataset
13+
14+
Usage:
15+
pytest tests/test_concurrent_runner.py -v -s # correctness tests only
16+
python tests/test_concurrent_runner.py # full benchmark matrix
17+
"""
18+
19+
# ruff: noqa: T201
20+
21+
from __future__ import annotations
22+
23+
import logging
24+
import os
25+
import time
26+
27+
from vectordb_bench.backend.clients import DB
28+
from vectordb_bench.backend.clients.milvus.config import FLATConfig
29+
from vectordb_bench.backend.dataset import Dataset, DatasetSource
30+
from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner, ExecutorBackend
31+
from vectordb_bench.backend.runner.serial_runner import SerialInsertRunner
32+
33+
log = logging.getLogger("vectordb_bench")
34+
log.setLevel(logging.INFO)
35+
36+
DATASET_SIZE = 50_000
37+
38+
39+
# ── Shared helpers ──────────────────────────────────────────────────────
40+
41+
42+
def get_milvus_db(collection_name: str):
43+
return DB.Milvus.init_cls(
44+
dim=1536,
45+
db_config={"uri": "http://localhost:19530", "user": "", "password": ""},
46+
db_case_config=FLATConfig(metric_type="COSINE"),
47+
collection_name=collection_name,
48+
drop_old=True,
49+
)
50+
51+
52+
def prepare_dataset():
53+
dataset = Dataset.OPENAI.manager(DATASET_SIZE)
54+
dataset.prepare(DatasetSource.AliyunOSS)
55+
return dataset
56+
57+
58+
def set_batch_size(batch_size: int) -> None:
59+
os.environ["NUM_PER_BATCH"] = str(batch_size)
60+
61+
62+
def timed_run(runner: SerialInsertRunner | ConcurrentInsertRunner) -> tuple[int, float]:
63+
start = time.perf_counter()
64+
count = runner.run()
65+
return count, time.perf_counter() - start
66+
67+
68+
# ── Correctness tests (pytest) ──────────────────────────────────────────
69+
70+
71+
def test_concurrent_insert_threading():
72+
"""Test concurrent insert with threading backend."""
73+
db = get_milvus_db("test_conc_threading")
74+
runner = ConcurrentInsertRunner(
75+
db=db,
76+
dataset=prepare_dataset(),
77+
normalize=False,
78+
max_workers=4,
79+
backend=ExecutorBackend.THREADING,
80+
)
81+
count = runner.run()
82+
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"
83+
84+
85+
def test_concurrent_insert_async():
86+
"""Test concurrent insert with async backend."""
87+
db = get_milvus_db("test_conc_async")
88+
runner = ConcurrentInsertRunner(
89+
db=db,
90+
dataset=prepare_dataset(),
91+
normalize=False,
92+
max_workers=4,
93+
backend=ExecutorBackend.ASYNC,
94+
)
95+
count = runner.run()
96+
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"
97+
98+
99+
# ── Parameterized benchmark ────────────────────────────────────────────
100+
101+
102+
def run_serial(batch_size: int) -> tuple[int, float]:
103+
set_batch_size(batch_size)
104+
runner = SerialInsertRunner(
105+
db=get_milvus_db(f"bench_serial_b{batch_size}"),
106+
dataset=prepare_dataset(),
107+
normalize=False,
108+
)
109+
return timed_run(runner)
110+
111+
112+
def run_concurrent(batch_size: int, workers: int) -> tuple[int, float]:
113+
set_batch_size(batch_size)
114+
runner = ConcurrentInsertRunner(
115+
db=get_milvus_db(f"bench_conc_b{batch_size}_w{workers}"),
116+
dataset=prepare_dataset(),
117+
normalize=False,
118+
max_workers=workers,
119+
backend=ExecutorBackend.THREADING,
120+
)
121+
return timed_run(runner)
122+
123+
124+
def bench_matrix():
125+
batch_sizes = [100, 500, 1000, 5000]
126+
worker_counts = [1, 2, 4, 8]
127+
128+
conc_headers = [f"conc({w}w)" for w in worker_counts]
129+
speedup_headers = [f"speedup({w}w)" for w in worker_counts]
130+
print(f"\n{'Batch':>6} {'#Bat':>5} {'serial':>8}", end="")
131+
for h in conc_headers:
132+
print(f" {h:>10}", end="")
133+
for h in speedup_headers:
134+
print(f" {h:>12}", end="")
135+
print()
136+
print("-" * (22 + 10 * len(worker_counts) + 12 * len(worker_counts)))
137+
138+
for bs in batch_sizes:
139+
n_batches = DATASET_SIZE // bs
140+
_, dur_s = run_serial(bs)
141+
142+
conc_durs = []
143+
for w in worker_counts:
144+
_, dur_c = run_concurrent(bs, w)
145+
conc_durs.append(dur_c)
146+
147+
print(f"{bs:>6} {n_batches:>5} {dur_s:>7.2f}s", end="")
148+
for dur_c in conc_durs:
149+
print(f" {dur_c:>9.2f}s", end="")
150+
for dur_c in conc_durs:
151+
print(f" {dur_s / dur_c:>11.2f}x", end="")
152+
print()
153+
154+
# restore default
155+
set_batch_size(100)
156+
157+
158+
if __name__ == "__main__":
159+
bench_matrix()

vectordb_bench/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class config:
2020
DATASET_SOURCE = env.str("DATASET_SOURCE", "S3") # Options "S3" or "AliyunOSS"
2121
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
2222
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100)
23+
LOAD_CONCURRENCY = env.int("LOAD_CONCURRENCY", 0) # 0 = cpu_count
2324
TIME_PER_BATCH = 1 # 1s. for streaming insertion.
2425
MAX_INSERT_RETRY = 5
2526
MAX_SEARCH_RETRY = 5

vectordb_bench/backend/clients/alisql/alisql.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,13 @@ def init(self):
107107
self.cursor.execute(f"SET SESSION vidx_hnsw_ef_search = {search_param['ef_search']}")
108108
self.cursor.execute("COMMIT")
109109

110-
self.insert_sql = (
111-
f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)' # noqa: S608
112-
)
110+
self.insert_sql = f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)'
113111
self.select_sql = (
114-
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} ' # noqa: S608
112+
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} '
115113
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
116114
)
117115
self.select_sql_with_filter = (
118-
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s ' # noqa: S608
116+
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s '
119117
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
120118
)
121119

vectordb_bench/backend/clients/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ class VectorDB(ABC):
140140
supported_filter_types: list[FilterOp] = [FilterOp.NonFilter]
141141
name: str = ""
142142

143+
# Whether the client can share a single connection across threads.
144+
# If False, concurrent runners will deep-copy the instance and call
145+
# init() per thread instead of sharing the parent connection.
146+
thread_safe: bool = True
147+
143148
@classmethod
144149
def filter_supported(cls, filters: Filter) -> bool:
145150
"""Ensure that the filters are supported before testing filtering cases."""

vectordb_bench/backend/clients/doris/doris.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414

1515
class Doris(VectorDB):
16+
thread_safe: bool = False
17+
1618
def __init__(
1719
self,
1820
dim: int,

vectordb_bench/backend/clients/mariadb/mariadb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,13 @@ def init(self):
108108
self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
109109
self.cursor.execute("COMMIT")
110110

111-
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
111+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
112112
self.select_sql = (
113-
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
113+
f"SELECT id FROM {self.db_name}.{self.table_name}"
114114
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
115115
)
116116
self.select_sql_with_filter = (
117-
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
117+
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d "
118118
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
119119
)
120120

vectordb_bench/backend/clients/oceanbase/oceanbase.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def insert_embeddings(
186186
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
187187
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
188188
self._cursor.execute(
189-
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
189+
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}"
190190
)
191191
insert_count += len(batch)
192192
except mysql.Error:
@@ -217,7 +217,7 @@ def search_embedding(
217217
packed = struct.pack(f"<{len(query)}f", *query)
218218
hex_vec = packed.hex()
219219
query_str = (
220-
f"SELECT id FROM {self.table_name} " # noqa: S608
220+
f"SELECT id FROM {self.table_name} "
221221
f"{self.expr} ORDER BY "
222222
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
223223
f"APPROXIMATE LIMIT {k}"

vectordb_bench/backend/clients/pgvector/pgvector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
class PgVector(VectorDB):
2222
"""Use psycopg instructions"""
2323

24+
thread_safe: bool = False
2425
supported_filter_types: list[FilterOp] = [
2526
FilterOp.NonFilter,
2627
FilterOp.NumGE,

vectordb_bench/backend/clients/tidb/tidb.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _optimize_check_tiflash_replica_progress(self):
119119
cursor.execute(f"""
120120
SELECT PROGRESS FROM information_schema.tiflash_replica
121121
WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
122-
""") # noqa: S608
122+
""")
123123
result = cursor.fetchone()
124124
return result[0]
125125
except Exception as e:
@@ -131,7 +131,7 @@ def _optimize_wait_tiflash_catch_up(self):
131131
with self._get_connection() as (conn, cursor):
132132
cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
133133
conn.commit()
134-
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
134+
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
135135
result = cursor.fetchone()
136136
return result[0]
137137
except Exception as e:
@@ -155,7 +155,7 @@ def _optimize_get_tiflash_index_pending_rows(self):
155155
SELECT SUM(ROWS_STABLE_NOT_INDEXED)
156156
FROM information_schema.tiflash_indexes
157157
WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
158-
""") # noqa: S608
158+
""")
159159
result = cursor.fetchone()
160160
return result[0]
161161
except Exception as e:
@@ -172,7 +172,7 @@ def _insert_embeddings_serial(
172172
try:
173173
with self._get_connection() as (conn, cursor):
174174
buf = io.StringIO()
175-
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
175+
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
176176
for i in range(offset, offset + size):
177177
if i > offset:
178178
buf.write(",")
@@ -220,6 +220,6 @@ def search_embedding(
220220
self.cursor.execute(f"""
221221
SELECT id FROM {self.table_name}
222222
ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
223-
""") # noqa: S608
223+
""")
224224
result = self.cursor.fetchall()
225225
return [int(i[0]) for i in result]

0 commit comments

Comments
 (0)