Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ lint.ignore = [
"INP001", # TODO
"TID252", # TODO
"N801", "N802", "N815",
"S101", "S108", "S603", "S311",
"S101", "S108", "S603", "S311", "S608",
"PLR2004",
"RUF017",
"C416",
Expand Down
159 changes: 159 additions & 0 deletions tests/test_concurrent_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Tests for ConcurrentInsertRunner against a running Milvus instance.

Includes:
- Correctness tests (threading & async backends)
- Parameterized benchmark: serial vs concurrent across (batch_size, workers) matrix

NUM_PER_BATCH is set via os.environ before each run. Since runners execute
task() in a spawn subprocess that re-imports config, the env var takes effect.

Requires:
- Milvus running at localhost:19530
- Network access to download OpenAI 50K dataset

Usage:
pytest tests/test_concurrent_runner.py -v -s # correctness tests only
python tests/test_concurrent_runner.py # full benchmark matrix
"""

# ruff: noqa: T201

from __future__ import annotations

import logging
import os
import time

from vectordb_bench.backend.clients import DB
from vectordb_bench.backend.clients.milvus.config import FLATConfig
from vectordb_bench.backend.dataset import Dataset, DatasetSource
from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner, ExecutorBackend
from vectordb_bench.backend.runner.serial_runner import SerialInsertRunner

log = logging.getLogger("vectordb_bench")
log.setLevel(logging.INFO)

DATASET_SIZE = 50_000


# ── Shared helpers ──────────────────────────────────────────────────────


def get_milvus_db(collection_name: str):
return DB.Milvus.init_cls(
dim=1536,
db_config={"uri": "http://localhost:19530", "user": "", "password": ""},
db_case_config=FLATConfig(metric_type="COSINE"),
collection_name=collection_name,
drop_old=True,
)


def prepare_dataset():
dataset = Dataset.OPENAI.manager(DATASET_SIZE)
dataset.prepare(DatasetSource.AliyunOSS)
return dataset


def set_batch_size(batch_size: int) -> None:
os.environ["NUM_PER_BATCH"] = str(batch_size)


def timed_run(runner: SerialInsertRunner | ConcurrentInsertRunner) -> tuple[int, float]:
start = time.perf_counter()
count = runner.run()
return count, time.perf_counter() - start


# ── Correctness tests (pytest) ──────────────────────────────────────────


def test_concurrent_insert_threading():
"""Test concurrent insert with threading backend."""
db = get_milvus_db("test_conc_threading")
runner = ConcurrentInsertRunner(
db=db,
dataset=prepare_dataset(),
normalize=False,
max_workers=4,
backend=ExecutorBackend.THREADING,
)
count = runner.run()
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"


def test_concurrent_insert_async():
"""Test concurrent insert with async backend."""
db = get_milvus_db("test_conc_async")
runner = ConcurrentInsertRunner(
db=db,
dataset=prepare_dataset(),
normalize=False,
max_workers=4,
backend=ExecutorBackend.ASYNC,
)
count = runner.run()
assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}"


# ── Parameterized benchmark ────────────────────────────────────────────


def run_serial(batch_size: int) -> tuple[int, float]:
set_batch_size(batch_size)
runner = SerialInsertRunner(
db=get_milvus_db(f"bench_serial_b{batch_size}"),
dataset=prepare_dataset(),
normalize=False,
)
return timed_run(runner)


def run_concurrent(batch_size: int, workers: int) -> tuple[int, float]:
set_batch_size(batch_size)
runner = ConcurrentInsertRunner(
db=get_milvus_db(f"bench_conc_b{batch_size}_w{workers}"),
dataset=prepare_dataset(),
normalize=False,
max_workers=workers,
backend=ExecutorBackend.THREADING,
)
return timed_run(runner)


def bench_matrix():
batch_sizes = [100, 500, 1000, 5000]
worker_counts = [1, 2, 4, 8]

conc_headers = [f"conc({w}w)" for w in worker_counts]
speedup_headers = [f"speedup({w}w)" for w in worker_counts]
print(f"\n{'Batch':>6} {'#Bat':>5} {'serial':>8}", end="")
for h in conc_headers:
print(f" {h:>10}", end="")
for h in speedup_headers:
print(f" {h:>12}", end="")
print()
print("-" * (22 + 10 * len(worker_counts) + 12 * len(worker_counts)))

for bs in batch_sizes:
n_batches = DATASET_SIZE // bs
_, dur_s = run_serial(bs)

conc_durs = []
for w in worker_counts:
_, dur_c = run_concurrent(bs, w)
conc_durs.append(dur_c)

print(f"{bs:>6} {n_batches:>5} {dur_s:>7.2f}s", end="")
for dur_c in conc_durs:
print(f" {dur_c:>9.2f}s", end="")
for dur_c in conc_durs:
print(f" {dur_s / dur_c:>11.2f}x", end="")
print()

# restore default
set_batch_size(100)


if __name__ == "__main__":
bench_matrix()
1 change: 1 addition & 0 deletions vectordb_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class config:
DATASET_SOURCE = env.str("DATASET_SOURCE", "S3") # Options "S3" or "AliyunOSS"
DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset")
NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100)
LOAD_CONCURRENCY = env.int("LOAD_CONCURRENCY", 0) # 0 = cpu_count
TIME_PER_BATCH = 1 # 1s. for streaming insertion.
MAX_INSERT_RETRY = 5
MAX_SEARCH_RETRY = 5
Expand Down
8 changes: 3 additions & 5 deletions vectordb_bench/backend/clients/alisql/alisql.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,13 @@ def init(self):
self.cursor.execute(f"SET SESSION vidx_hnsw_ef_search = {search_param['ef_search']}")
self.cursor.execute("COMMIT")

self.insert_sql = (
f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)' # noqa: S608
)
self.insert_sql = f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)'
self.select_sql = (
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} ' # noqa: S608
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} '
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
)
self.select_sql_with_filter = (
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s ' # noqa: S608
f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s '
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s"
)

Expand Down
5 changes: 5 additions & 0 deletions vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ class VectorDB(ABC):
supported_filter_types: list[FilterOp] = [FilterOp.NonFilter]
name: str = ""

# Whether the client can share a single connection across threads.
# If False, concurrent runners will deep-copy the instance and call
# init() per thread instead of sharing the parent connection.
thread_safe: bool = True

@classmethod
def filter_supported(cls, filters: Filter) -> bool:
"""Ensure that the filters are supported before testing filtering cases."""
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/clients/doris/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class Doris(VectorDB):
thread_safe: bool = False

def __init__(
self,
dim: int,
Expand Down
6 changes: 3 additions & 3 deletions vectordb_bench/backend/clients/mariadb/mariadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ def init(self):
self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
self.cursor.execute("COMMIT")

self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
self.select_sql = (
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
f"SELECT id FROM {self.db_name}.{self.table_name}"
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
)
self.select_sql_with_filter = (
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d "
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
)

Expand Down
11 changes: 11 additions & 0 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ def init(self):
self.client.close()
self.client = None

def _wait_for_segments_sorted(self):
while True:
segments = self.client.list_persistent_segments(self.collection_name)
unsorted = [s for s in segments if not s.is_sorted]
if not unsorted:
log.info(f"{self.name} all persistent segments are sorted.")
break
log.debug(f"{self.name} waiting for {len(unsorted)} segments to be sorted...")
time.sleep(5)

def _wait_for_index(self):
while True:
info = self.client.describe_index(self.collection_name, self._vector_index_name)
Expand All @@ -155,6 +165,7 @@ def _optimize(self):
log.info(f"{self.name} optimizing before search")
try:
self.client.flush(self.collection_name)
self._wait_for_segments_sorted()
self._wait_for_index()
if self.case_config.is_gpu_index:
log.debug("skip force merge compaction for gpu index type.")
Expand Down
4 changes: 2 additions & 2 deletions vectordb_bench/backend/clients/oceanbase/oceanbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def insert_embeddings(
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
self._cursor.execute(
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}"
)
insert_count += len(batch)
except mysql.Error:
Expand Down Expand Up @@ -217,7 +217,7 @@ def search_embedding(
packed = struct.pack(f"<{len(query)}f", *query)
hex_vec = packed.hex()
query_str = (
f"SELECT id FROM {self.table_name} " # noqa: S608
f"SELECT id FROM {self.table_name} "
f"{self.expr} ORDER BY "
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
f"APPROXIMATE LIMIT {k}"
Expand Down
1 change: 1 addition & 0 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class PgVector(VectorDB):
"""Use psycopg instructions"""

thread_safe: bool = False
supported_filter_types: list[FilterOp] = [
FilterOp.NonFilter,
FilterOp.NumGE,
Expand Down
10 changes: 5 additions & 5 deletions vectordb_bench/backend/clients/tidb/tidb.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _optimize_check_tiflash_replica_progress(self):
cursor.execute(f"""
SELECT PROGRESS FROM information_schema.tiflash_replica
WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}"
""") # noqa: S608
""")
result = cursor.fetchone()
return result[0]
except Exception as e:
Expand All @@ -131,7 +131,7 @@ def _optimize_wait_tiflash_catch_up(self):
with self._get_connection() as (conn, cursor):
cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"')
conn.commit()
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
result = cursor.fetchone()
return result[0]
except Exception as e:
Expand All @@ -155,7 +155,7 @@ def _optimize_get_tiflash_index_pending_rows(self):
SELECT SUM(ROWS_STABLE_NOT_INDEXED)
FROM information_schema.tiflash_indexes
WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}"
""") # noqa: S608
""")
result = cursor.fetchone()
return result[0]
except Exception as e:
Expand All @@ -172,7 +172,7 @@ def _insert_embeddings_serial(
try:
with self._get_connection() as (conn, cursor):
buf = io.StringIO()
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608
buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ")
for i in range(offset, offset + size):
if i > offset:
buf.write(",")
Expand Down Expand Up @@ -220,6 +220,6 @@ def search_embedding(
self.cursor.execute(f"""
SELECT id FROM {self.table_name}
ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k};
""") # noqa: S608
""")
result = self.cursor.fetchall()
return [int(i[0]) for i in result]
2 changes: 1 addition & 1 deletion vectordb_bench/backend/clients/vespa/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def search_embedding(
embedding_field = "embedding" if self.case_config.quantization_type == "none" else "embedding_binary"

yql = (
f"select id from {self.schema_name} where " # noqa: S608
f"select id from {self.schema_name} where "
f"{{targetHits: {k}, hnsw.exploreAdditionalHits: {extra_ef}}}"
f"nearestNeighbor({embedding_field}, query_embedding)"
)
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .concurrent_runner import ConcurrentInsertRunner
from .mp_runner import MultiProcessingSearchRunner
from .read_write_runner import ReadWriteRunner
from .serial_runner import SerialInsertRunner, SerialSearchRunner

__all__ = [
"ConcurrentInsertRunner",
"MultiProcessingSearchRunner",
"ReadWriteRunner",
"SerialInsertRunner",
Expand Down
Loading
Loading