Skip to content

Commit d8765b2

Browse files
committed
fix(pgvector): fix ConcurrentInsertRunner for non-thread-safe DBs
For non-thread-safe DBs (e.g. PgVector), ConcurrentInsertRunner clamps max_workers to 1, so there is always exactly one worker thread. There is no need to deepcopy self.db per thread — the single worker can use self.db directly via the connection already opened by task()'s `with self.db.init():`. The original code called deepcopy(self.db) inside _get_thread_db() after task() had already opened a live psycopg C-extension Connection on self.db. C-extension objects cannot be deep-copied, causing: TypeError: no default __reduce__ due to non-trivial __cinit__ Fix: remove the deepcopy branch entirely. All workers (thread-safe or not) now use self.db directly; thread-safety is guaranteed for non-thread-safe DBs by the max_workers=1 clamp. Also clean up stale comments in pgvector.py left over from zilliztech#760/zilliztech#763. Adds tests/test_pgvector.py with: - unit test that reproduces the bug (fails on original, passes on fix) - e2e regression test via ConcurrentInsertRunner + OpenAI 50K dataset See also: zilliztech#756 Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent b3613ff commit d8765b2

4 files changed

Lines changed: 211 additions & 77 deletions

File tree

tests/pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@
33
filterwarnings =
44
ignore::UserWarning
55
ignore::DeprecationWarning
6+
7+
markers =
8+
integration: tests that require external services or network access (deselect with -m "not integration")

tests/test_pgvector.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Tests for PgVector client and ConcurrentInsertRunner.
2+
3+
Reproduces issue #756: insert fails with
4+
TypeError: no default __reduce__ due to non-trivial __cinit__
5+
when ConcurrentInsertRunner deep-copies a PgVector instance that has a live
6+
psycopg connection open (the connection is opened by `with self.db.init():`
7+
inside task() before the deepcopy in _get_thread_db()).
8+
9+
Requires:
10+
docker run -d --name pgvector-test \
11+
-e POSTGRES_USER=vectordb -e POSTGRES_PASSWORD=vectordb \
12+
-e POSTGRES_DB=vectordb -p 5432:5432 \
13+
pgvector/pgvector:pg17
14+
15+
Usage:
16+
pytest tests/test_pgvector.py -v -s
17+
"""
18+
19+
from __future__ import annotations
20+
21+
import logging
22+
import pickle
23+
from unittest.mock import MagicMock
24+
25+
import numpy as np
26+
import pytest
27+
28+
from vectordb_bench.backend.clients import DB
29+
from vectordb_bench.backend.clients.pgvector.config import PgVectorHNSWConfig
30+
from vectordb_bench.backend.dataset import Dataset, DatasetSource
31+
from vectordb_bench.backend.filter import Filter, FilterOp, non_filter
32+
from vectordb_bench.backend.runner.concurrent_runner import ConcurrentInsertRunner
33+
34+
log = logging.getLogger(__name__)
35+
36+
# ── Connection config ────────────────────────────────────────────────────────
37+
38+
DB_CONFIG = {
39+
"connect_config": {
40+
"host": "localhost",
41+
"port": 5432,
42+
"dbname": "vectordb",
43+
"user": "vectordb",
44+
"password": "vectordb",
45+
},
46+
"table_name": "test_pgvector",
47+
}
48+
49+
DIM = 128
50+
COUNT = 500
51+
RNG = np.random.default_rng(42)
52+
53+
54+
# ── Helpers ──────────────────────────────────────────────────────────────────
55+
56+
57+
def make_hnsw_config(**kwargs) -> PgVectorHNSWConfig:
58+
return PgVectorHNSWConfig(
59+
metric_type="COSINE",
60+
m=16,
61+
ef_construction=64,
62+
ef_search=64,
63+
**kwargs,
64+
)
65+
66+
67+
def make_db(table_name: str = "test_pgvector", drop_old: bool = True) -> DB.PgVector.init_cls:
68+
cfg = dict(DB_CONFIG)
69+
cfg["table_name"] = table_name
70+
return DB.PgVector.init_cls(
71+
dim=DIM,
72+
db_config=cfg,
73+
db_case_config=make_hnsw_config(),
74+
drop_old=drop_old,
75+
)
76+
77+
78+
def random_embeddings(n: int = COUNT, d: int = DIM) -> list[list[float]]:
79+
return RNG.random((n, d)).tolist()
80+
81+
82+
# ── Basic client tests ────────────────────────────────────────────────────────
83+
84+
85+
class TestPgVectorBasic:
86+
"""Unit tests for the PgVector client (no subprocess)."""
87+
88+
def test_insert_and_search(self):
89+
db = make_db("test_basic")
90+
embeddings = random_embeddings()
91+
metadata = list(range(COUNT))
92+
93+
with db.init():
94+
count, err = db.insert_embeddings(embeddings=embeddings, metadata=metadata)
95+
assert err is None, f"Insert error: {err}"
96+
assert count == COUNT
97+
98+
with db.init():
99+
db.optimize()
100+
101+
with db.init():
102+
db.prepare_filter(Filter(type=FilterOp.NonFilter))
103+
results = db.search_embedding(query=embeddings[0], k=10)
104+
assert len(results) > 0
105+
106+
def test_db_is_not_thread_safe(self):
107+
db = make_db("test_thread_safe")
108+
assert db.thread_safe is False
109+
110+
def test_db_picklable_after_init(self):
111+
"""PgVector instance must be picklable after __init__ (conn/cursor are None).
112+
113+
This is required for ConcurrentInsertRunner which spawns a subprocess
114+
and pickles self (which includes self.db).
115+
"""
116+
db = make_db("test_pickle")
117+
data = pickle.dumps(db)
118+
db2 = pickle.loads(data) # noqa: S301
119+
assert db2.dim == DIM
120+
121+
def test_get_thread_db_with_open_connection(self):
122+
"""Regression test for issue #756.
123+
124+
ConcurrentInsertRunner.task() opens `with self.db.init()` before calling
125+
workers. For non-thread-safe DBs the original _get_thread_db() then called
126+
deepcopy(self.db) — but the live psycopg C-extension Connection is not
127+
deep-copyable, causing TypeError.
128+
129+
Fixed code returns self.db directly (no deepcopy), so this test must pass
130+
without raising.
131+
"""
132+
db = make_db("test_get_thread_db")
133+
runner = ConcurrentInsertRunner(db=db, dataset=MagicMock(), normalize=False)
134+
135+
with db.init():
136+
assert db.conn is not None
137+
result = runner._get_thread_db() # TypeError here on original code
138+
139+
assert result is db
140+
141+
142+
# ── ConcurrentInsertRunner tests ──────────────────────────────────────────────
143+
144+
145+
class TestPgVectorConcurrentInsert:
146+
"""Tests for ConcurrentInsertRunner with PgVector (reproduces issue #756)."""
147+
148+
@pytest.mark.integration
149+
def test_concurrent_insert_e2e(self):
150+
"""E2E regression test for issue #756 using the OpenAI 50K dataset.
151+
152+
Exercises the full pipeline:
153+
ProcessPoolExecutor(spawn) → pickle runner → subprocess task()
154+
→ with self.db.init() → worker _get_thread_db() → insert batches
155+
156+
FAILS on original code (TypeError: deepcopy of live psycopg connection).
157+
PASSES on fixed code.
158+
"""
159+
dataset = Dataset.OPENAI.manager(50_000)
160+
dataset.prepare(DatasetSource.AliyunOSS)
161+
162+
cfg = dict(DB_CONFIG)
163+
cfg["table_name"] = "test_e2e_insert"
164+
db = DB.PgVector.init_cls(
165+
dim=dataset.data.dim,
166+
db_config=cfg,
167+
db_case_config=PgVectorHNSWConfig(
168+
metric_type="COSINE",
169+
m=16,
170+
ef_construction=64,
171+
ef_search=64,
172+
),
173+
drop_old=True,
174+
)
175+
176+
runner = ConcurrentInsertRunner(db=db, dataset=dataset, normalize=True, filters=non_filter)
177+
count = runner.run()
178+
179+
assert count == 50_000, f"Expected 50000 rows, got {count}"
180+
log.info(f"E2E insert completed: {count} rows")

vectordb_bench/backend/clients/pgvector/pgvector.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,9 @@ def _create_index(self):
335335

336336
index_param = self.case_config.index_param()
337337
self._set_parallel_index_build_param()
338-
# [FIX] The index access method name registered by the PostgreSQL pgvector extension is in
339-
# lowercase (e.g., "hnsw", "ivfflat"), but the index type passed from the frontend UI is
340-
# uppercase "HNSW" via IndexType.HNSW.value, causing SQL syntax "USING 'HNSW'" to fail
341-
# with error "access method HNSW does not exist". Here we uniformly convert it to lowercase
342-
# to match PostgreSQL's access method name.
338+
# pgvector registers access methods in lowercase ("hnsw", "ivfflat") but
339+
# IndexType enum values are uppercase; also IVFFlat maps to "ivfflat" (no underscore).
343340
index_type_lower = index_param["index_type"].lower()
344-
# [FIX] The pgvector access method name is "ivfflat" (no underscore), but IndexType.IVFFlat.value
345-
# produces "IVF_FLAT" which becomes "ivf_flat" after lowercase conversion, causing SQL syntax
346-
# "USING 'ivf_flat'" to fail with error "access method 'ivf_flat' does not exist".
347-
# Here we map "ivf_flat" → "ivfflat" to match PostgreSQL pgvector's registered access method name.
348341
if index_type_lower == "ivf_flat":
349342
index_type_lower = "ivfflat"
350343
log.info(f"index_type (original={index_param['index_type']}, normalized={index_type_lower})")
@@ -374,9 +367,8 @@ def _create_index(self):
374367
if index_param["quantization_type"] == "bit"
375368
else sql.Identifier("embedding")
376369
),
377-
# [FIX] Use lowercase index_type_lower instead of original index_param["index_type"]
378370
index_type=sql.Identifier(index_type_lower),
379-
# This assumes that the quantization_type value matches the quantization function name
371+
# quantization_type value matches the quantization function name
380372
quantization_type=sql.SQL(index_param["quantization_type"]),
381373
dim=self.dim,
382374
embedding_metric=sql.Identifier(index_param["metric"]),
@@ -390,7 +382,6 @@ def _create_index(self):
390382
).format(
391383
index_name=sql.Identifier(self._index_name),
392384
table_name=sql.Identifier(self.table_name),
393-
# [FIX] Use lowercase index_type_lower instead of original index_param["index_type"]
394385
index_type=sql.Identifier(index_type_lower),
395386
embedding_metric=sql.Identifier(index_param["metric"]),
396387
)

vectordb_bench/backend/runner/concurrent_runner.py

Lines changed: 25 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import multiprocessing as mp
1414
import threading
1515
import time
16-
from copy import deepcopy
1716
from enum import StrEnum
1817
from typing import TYPE_CHECKING
1918

@@ -44,7 +43,7 @@ class ConcurrentInsertRunner:
4443
"""Concurrent insert runner with pluggable executor backend.
4544
4645
Thread-safety: If db.thread_safe is False, max_workers is clamped to 1
47-
and each worker thread gets a deep-copied DB instance with its own connection.
46+
so the single worker thread uses self.db directly (no deepcopy needed).
4847
4948
Args:
5049
db: VectorDB instance.
@@ -78,57 +77,31 @@ def __init__(
7877
log.info(f"DB {db.name} is not thread-safe, falling back to max_workers=1")
7978
effective_workers = 1
8079
self.max_workers = effective_workers
80+
assert db.thread_safe or self.max_workers == 1, (
81+
"Non-thread-safe DBs must use max_workers=1 — "
82+
"_get_thread_db() relies on this to avoid concurrent access to self.db"
83+
)
8184

8285
def __getstate__(self):
8386
"""Exclude unpicklable thread-local state for ProcessPoolExecutor(spawn)."""
8487
state = self.__dict__.copy()
85-
state.pop("_local", None)
86-
state.pop("_ctx_lock", None)
87-
state.pop("_thread_contexts", None)
8888
state.pop("_iter_lock", None)
8989
state.pop("_dataset_iter", None)
9090
return state
9191

92-
def __setstate__(self, state: dict):
93-
self.__dict__.update(state)
94-
self._local = threading.local()
95-
self._ctx_lock = threading.Lock()
96-
self._thread_contexts = []
97-
9892
def _create_executor(self) -> TaskExecutor:
9993
if self.backend == ExecutorBackend.ASYNC:
10094
return AsyncExecutor(max_workers=self.max_workers)
10195
return ThreadExecutor(max_workers=self.max_workers)
10296

10397
def _get_thread_db(self) -> api.VectorDB:
104-
"""Get or create a per-thread DB instance.
98+
"""Return self.db.
10599
106-
Thread-safe DBs reuse self.db (connection opened in task()).
107-
Non-thread-safe DBs get a deep-copied instance with its own connection,
108-
cached in thread-local storage so it is created once per thread.
100+
All workers share the connection opened by task()'s `with self.db.init()`.
101+
Thread-safe DBs share it across multiple workers. Non-thread-safe DBs are
102+
clamped to max_workers=1, so there is never concurrent access.
109103
"""
110-
if not hasattr(self._local, "db"):
111-
if self.db.thread_safe:
112-
self._local.db = self.db
113-
else:
114-
db = deepcopy(self.db)
115-
# Manual __enter__/__exit__ because enter and exit happen in
116-
# different scopes (here vs _cleanup_thread_contexts).
117-
ctx = db.init()
118-
ctx.__enter__()
119-
self._local.db = db
120-
with self._ctx_lock:
121-
self._thread_contexts.append(ctx)
122-
return self._local.db
123-
124-
def _cleanup_thread_contexts(self) -> None:
125-
"""Close per-thread DB connections opened for non-thread-safe clients."""
126-
for ctx in self._thread_contexts:
127-
try:
128-
ctx.__exit__(None, None, None)
129-
except Exception:
130-
log.warning("Failed to close per-thread DB connection", exc_info=True)
131-
self._thread_contexts.clear()
104+
return self.db
132105

133106
def _insert_batch_with_retry(
134107
self,
@@ -160,14 +133,7 @@ def _worker_insert(
160133
metadata: list[int],
161134
labels_data: list[str] | None = None,
162135
) -> int:
163-
"""Worker function: insert a batch with retry.
164-
165-
Thread-safe DBs: reuse self.db whose connection is already open
166-
via task()'s `with self.db.init()` — all threads share it safely.
167-
168-
Non-thread-safe DBs: use a per-thread deep-copied instance with
169-
its own connection, cached via threading.local.
170-
"""
136+
"""Worker function: insert a batch with retry."""
171137
db = self._get_thread_db()
172138
return self._insert_batch_with_retry(db, embeddings, metadata, labels_data)
173139

@@ -214,9 +180,6 @@ def _worker_loop(self) -> int:
214180
def task(self) -> int:
215181
"""Insert entire dataset using concurrent executor. Runs in subprocess."""
216182
count = 0
217-
self._local = threading.local()
218-
self._ctx_lock = threading.Lock()
219-
self._thread_contexts = []
220183
self._iter_lock = threading.Lock()
221184
self._dataset_iter = iter(self.dataset)
222185

@@ -227,23 +190,20 @@ def task(self) -> int:
227190
)
228191
start = time.perf_counter()
229192

230-
try:
231-
with self._create_executor() as executor:
232-
for _ in range(self.max_workers):
233-
executor.submit(self._worker_loop)
234-
235-
batch_results = executor.wait_all()
236-
237-
# Log all errors, then raise the first one
238-
errors = [r.error for r in batch_results if r.error is not None]
239-
if errors:
240-
for err in errors:
241-
log.warning(f"Batch insert error: {err}")
242-
raise errors[0]
243-
244-
count = sum(r.value for r in batch_results)
245-
finally:
246-
self._cleanup_thread_contexts()
193+
with self._create_executor() as executor:
194+
for _ in range(self.max_workers):
195+
executor.submit(self._worker_loop)
196+
197+
batch_results = executor.wait_all()
198+
199+
# Log all errors, then raise the first one
200+
errors = [r.error for r in batch_results if r.error is not None]
201+
if errors:
202+
for err in errors:
203+
log.warning(f"Batch insert error: {err}")
204+
raise errors[0]
205+
206+
count = sum(r.value for r in batch_results)
247207

248208
log.info(
249209
f"({mp.current_process().name:16}) Finish concurrent insert, "

0 commit comments

Comments
 (0)