Skip to content

Commit cbe9f4f

Browse files
authored
Merge pull request #7 from atasoglu/unique-text
Unique text
2 parents c605cd6 + ba3bcd6 commit cbe9f4f

8 files changed

Lines changed: 297 additions & 14 deletions

File tree

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [2.4.0] - 2026-04-21
9+
10+
### Added
11+
- `unique_text` parameter in `create_table` to enforce a uniqueness constraint on the text column
12+
- `on_conflict` parameter in the `add` method to control duplicate-entry handling with options: `"error"`, `"ignore"`, and `"replace"`
13+
- Validation for `on_conflict` values to ensure only accepted options are used
14+
- Tests covering unique text constraints and all conflict resolution strategies
15+
816
## [2.3.0] - 2025-02-15
917

1018
### Added

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sqlite-vec-client"
7-
version = "2.3.0"
7+
version = "2.4.0"
88
description = "A lightweight Python client around sqlite-vec for CRUD and similarity search."
99
readme = "README.md"
1010
requires-python = ">=3.9"

sqlite_vec_client/base.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
validate_limit,
3131
validate_metadata_filters,
3232
validate_offset,
33+
validate_on_conflict,
3334
validate_table_name,
3435
validate_top_k,
3536
)
@@ -144,12 +145,15 @@ def create_table(
144145
self,
145146
dim: int,
146147
distance: Literal["L1", "L2", "cosine"] = "cosine",
148+
unique_text: bool = False,
147149
) -> None:
148150
"""Create base table, vector table, and triggers to keep them in sync.
149151
150152
Args:
151153
dim: Embedding dimension (must be positive)
152154
distance: Distance metric for similarity search
155+
unique_text: If True, enforce uniqueness on the text column.
156+
This enables ``on_conflict`` options in :meth:`add`.
153157
154158
Raises:
155159
TableNameError: If table name is invalid
@@ -172,6 +176,14 @@ def create_table(
172176
;
173177
"""
174178
)
179+
if unique_text:
180+
self.connection.execute(
181+
f"""
182+
CREATE UNIQUE INDEX IF NOT EXISTS {self.table}_text_unique
183+
ON {self.table}(text)
184+
;
185+
"""
186+
)
175187
self.connection.execute(
176188
f"""
177189
CREATE VIRTUAL TABLE IF NOT EXISTS {self.table}_vec USING vec0(
@@ -325,21 +337,30 @@ def add(
325337
texts: list[Text],
326338
embeddings: list[Embeddings],
327339
metadata: list[Metadata] | None = None,
340+
on_conflict: Literal["error", "ignore", "replace"] = "error",
328341
) -> Rowids:
329342
"""Insert texts with embeddings (and optional metadata) and return rowids.
330343
331344
Args:
332345
texts: List of text strings
333346
embeddings: List of embedding vectors
334347
metadata: Optional list of metadata dicts
348+
on_conflict: How to handle duplicate texts when a UNIQUE index on
349+
``text`` exists (see ``create_table(unique_text=True)``).
350+
351+
- ``"error"`` (default): raise on conflict.
352+
- ``"ignore"``: silently skip duplicate texts.
353+
- ``"replace"``: update metadata and embedding of
354+
existing records that share the same text.
335355
336356
Returns:
337-
List of rowids for inserted records
357+
List of rowids for inserted (or upserted) records
338358
339359
Raises:
340-
ValidationError: If list lengths don't match
360+
ValidationError: If list lengths don't match or on_conflict is invalid
341361
TableNotFoundError: If table doesn't exist
342362
"""
363+
validate_on_conflict(on_conflict)
343364
validate_embeddings_match(texts, embeddings, metadata)
344365
expected_dim = self._ensure_dimension()
345366
for embedding in embeddings:
@@ -356,19 +377,49 @@ def add(
356377

357378
cur = self.connection.cursor()
358379

359-
# Get max rowid before insert
360-
max_before = cur.execute(
361-
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
362-
).fetchone()[0]
380+
if on_conflict == "ignore":
381+
sql = (
382+
f"INSERT OR IGNORE INTO {self.table}"
383+
f"(text, metadata, text_embedding) VALUES (?,?,?)"
384+
)
385+
elif on_conflict == "replace":
386+
sql = (
387+
f"INSERT INTO {self.table}(text, metadata, text_embedding) "
388+
f"VALUES (?,?,?) "
389+
f"ON CONFLICT(text) DO UPDATE SET "
390+
f"metadata=excluded.metadata, "
391+
f"text_embedding=excluded.text_embedding"
392+
)
393+
else:
394+
sql = (
395+
f"INSERT INTO {self.table}"
396+
f"(text, metadata, text_embedding) VALUES (?,?,?)"
397+
)
363398

364-
cur.executemany(
365-
f"""INSERT INTO {self.table}(text, metadata, text_embedding)
366-
VALUES (?,?,?)""",
367-
data_input,
368-
)
399+
if on_conflict in ("error", "ignore"):
400+
max_before = cur.execute(
401+
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
402+
).fetchone()[0]
403+
404+
cur.executemany(sql, data_input)
369405

370-
# Calculate rowids from max_before
371-
rowids = list(range(max_before + 1, max_before + len(texts) + 1))
406+
if on_conflict == "error":
407+
rowids = list(range(max_before + 1, max_before + len(texts) + 1))
408+
elif on_conflict == "ignore":
409+
cur.execute(
410+
f"SELECT rowid FROM {self.table} "
411+
f"WHERE rowid > ? ORDER BY rowid",
412+
[max_before],
413+
)
414+
rowids = [row[0] for row in cur.fetchall()]
415+
else:
416+
placeholders = ",".join(["?"] * len(texts))
417+
cur.execute(
418+
f"SELECT rowid FROM {self.table} "
419+
f"WHERE text IN ({placeholders}) ORDER BY rowid",
420+
texts,
421+
)
422+
rowids = [row[0] for row in cur.fetchall()]
372423

373424
if not self._in_transaction:
374425
self.connection.commit()

sqlite_vec_client/validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,25 @@ def validate_embedding_dimension(embedding: list[float], expected_dim: int) -> N
123123
)
124124

125125

126+
_VALID_ON_CONFLICT = frozenset({"error", "ignore", "replace"})
127+
128+
129+
def validate_on_conflict(on_conflict: str) -> None:
130+
"""Validate the on_conflict parameter for add().
131+
132+
Args:
133+
on_conflict: Conflict resolution strategy
134+
135+
Raises:
136+
ValidationError: If the value is not one of 'error', 'ignore', 'replace'
137+
"""
138+
if on_conflict not in _VALID_ON_CONFLICT:
139+
raise ValidationError(
140+
f"on_conflict must be one of {sorted(_VALID_ON_CONFLICT)}, "
141+
f"got '{on_conflict}'"
142+
)
143+
144+
126145
def validate_metadata_filters(filters: dict[str, Any]) -> None:
127146
"""Validate metadata filters dictionary.
128147

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ def client_with_table(client: SQLiteVecClient) -> SQLiteVecClient:
3131
return client
3232

3333

34+
@pytest.fixture
35+
def client_with_unique_table(client: SQLiteVecClient) -> SQLiteVecClient:
36+
"""Provide a client with table created with unique_text=True."""
37+
client.create_table(dim=3, distance="cosine", unique_text=True)
38+
return client
39+
40+
3441
@pytest.fixture
3542
def sample_embeddings() -> list[list[float]]:
3643
"""Provide sample 3D embeddings for testing."""

tests/test_client.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Integration tests for SQLiteVecClient."""
22

3+
import sqlite3
4+
35
import pytest
46

57
from sqlite_vec_client import (
@@ -94,6 +96,155 @@ def test_add_invalid_embedding_dimension(
9496
client_with_table.add(texts=sample_texts, embeddings=invalid_embeddings)
9597

9698

99+
@pytest.mark.integration
100+
class TestUniqueText:
101+
"""Tests for unique_text constraint and on_conflict parameter."""
102+
103+
def test_unique_text_rejects_duplicates_by_default(
104+
self, client_with_unique_table, sample_embeddings
105+
):
106+
"""Duplicate text raises IntegrityError when on_conflict='error'."""
107+
client_with_unique_table.add(
108+
texts=["hello"], embeddings=[sample_embeddings[0]]
109+
)
110+
with pytest.raises(sqlite3.IntegrityError):
111+
client_with_unique_table.add(
112+
texts=["hello"], embeddings=[sample_embeddings[1]]
113+
)
114+
115+
def test_on_conflict_ignore_skips_duplicates(
116+
self, client_with_unique_table, sample_embeddings
117+
):
118+
"""Duplicate texts are silently skipped with on_conflict='ignore'."""
119+
client_with_unique_table.add(
120+
texts=["hello"], embeddings=[sample_embeddings[0]]
121+
)
122+
rowids = client_with_unique_table.add(
123+
texts=["hello", "world"],
124+
embeddings=[sample_embeddings[1], sample_embeddings[2]],
125+
on_conflict="ignore",
126+
)
127+
assert len(rowids) == 1
128+
assert client_with_unique_table.count() == 2
129+
130+
def test_on_conflict_ignore_all_duplicates(
131+
self, client_with_unique_table, sample_embeddings
132+
):
133+
"""All duplicates skipped returns empty rowids."""
134+
client_with_unique_table.add(
135+
texts=["hello"], embeddings=[sample_embeddings[0]]
136+
)
137+
rowids = client_with_unique_table.add(
138+
texts=["hello"],
139+
embeddings=[sample_embeddings[1]],
140+
on_conflict="ignore",
141+
)
142+
assert rowids == []
143+
assert client_with_unique_table.count() == 1
144+
145+
def test_on_conflict_ignore_preserves_original(
146+
self, client_with_unique_table, sample_embeddings
147+
):
148+
"""Ignored duplicates do not overwrite the original record."""
149+
client_with_unique_table.add(
150+
texts=["hello"],
151+
embeddings=[sample_embeddings[0]],
152+
metadata=[{"version": 1}],
153+
)
154+
client_with_unique_table.add(
155+
texts=["hello"],
156+
embeddings=[sample_embeddings[1]],
157+
metadata=[{"version": 2}],
158+
on_conflict="ignore",
159+
)
160+
record = client_with_unique_table.get(1)
161+
assert record[2] == {"version": 1}
162+
163+
def test_on_conflict_replace_updates_existing(
164+
self, client_with_unique_table, sample_embeddings
165+
):
166+
"""Duplicate texts are updated with on_conflict='replace'."""
167+
client_with_unique_table.add(
168+
texts=["hello"],
169+
embeddings=[sample_embeddings[0]],
170+
metadata=[{"version": 1}],
171+
)
172+
rowids = client_with_unique_table.add(
173+
texts=["hello"],
174+
embeddings=[sample_embeddings[1]],
175+
metadata=[{"version": 2}],
176+
on_conflict="replace",
177+
)
178+
assert len(rowids) == 1
179+
assert client_with_unique_table.count() == 1
180+
record = client_with_unique_table.get(rowids[0])
181+
assert record[2] == {"version": 2}
182+
assert record[3] == pytest.approx(sample_embeddings[1], abs=1e-6)
183+
184+
def test_on_conflict_replace_mixed_insert_and_update(
185+
self, client_with_unique_table, sample_embeddings
186+
):
187+
"""Replace mode handles a mix of new and existing texts."""
188+
client_with_unique_table.add(
189+
texts=["hello"], embeddings=[sample_embeddings[0]]
190+
)
191+
rowids = client_with_unique_table.add(
192+
texts=["hello", "world"],
193+
embeddings=[sample_embeddings[1], sample_embeddings[2]],
194+
on_conflict="replace",
195+
)
196+
assert len(rowids) == 2
197+
assert client_with_unique_table.count() == 2
198+
199+
def test_on_conflict_replace_keeps_rowid(
200+
self, client_with_unique_table, sample_embeddings
201+
):
202+
"""Replace mode preserves the original rowid."""
203+
original_rowids = client_with_unique_table.add(
204+
texts=["hello"], embeddings=[sample_embeddings[0]]
205+
)
206+
new_rowids = client_with_unique_table.add(
207+
texts=["hello"],
208+
embeddings=[sample_embeddings[1]],
209+
on_conflict="replace",
210+
)
211+
assert new_rowids == original_rowids
212+
213+
def test_on_conflict_replace_vec_table_synced(
214+
self, client_with_unique_table, sample_embeddings
215+
):
216+
"""Replace mode keeps the vector table in sync for similarity search."""
217+
client_with_unique_table.add(
218+
texts=["hello"], embeddings=[sample_embeddings[0]]
219+
)
220+
client_with_unique_table.add(
221+
texts=["hello"],
222+
embeddings=[sample_embeddings[1]],
223+
on_conflict="replace",
224+
)
225+
results = client_with_unique_table.similarity_search(
226+
embedding=sample_embeddings[1], top_k=1
227+
)
228+
assert results[0][1] == "hello"
229+
230+
def test_on_conflict_invalid_value(self, client_with_unique_table):
231+
"""Invalid on_conflict value raises ValidationError."""
232+
with pytest.raises(ValidationError):
233+
client_with_unique_table.add(
234+
texts=["hello"],
235+
embeddings=[[0.1, 0.2, 0.3]],
236+
on_conflict="bad",
237+
)
238+
239+
def test_without_unique_text_allows_duplicates(
240+
self, client_with_table, sample_embeddings
241+
):
242+
"""Without unique_text, duplicate texts are allowed."""
243+
client_with_table.add(texts=["hello"], embeddings=[sample_embeddings[0]])
244+
client_with_table.add(texts=["hello"], embeddings=[sample_embeddings[1]])
245+
assert client_with_table.count() == 2
246+
247+
97248
@pytest.mark.integration
98249
class TestSimilaritySearch:
99250
"""Tests for similarity_search method."""

tests/test_validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
validate_limit,
1515
validate_metadata_filters,
1616
validate_offset,
17+
validate_on_conflict,
1718
validate_table_name,
1819
validate_top_k,
1920
)
@@ -239,3 +240,23 @@ def test_non_string_keys(self):
239240
"""Test that non-string keys raise error."""
240241
with pytest.raises(ValidationError, match="must be string"):
241242
validate_metadata_filters({123: "value"})
243+
244+
245+
@pytest.mark.unit
246+
class TestValidateOnConflict:
247+
"""Tests for validate_on_conflict function."""
248+
249+
def test_valid_values(self):
250+
"""Test that valid on_conflict values pass validation."""
251+
for value in ["error", "ignore", "replace"]:
252+
validate_on_conflict(value)
253+
254+
def test_invalid_value(self):
255+
"""Test that invalid on_conflict value raises error."""
256+
with pytest.raises(ValidationError, match="on_conflict must be one of"):
257+
validate_on_conflict("bad")
258+
259+
def test_empty_value(self):
260+
"""Test that empty string raises error."""
261+
with pytest.raises(ValidationError, match="on_conflict must be one of"):
262+
validate_on_conflict("")

0 commit comments

Comments
 (0)