diff --git a/pyproject.toml b/pyproject.toml index 905996f5e..a1472fb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,7 @@ lint.ignore = [ "INP001", # TODO "TID252", # TODO "N801", "N802", "N815", - "S101", "S108", "S603", "S311", + "S101", "S108", "S603", "S311", "S608", "PLR2004", "RUF017", "C416", diff --git a/tests/test_concurrent_runner.py b/tests/test_concurrent_runner.py new file mode 100644 index 000000000..c9e5d9267 --- /dev/null +++ b/tests/test_concurrent_runner.py @@ -0,0 +1,159 @@ +"""Tests for ConcurrentInsertRunner against a running Milvus instance. + +Includes: + - Correctness tests (threading & async backends) + - Parameterized benchmark: serial vs concurrent across (batch_size, workers) matrix + +NUM_PER_BATCH is set via os.environ before each run. Since runners execute +task() in a spawn subprocess that re-imports config, the env var takes effect. + +Requires: + - Milvus running at localhost:19530 + - Network access to download OpenAI 50K dataset + +Usage: + pytest tests/test_concurrent_runner.py -v -s # correctness tests only + python tests/test_concurrent_runner.py # full benchmark matrix +""" + +# ruff: noqa: T201 + +from __future__ import annotations + +import logging +import os +import time + +from vectordb_bench.backend.clients import DB +from vectordb_bench.backend.clients.milvus.config import FLATConfig +from vectordb_bench.backend.dataset import Dataset, DatasetSource +from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner, ExecutorBackend +from vectordb_bench.backend.runner.serial_runner import SerialInsertRunner + +log = logging.getLogger("vectordb_bench") +log.setLevel(logging.INFO) + +DATASET_SIZE = 50_000 + + +# ── Shared helpers ────────────────────────────────────────────────────── + + +def get_milvus_db(collection_name: str): + return DB.Milvus.init_cls( + dim=1536, + db_config={"uri": "http://localhost:19530", "user": "", "password": ""}, + db_case_config=FLATConfig(metric_type="COSINE"), + collection_name=collection_name, + drop_old=True, + ) + + +def prepare_dataset(): + dataset = Dataset.OPENAI.manager(DATASET_SIZE) + dataset.prepare(DatasetSource.AliyunOSS) + return dataset + + +def set_batch_size(batch_size: int) -> None: + os.environ["NUM_PER_BATCH"] = str(batch_size) + + +def timed_run(runner: SerialInsertRunner | ConcurrentInsertRunner) -> tuple[int, float]: + start = time.perf_counter() + count = runner.run() + return count, time.perf_counter() - start + + +# ── Correctness tests (pytest) ────────────────────────────────────────── + + +def test_concurrent_insert_threading(): + """Test concurrent insert with threading backend.""" + db = get_milvus_db("test_conc_threading") + runner = ConcurrentInsertRunner( + db=db, + dataset=prepare_dataset(), + normalize=False, + max_workers=4, + backend=ExecutorBackend.THREADING, + ) + count = runner.run() + assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}" + + +def test_concurrent_insert_async(): + """Test concurrent insert with async backend.""" + db = get_milvus_db("test_conc_async") + runner = ConcurrentInsertRunner( + db=db, + dataset=prepare_dataset(), + normalize=False, + max_workers=4, + backend=ExecutorBackend.ASYNC, + ) + count = runner.run() + assert count == DATASET_SIZE, f"Expected {DATASET_SIZE}, got {count}" + + +# ── Parameterized benchmark ──────────────────────────────────────────── + + +def run_serial(batch_size: int) -> tuple[int, float]: + set_batch_size(batch_size) + runner = SerialInsertRunner( + db=get_milvus_db(f"bench_serial_b{batch_size}"), + dataset=prepare_dataset(), + normalize=False, + ) + return timed_run(runner) + + +def run_concurrent(batch_size: int, workers: int) -> tuple[int, float]: + set_batch_size(batch_size) + runner = ConcurrentInsertRunner( + db=get_milvus_db(f"bench_conc_b{batch_size}_w{workers}"), + dataset=prepare_dataset(), + normalize=False, + max_workers=workers, + backend=ExecutorBackend.THREADING, + ) + return timed_run(runner) + + +def bench_matrix(): + batch_sizes = [100, 500, 1000, 5000] + worker_counts = [1, 2, 4, 8] + + conc_headers = [f"conc({w}w)" for w in worker_counts] + speedup_headers = [f"speedup({w}w)" for w in worker_counts] + print(f"\n{'Batch':>6} {'#Bat':>5} {'serial':>8}", end="") + for h in conc_headers: + print(f" {h:>10}", end="") + for h in speedup_headers: + print(f" {h:>12}", end="") + print() + print("-" * (22 + 10 * len(worker_counts) + 12 * len(worker_counts))) + + for bs in batch_sizes: + n_batches = DATASET_SIZE // bs + _, dur_s = run_serial(bs) + + conc_durs = [] + for w in worker_counts: + _, dur_c = run_concurrent(bs, w) + conc_durs.append(dur_c) + + print(f"{bs:>6} {n_batches:>5} {dur_s:>7.2f}s", end="") + for dur_c in conc_durs: + print(f" {dur_c:>9.2f}s", end="") + for dur_c in conc_durs: + print(f" {dur_s / dur_c:>11.2f}x", end="") + print() + + # restore default + set_batch_size(100) + + +if __name__ == "__main__": + bench_matrix() diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index 07f77bb02..fc1813b38 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -20,6 +20,7 @@ class config: DATASET_SOURCE = env.str("DATASET_SOURCE", "S3") # Options "S3" or "AliyunOSS" DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset") NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100) + LOAD_CONCURRENCY = env.int("LOAD_CONCURRENCY", 0) # 0 = cpu_count TIME_PER_BATCH = 1 # 1s. for streaming insertion. MAX_INSERT_RETRY = 5 MAX_SEARCH_RETRY = 5 diff --git a/vectordb_bench/backend/clients/alisql/alisql.py b/vectordb_bench/backend/clients/alisql/alisql.py index f88cf9d88..6d1fbaefe 100644 --- a/vectordb_bench/backend/clients/alisql/alisql.py +++ b/vectordb_bench/backend/clients/alisql/alisql.py @@ -107,15 +107,13 @@ def init(self): self.cursor.execute(f"SET SESSION vidx_hnsw_ef_search = {search_param['ef_search']}") self.cursor.execute("COMMIT") - self.insert_sql = ( - f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)' # noqa: S608 - ) + self.insert_sql = f'INSERT INTO {self.db_config["database"]}.{self.table_name} (id, v) VALUES (%s, %s)' self.select_sql = ( - f'SELECT id FROM {self.db_config["database"]}.{self.table_name} ' # noqa: S608 + f'SELECT id FROM {self.db_config["database"]}.{self.table_name} ' f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s" ) self.select_sql_with_filter = ( - f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s ' # noqa: S608 + f'SELECT id FROM {self.db_config["database"]}.{self.table_name} WHERE id >= %s ' f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %s" ) diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index 82eda1824..80709e8e3 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -140,6 +140,11 @@ class VectorDB(ABC): supported_filter_types: list[FilterOp] = [FilterOp.NonFilter] name: str = "" + # Whether the client can share a single connection across threads. + # If False, concurrent runners will deep-copy the instance and call + # init() per thread instead of sharing the parent connection. + thread_safe: bool = True + @classmethod def filter_supported(cls, filters: Filter) -> bool: """Ensure that the filters are supported before testing filtering cases.""" diff --git a/vectordb_bench/backend/clients/doris/doris.py b/vectordb_bench/backend/clients/doris/doris.py index 82b3a12da..01984d665 100644 --- a/vectordb_bench/backend/clients/doris/doris.py +++ b/vectordb_bench/backend/clients/doris/doris.py @@ -13,6 +13,8 @@ class Doris(VectorDB): + thread_safe: bool = False + def __init__( self, dim: int, diff --git a/vectordb_bench/backend/clients/mariadb/mariadb.py b/vectordb_bench/backend/clients/mariadb/mariadb.py index db3863c85..e6053a0d8 100644 --- a/vectordb_bench/backend/clients/mariadb/mariadb.py +++ b/vectordb_bench/backend/clients/mariadb/mariadb.py @@ -108,13 +108,13 @@ def init(self): self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}") self.cursor.execute("COMMIT") - self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608 + self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" self.select_sql = ( - f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608 + f"SELECT id FROM {self.db_name}.{self.table_name}" f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" ) self.select_sql_with_filter = ( - f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608 + f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d" ) diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index ead2979ff..9e9dfb7f9 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -137,6 +137,16 @@ def init(self): self.client.close() self.client = None + def _wait_for_segments_sorted(self): + while True: + segments = self.client.list_persistent_segments(self.collection_name) + unsorted = [s for s in segments if not s.is_sorted] + if not unsorted: + log.info(f"{self.name} all persistent segments are sorted.") + break + log.debug(f"{self.name} waiting for {len(unsorted)} segments to be sorted...") + time.sleep(5) + def _wait_for_index(self): while True: info = self.client.describe_index(self.collection_name, self._vector_index_name) @@ -155,6 +165,7 @@ def _optimize(self): log.info(f"{self.name} optimizing before search") try: self.client.flush(self.collection_name) + self._wait_for_segments_sorted() self._wait_for_index() if self.case_config.is_gpu_index: log.debug("skip force merge compaction for gpu index type.") diff --git a/vectordb_bench/backend/clients/oceanbase/oceanbase.py b/vectordb_bench/backend/clients/oceanbase/oceanbase.py index 93c42aac1..bf615e4d0 100644 --- a/vectordb_bench/backend/clients/oceanbase/oceanbase.py +++ b/vectordb_bench/backend/clients/oceanbase/oceanbase.py @@ -186,7 +186,7 @@ def insert_embeddings( batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)] values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch) self._cursor.execute( - f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608 + f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" ) insert_count += len(batch) except mysql.Error: @@ -217,7 +217,7 @@ def search_embedding( packed = struct.pack(f"<{len(query)}f", *query) hex_vec = packed.hex() query_str = ( - f"SELECT id FROM {self.table_name} " # noqa: S608 + f"SELECT id FROM {self.table_name} " f"{self.expr} ORDER BY " f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') " f"APPROXIMATE LIMIT {k}" diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 42fa7533d..30c797c38 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -21,6 +21,7 @@ class PgVector(VectorDB): """Use psycopg instructions""" + thread_safe: bool = False supported_filter_types: list[FilterOp] = [ FilterOp.NonFilter, FilterOp.NumGE, diff --git a/vectordb_bench/backend/clients/tidb/tidb.py b/vectordb_bench/backend/clients/tidb/tidb.py index a5c99bbe4..fba60d41c 100644 --- a/vectordb_bench/backend/clients/tidb/tidb.py +++ b/vectordb_bench/backend/clients/tidb/tidb.py @@ -119,7 +119,7 @@ def _optimize_check_tiflash_replica_progress(self): cursor.execute(f""" SELECT PROGRESS FROM information_schema.tiflash_replica WHERE TABLE_SCHEMA = "{database}" AND TABLE_NAME = "{self.table_name}" - """) # noqa: S608 + """) result = cursor.fetchone() return result[0] except Exception as e: @@ -131,7 +131,7 @@ def _optimize_wait_tiflash_catch_up(self): with self._get_connection() as (conn, cursor): cursor.execute('SET @@TIDB_ISOLATION_READ_ENGINES="tidb,tiflash"') conn.commit() - cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") # noqa: S608 + cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") result = cursor.fetchone() return result[0] except Exception as e: @@ -155,7 +155,7 @@ def _optimize_get_tiflash_index_pending_rows(self): SELECT SUM(ROWS_STABLE_NOT_INDEXED) FROM information_schema.tiflash_indexes WHERE TIDB_DATABASE = "{database}" AND TIDB_TABLE = "{self.table_name}" - """) # noqa: S608 + """) result = cursor.fetchone() return result[0] except Exception as e: @@ -172,7 +172,7 @@ def _insert_embeddings_serial( try: with self._get_connection() as (conn, cursor): buf = io.StringIO() - buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") # noqa: S608 + buf.write(f"INSERT INTO {self.table_name} (id, embedding) VALUES ") for i in range(offset, offset + size): if i > offset: buf.write(",") @@ -220,6 +220,6 @@ def search_embedding( self.cursor.execute(f""" SELECT id FROM {self.table_name} ORDER BY {self.search_fn}(embedding, "{query!s}") LIMIT {k}; - """) # noqa: S608 + """) result = self.cursor.fetchall() return [int(i[0]) for i in result] diff --git a/vectordb_bench/backend/clients/vespa/vespa.py b/vectordb_bench/backend/clients/vespa/vespa.py index 5288bc04c..1f2e1b883 100644 --- a/vectordb_bench/backend/clients/vespa/vespa.py +++ b/vectordb_bench/backend/clients/vespa/vespa.py @@ -107,7 +107,7 @@ def search_embedding( embedding_field = "embedding" if self.case_config.quantization_type == "none" else "embedding_binary" yql = ( - f"select id from {self.schema_name} where " # noqa: S608 + f"select id from {self.schema_name} where " f"{{targetHits: {k}, hnsw.exploreAdditionalHits: {extra_ef}}}" f"nearestNeighbor({embedding_field}, query_embedding)" ) diff --git a/vectordb_bench/backend/runner/__init__.py b/vectordb_bench/backend/runner/__init__.py index 4af583773..d56fe0ff8 100644 --- a/vectordb_bench/backend/runner/__init__.py +++ b/vectordb_bench/backend/runner/__init__.py @@ -1,8 +1,10 @@ +from .concurrent_runner import ConcurrentInsertRunner from .mp_runner import MultiProcessingSearchRunner from .read_write_runner import ReadWriteRunner from .serial_runner import SerialInsertRunner, SerialSearchRunner __all__ = [ + "ConcurrentInsertRunner", "MultiProcessingSearchRunner", "ReadWriteRunner", "SerialInsertRunner", diff --git a/vectordb_bench/backend/runner/concurrent_runner.py b/vectordb_bench/backend/runner/concurrent_runner.py new file mode 100644 index 000000000..6ed8e39fb --- /dev/null +++ b/vectordb_bench/backend/runner/concurrent_runner.py @@ -0,0 +1,278 @@ +"""Concurrent insert runner with configurable executor backend. + +Replaces SerialInsertRunner for faster data loading in performance cases. + +Auto-detects thread-unsafe DBs via VectorDB.thread_safe and +falls back to single-worker mode. +""" + +from __future__ import annotations + +import concurrent.futures +import logging +import multiprocessing as mp +import threading +import time +from copy import deepcopy +from enum import StrEnum +from typing import TYPE_CHECKING + +import numpy as np + +from vectordb_bench.backend.filter import Filter, FilterOp, non_filter +from vectordb_bench.backend.utils import kill_proc_tree, time_it + +from ... import config +from ...models import PerformanceTimeoutError +from .executor import AsyncExecutor, ThreadExecutor + +if TYPE_CHECKING: + from vectordb_bench.backend.clients import api + from vectordb_bench.backend.dataset import DatasetManager + + from .executor import TaskExecutor + +log = logging.getLogger(__name__) + + +class ExecutorBackend(StrEnum): + THREADING = "threading" + ASYNC = "async" + + +class ConcurrentInsertRunner: + """Concurrent insert runner with pluggable executor backend. + + Thread-safety: If db.thread_safe is False, max_workers is clamped to 1 + and each worker thread gets a deep-copied DB instance with its own connection. + + Args: + db: VectorDB instance. + dataset: DatasetManager for batch iteration. + normalize: Whether to L2-normalize embeddings. + filters: Filter configuration. + timeout: Timeout in seconds for the overall operation. + max_workers: Number of concurrent workers (default: cpu_count). + backend: Executor backend to use ('threading' or 'async'). + """ + + def __init__( + self, + db: api.VectorDB, + dataset: DatasetManager, + normalize: bool, + filters: Filter = non_filter, + timeout: float | None = None, + max_workers: int | None = None, + backend: ExecutorBackend = ExecutorBackend.THREADING, + ): + self.timeout = timeout if isinstance(timeout, int | float) else None + self.dataset: DatasetManager = dataset + self.db = db + self.normalize = normalize + self.filters = filters + self.backend = backend + + effective_workers = max_workers or mp.cpu_count() + if not db.thread_safe: + log.info(f"DB {db.name} is not thread-safe, falling back to max_workers=1") + effective_workers = 1 + self.max_workers = effective_workers + + def __getstate__(self): + """Exclude unpicklable thread-local state for ProcessPoolExecutor(spawn).""" + state = self.__dict__.copy() + state.pop("_local", None) + state.pop("_ctx_lock", None) + state.pop("_thread_contexts", None) + state.pop("_iter_lock", None) + state.pop("_dataset_iter", None) + return state + + def __setstate__(self, state: dict): + self.__dict__.update(state) + self._local = threading.local() + self._ctx_lock = threading.Lock() + self._thread_contexts = [] + + def _create_executor(self) -> TaskExecutor: + if self.backend == ExecutorBackend.ASYNC: + return AsyncExecutor(max_workers=self.max_workers) + return ThreadExecutor(max_workers=self.max_workers) + + def _get_thread_db(self) -> api.VectorDB: + """Get or create a per-thread DB instance. + + Thread-safe DBs reuse self.db (connection opened in task()). + Non-thread-safe DBs get a deep-copied instance with its own connection, + cached in thread-local storage so it is created once per thread. + """ + if not hasattr(self._local, "db"): + if self.db.thread_safe: + self._local.db = self.db + else: + db = deepcopy(self.db) + # Manual __enter__/__exit__ because enter and exit happen in + # different scopes (here vs _cleanup_thread_contexts). + ctx = db.init() + ctx.__enter__() + self._local.db = db + with self._ctx_lock: + self._thread_contexts.append(ctx) + return self._local.db + + def _cleanup_thread_contexts(self) -> None: + """Close per-thread DB connections opened for non-thread-safe clients.""" + for ctx in self._thread_contexts: + try: + ctx.__exit__(None, None, None) + except Exception: + log.warning("Failed to close per-thread DB connection", exc_info=True) + self._thread_contexts.clear() + + def _insert_batch_with_retry( + self, + db: api.VectorDB, + embeddings: list[list[float]], + metadata: list[int], + labels_data: list[str] | None = None, + retry_idx: int = 0, + ) -> int: + """Insert a single batch with retry logic. Returns inserted count.""" + insert_count, error = db.insert_embeddings( + embeddings=embeddings, + metadata=metadata, + labels_data=labels_data, + ) + if error is not None: + log.warning(f"Insert failed, try_idx={retry_idx}, Exception: {error}") + retry_idx += 1 + if retry_idx <= config.MAX_INSERT_RETRY: + time.sleep(retry_idx) + return self._insert_batch_with_retry(db, embeddings, metadata, labels_data, retry_idx) + msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times" + raise RuntimeError(msg) + return insert_count + + def _worker_insert( + self, + embeddings: list[list[float]], + metadata: list[int], + labels_data: list[str] | None = None, + ) -> int: + """Worker function: insert a batch with retry. + + Thread-safe DBs: reuse self.db whose connection is already open + via task()'s `with self.db.init()` — all threads share it safely. + + Non-thread-safe DBs: use a per-thread deep-copied instance with + its own connection, cached via threading.local. + """ + db = self._get_thread_db() + return self._insert_batch_with_retry(db, embeddings, metadata, labels_data) + + def _next_batch(self) -> tuple[list[list[float]], list[int], list[str] | None] | None: + """Pull the next batch from the shared dataset iterator. + + Thread-safe: only one thread reads from the iterator at a time. + Returns None when the iterator is exhausted. + """ + with self._iter_lock: + try: + data_df = next(self._dataset_iter) + except StopIteration: + return None + + all_metadata = data_df[self.dataset.data.train_id_field].tolist() + emb_np = np.stack(data_df[self.dataset.data.train_vector_field]) + if self.normalize: + all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() + else: + all_embeddings = emb_np.tolist() + del emb_np + + labels_data = None + if self.filters.type == FilterOp.StrEqual: + if self.dataset.data.scalar_labels_file_separated: + labels_data = self.dataset.scalar_labels[self.filters.label_field][all_metadata].to_list() + else: + labels_data = data_df[self.filters.label_field].tolist() + + return all_embeddings, all_metadata, labels_data + + def _worker_loop(self) -> int: + """Worker loop: pull batches from the shared iterator and insert them.""" + total = 0 + while True: + batch = self._next_batch() + if batch is None: + break + embeddings, metadata, labels_data = batch + total += self._worker_insert(embeddings, metadata, labels_data) + return total + + def task(self) -> int: + """Insert entire dataset using concurrent executor. Runs in subprocess.""" + count = 0 + self._local = threading.local() + self._ctx_lock = threading.Lock() + self._thread_contexts = [] + self._iter_lock = threading.Lock() + self._dataset_iter = iter(self.dataset) + + with self.db.init(): + log.info( + f"({mp.current_process().name:16}) Start concurrent insert, " + f"batch_size={config.NUM_PER_BATCH}, max_workers={self.max_workers}" + ) + start = time.perf_counter() + + try: + with self._create_executor() as executor: + for _ in range(self.max_workers): + executor.submit(self._worker_loop) + + batch_results = executor.wait_all() + + # Log all errors, then raise the first one + errors = [r.error for r in batch_results if r.error is not None] + if errors: + for err in errors: + log.warning(f"Batch insert error: {err}") + raise errors[0] + + count = sum(r.value for r in batch_results) + finally: + self._cleanup_thread_contexts() + + log.info( + f"({mp.current_process().name:16}) Finish concurrent insert, " + f"count={count}, dur={time.perf_counter() - start:.2f}s" + ) + return count + + @time_it + def _insert_all_batches(self) -> int: + """Performance case only: run task() in subprocess with timeout.""" + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("spawn"), + max_workers=1, + ) as executor: + future = executor.submit(self.task) + try: + count = future.result(timeout=self.timeout) + except TimeoutError as e: + msg = f"VectorDB load dataset timeout in {self.timeout}" + log.warning(msg) + kill_proc_tree(pids=list(executor._processes.keys())) + raise PerformanceTimeoutError(msg) from e + except Exception as e: + log.warning(f"VectorDB load dataset error: {e}") + raise e from e + else: + return count + + def run(self) -> int: + """Insert full dataset concurrently. Returns total inserted count.""" + count, _ = self._insert_all_batches() + return count diff --git a/vectordb_bench/backend/runner/executor.py b/vectordb_bench/backend/runner/executor.py new file mode 100644 index 000000000..0bff2dc2a --- /dev/null +++ b/vectordb_bench/backend/runner/executor.py @@ -0,0 +1,170 @@ +"""Task executor abstraction with threading and async backends. + +Provides a unified interface for submitting callables with controlled +concurrency. Two implementations: + - ThreadExecutor: backed by ThreadPoolExecutor + - AsyncExecutor: backed by asyncio with semaphore-based concurrency control +""" + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + +log = logging.getLogger(__name__) + + +@dataclass +class TaskResult: + """Result of a single submitted task.""" + + value: Any = None + error: Exception | None = None + + @property + def success(self) -> bool: + return self.error is None + + +class TaskExecutor(ABC): + """Abstract executor that accepts callables and controls concurrency.""" + + @abstractmethod + def start(self) -> None: + """Initialize executor resources.""" + raise NotImplementedError + + @abstractmethod + def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Submit a task for execution.""" + raise NotImplementedError + + @abstractmethod + def wait_all(self) -> list[TaskResult]: + """Block until all submitted tasks complete. Return results in submission order.""" + raise NotImplementedError + + @abstractmethod + def shutdown(self) -> None: + """Release executor resources. Safe to call multiple times.""" + raise NotImplementedError + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> bool: + self.shutdown() + return False + + +class ThreadExecutor(TaskExecutor): + """ThreadPoolExecutor-backed implementation.""" + + def __init__(self, max_workers: int): + self._max_workers = max(1, max_workers) + self._executor: ThreadPoolExecutor | None = None + self._futures: list[Future] = [] + + def start(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=self._max_workers) + self._futures = [] + + def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + if self._executor is None: + raise RuntimeError("Executor not started. Call start() or use as context manager.") + future = self._executor.submit(fn, *args, **kwargs) + self._futures.append(future) + + def wait_all(self) -> list[TaskResult]: + results = [] + for future in self._futures: + try: + value = future.result() + results.append(TaskResult(value=value)) + except Exception as e: + results.append(TaskResult(error=e)) + self._futures = [] + return results + + def shutdown(self) -> None: + if self._executor is not None: + self._executor.shutdown(wait=True) + self._executor = None + + +class AsyncExecutor(TaskExecutor): + """asyncio-backed implementation for async DB clients. + + Accepts coroutine functions (async def), runs them on a single event + loop thread with semaphore-based concurrency control. No thread pool. + """ + + def __init__(self, max_workers: int): + self._max_workers = max(1, max_workers) + self._loop: asyncio.AbstractEventLoop | None = None + self._semaphore: asyncio.Semaphore | None = None + self._coros: list = [] + self._owns_loop = False + + def start(self) -> None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + self._owns_loop = True + self._semaphore = asyncio.Semaphore(self._max_workers) + self._coros = [] + + def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Submit a callable for execution. + + Accepts both coroutine functions (async def) and regular functions. + Sync functions are offloaded to a thread via run_in_executor. + """ + if self._loop is None or self._semaphore is None: + raise RuntimeError("Executor not started. Call start() or use as context manager.") + + async def _run(): + async with self._semaphore: + if asyncio.iscoroutinefunction(fn): + return await fn(*args, **kwargs) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: fn(*args, **kwargs)) + + self._coros.append(_run()) + + def wait_all(self) -> list[TaskResult]: + if not self._coros: + return [] + + async def _gather(): + gathered = await asyncio.gather(*self._coros, return_exceptions=True) + results = [] + for item in gathered: + if isinstance(item, Exception): + results.append(TaskResult(error=item)) + else: + results.append(TaskResult(value=item)) + return results + + if self._owns_loop: + results = self._loop.run_until_complete(_gather()) + else: + results = asyncio.run_coroutine_threadsafe(_gather(), self._loop).result() + + self._coros = [] + return results + + def shutdown(self) -> None: + if self._owns_loop and self._loop is not None: + self._loop.close() + self._loop = None + self._semaphore = None diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 300553a4e..be0c6322d 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -1,4 +1,4 @@ -import concurrent +import concurrent.futures import logging import math import multiprocessing as mp @@ -6,14 +6,13 @@ import traceback import numpy as np -import psutil from vectordb_bench.backend.dataset import DatasetManager -from vectordb_bench.backend.filter import Filter, FilterOp, non_filter +from vectordb_bench.backend.filter import Filter, non_filter from ... import config from ...metric import calc_ndcg, calc_recall, get_ideal_dcg -from ...models import LoadTimeoutError, PerformanceTimeoutError +from ...models import LoadTimeoutError from .. import utils from ..clients import api @@ -38,66 +37,6 @@ def __init__( self.normalize = normalize self.filters = filters - def retry_insert(self, db: api.VectorDB, retry_idx: int = 0, **kwargs): - _, error = db.insert_embeddings(**kwargs) - if error is not None: - log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}") - retry_idx += 1 - if retry_idx <= config.MAX_INSERT_RETRY: - time.sleep(retry_idx) - self.retry_insert(db, retry_idx=retry_idx, **kwargs) - else: - msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times" - raise RuntimeError(msg) from None - - def task(self) -> int: - count = 0 - with self.db.init(): - log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}") - start = time.perf_counter() - for data_df in self.dataset: - all_metadata = data_df[self.dataset.data.train_id_field].tolist() - - emb_np = np.stack(data_df[self.dataset.data.train_vector_field]) - if self.normalize: - log.debug("normalize the 100k train data") - all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() - else: - all_embeddings = emb_np.tolist() - del emb_np - log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}") - - labels_data = None - if self.filters.type == FilterOp.StrEqual: - if self.dataset.data.scalar_labels_file_separated: - labels_data = self.dataset.scalar_labels[self.filters.label_field][all_metadata].to_list() - else: - labels_data = data_df[self.filters.label_field].tolist() - - insert_count, error = self.db.insert_embeddings( - embeddings=all_embeddings, - metadata=all_metadata, - labels_data=labels_data, - ) - if error is not None: - self.retry_insert( - self.db, - embeddings=all_embeddings, - metadata=all_metadata, - labels_data=labels_data, - ) - - assert insert_count == len(all_metadata) - count += insert_count - if count % 100_000 == 0: - log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB") - - log.info( - f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, " - f"dur={time.perf_counter() - start}" - ) - return count - def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: int = 0) -> int: with self.db.init(): # unique id for endlessness insertion @@ -147,28 +86,6 @@ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: ) return count - @utils.time_it - def _insert_all_batches(self) -> int: - """Performance case only""" - with concurrent.futures.ProcessPoolExecutor( - mp_context=mp.get_context("spawn"), - max_workers=1, - ) as executor: - future = executor.submit(self.task) - try: - count = future.result(timeout=self.timeout) - except TimeoutError as e: - msg = f"VectorDB load dataset timeout in {self.timeout}" - log.warning(msg) - for pid, _ in executor._processes.items(): - psutil.Process(pid).kill() - raise PerformanceTimeoutError(msg) from e - except Exception as e: - log.warning(f"VectorDB load dataset error: {e}") - raise e from e - else: - return count - def run_endlessness(self) -> int: """run forever util DB raises exception or crash""" # datasets for load tests are quite small, can fit into memory @@ -204,10 +121,6 @@ def run_endlessness(self) -> int: else: raise LoadTimeoutError(self.timeout) - def run(self) -> int: - count, _ = self._insert_all_batches() - return count - class SerialSearchRunner: def __init__( diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 8224a0415..6b51d1277 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -6,7 +6,6 @@ from enum import Enum, auto import numpy as np -import psutil from ..base import BaseModel from ..metric import Metric @@ -15,7 +14,14 @@ from .cases import Case, CaseLabel, StreamingPerformanceCase from .clients import DB, MetricType, api from .data_source import DatasetSource -from .runner import MultiProcessingSearchRunner, ReadWriteRunner, SerialInsertRunner, SerialSearchRunner +from .runner import ( + ConcurrentInsertRunner, + MultiProcessingSearchRunner, + ReadWriteRunner, + SerialInsertRunner, + SerialSearchRunner, +) +from .utils import kill_proc_tree log = logging.getLogger(__name__) @@ -241,14 +247,15 @@ def _run_streaming_case(self) -> Metric: @utils.time_it def _load_train_data(self): - """Insert train data and get the insert_duration""" + """Insert train data concurrently and get the insert_duration""" try: - runner = SerialInsertRunner( + runner = ConcurrentInsertRunner( self.db, self.ca.dataset, self.normalize, self.ca.filters, self.ca.load_timeout, + max_workers=self.config.load_concurrency or None, ) runner.run() except Exception as e: @@ -299,8 +306,7 @@ def _optimize(self) -> float: return future.result(timeout=self.ca.optimize_timeout)[1] except TimeoutError as e: log.warning(f"VectorDB optimize timeout in {self.ca.optimize_timeout}") - for pid, _ in executor._processes.items(): - psutil.Process(pid).kill() + kill_proc_tree(pids=list(executor._processes.keys())) raise PerformanceTimeoutError from e except Exception as e: log.warning(f"VectorDB optimize error: {e}") diff --git a/vectordb_bench/backend/utils.py b/vectordb_bench/backend/utils.py index 86c4faf5e..432f0d1d1 100644 --- a/vectordb_bench/backend/utils.py +++ b/vectordb_bench/backend/utils.py @@ -1,6 +1,47 @@ +import contextlib +import logging +import signal import time from functools import wraps +import psutil + +log = logging.getLogger(__name__) + + +def kill_proc_tree(pids: list[int] | None = None, grace: float = 2, timeout: float = 3): + """Kill child processes with SIGTERM, then SIGKILL for survivors. + + Args: + pids: Specific PIDs to kill. If None, kills all children of the + current process (recursive). + grace: Seconds to wait after SIGTERM before sending SIGKILL. + timeout: Seconds to wait for processes to fully exit after SIGKILL. + """ + if pids is not None: + targets = [] + for pid in pids: + with contextlib.suppress(psutil.NoSuchProcess): + targets.append(psutil.Process(pid)) + else: + targets = psutil.Process().children(recursive=True) + + for p in targets: + try: + log.warning(f"sending SIGTERM to child process: {p}") + p.send_signal(signal.SIGTERM) + except psutil.NoSuchProcess: + pass + + _, alive = psutil.wait_procs(targets, timeout=grace) + for p in alive: + try: + log.warning(f"force killing child process: {p}") + p.kill() + except psutil.NoSuchProcess: + pass + psutil.wait_procs(alive, timeout=timeout) + def numerize(n: int) -> str: """display positive number n for readability diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 12bb4be9b..94b13762a 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -1,7 +1,6 @@ import logging import time from collections.abc import Callable -from concurrent.futures import wait from datetime import datetime from pathlib import Path from pprint import pformat @@ -20,7 +19,7 @@ from .. import config from ..backend.clients import DB from ..backend.clients.api import MetricType -from ..interface import benchmark_runner, global_result_future +from ..interface import benchmark_runner from ..models import ( CaseConfig, CaseType, @@ -231,6 +230,16 @@ class CommonTypedDict(TypedDict): show_default=True, ), ] + load_concurrency: Annotated[ + int, + click.option( + "--load-concurrency", + type=int, + default=config.LOAD_CONCURRENCY, + show_default=True, + help="Number of concurrent workers for data loading in performance cases (0 = cpu_count)", + ), + ] search_serial: Annotated[ bool, click.option( @@ -643,15 +652,16 @@ def run( parameters["search_serial"], parameters["search_concurrent"], ), + load_concurrency=parameters["load_concurrency"], ) task_label = parameters["task_label"] log.info(f"Task:\n{pformat(task)}\n") if not parameters["dry_run"]: benchmark_runner.run([task], task_label) - time.sleep(5) - if global_result_future: - wait([global_result_future]) - - while benchmark_runner.has_running(): - time.sleep(1) + try: + while benchmark_runner.has_running(): + time.sleep(1) + except KeyboardInterrupt: + log.warning("Ctrl+C received, stopping benchmark...") + benchmark_runner.stop_running() diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index 01d0c5876..e5c2a1e42 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -61,11 +61,17 @@ def advancedSettings(st): "Concurrency Duration", value=config.CONCURRENCY_DURATION, label_visibility="collapsed" ) container[1].caption("concurrency duration for each concurrency search test") - return index_already_exists, use_aliyun, k, concurrentInput, concurrency_duration + + container = st.columns([1, 2]) + load_concurrency = container[0].number_input( + "Load Concurrency", min_value=0, value=config.LOAD_CONCURRENCY, label_visibility="collapsed" + ) + container[1].caption("number of concurrent workers for data loading in performance cases (0 = cpu_count)") + return index_already_exists, use_aliyun, k, concurrentInput, concurrency_duration, load_concurrency def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid): - index_already_exists, use_aliyun, k, concurrentInput, concurrency_duration = advancedSettings(st) + index_already_exists, use_aliyun, k, concurrentInput, concurrency_duration, load_concurrency = advancedSettings(st) def runHandler(): benchmark_runner.set_drop_old(not index_already_exists) @@ -80,6 +86,7 @@ def runHandler(): task.case_config.k = k task.case_config.concurrency_search_config.num_concurrency = concurrentInput_list task.case_config.concurrency_search_config.concurrency_duration = concurrency_duration + task.load_concurrency = load_concurrency benchmark_runner.set_download_address(use_aliyun) benchmark_runner.run(tasks, taskLabel) diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index 42dc876b0..0d4119e93 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -2,20 +2,17 @@ import logging import multiprocessing as mp import pathlib -import signal import traceback import uuid -from collections.abc import Callable from enum import Enum from multiprocessing.connection import Connection -import psutil - from . import config from .backend.assembler import Assembler, FilterNotSupportedError from .backend.data_source import DatasetSource from .backend.result_collector import ResultCollector from .backend.task_runner import TaskRunner +from .backend.utils import kill_proc_tree from .metric import Metric from .models import ( CaseResult, @@ -240,7 +237,7 @@ def _clear_running_task(self): for r in self.running_task.case_runners: r.stop() - self.kill_proc_tree(timeout=5) + kill_proc_tree() self.running_task = None if self.receive_conn: @@ -261,29 +258,5 @@ def _run_async(self, conn: Connection) -> bool: return True - def kill_proc_tree( - self, - sig: int = signal.SIGTERM, - timeout: float | None = None, - on_terminate: Callable | None = None, - ): - """Kill a process tree (including grandchildren) with signal - "sig" and return a (gone, still_alive) tuple. - "on_terminate", if specified, is a callback function which is - called as soon as a child terminates. - """ - children = psutil.Process().children(recursive=True) - for p in children: - try: - log.warning(f"sending SIGTERM to child process: {p}") - p.send_signal(sig) - except psutil.NoSuchProcess: - pass - _, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate) - - for p in alive: - log.warning(f"force killing child process: {p}") - p.kill() - benchmark_runner = BenchMarkRunner() diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 27be85b6e..ffe8abd56 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -238,6 +238,7 @@ class TaskConfig(BaseModel): db_case_config: DBCaseConfig case_config: CaseConfig stages: list[TaskStage] = ALL_TASK_STAGES + load_concurrency: int = config.LOAD_CONCURRENCY @property def db_name(self):