Skip to content

Commit 21bd412

Browse files
committed
Address review feedback: lazy import, dimension probe, batch pipeline, tag escaping, typed exceptions
Signed-off-by: Daria Korenieva <daric2612@gmail.com>
1 parent da4159d commit 21bd412

2 files changed

Lines changed: 72 additions & 42 deletions

File tree

application/vectorstore/valkey.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
import uuid
77
from typing import Any, Dict, List, Optional
88

9+
_GLIDE_AVAILABLE = False
910
try:
1011
from glide_sync import (
12+
Batch,
13+
ConnectionError as GlideConnectionError,
1114
DataType,
1215
DistanceMetricType,
1316
Field,
@@ -17,31 +20,33 @@
1720
GlideClient,
1821
GlideClientConfiguration,
1922
NodeAddress,
23+
RequestError,
2024
ReturnField,
2125
ServerCredentials,
2226
TagField,
2327
TextField,
28+
TimeoutError as GlideTimeoutError,
2429
VectorAlgorithm,
2530
VectorField,
2631
VectorFieldAttributesFlat,
2732
VectorFieldAttributesHnsw,
2833
VectorType,
2934
ft,
3035
)
36+
37+
_GLIDE_AVAILABLE = True
3138
except ImportError:
32-
raise ImportError(
33-
"Could not import valkey-glide-sync. "
34-
"Please install with `pip install valkey-glide-sync`."
35-
)
39+
pass
3640

37-
from application.core.settings import settings
38-
from application.vectorstore.base import BaseVectorStore
39-
from application.vectorstore.document_class import Document
41+
from application.core.settings import settings # noqa: E402
42+
from application.vectorstore.base import BaseVectorStore # noqa: E402
43+
from application.vectorstore.document_class import Document # noqa: E402
4044

4145
logger = logging.getLogger(__name__)
4246

4347
# Characters that must be escaped in Valkey tag field query values.
44-
_TAG_SPECIAL_CHARS = set(r".,<>{}[]\"':;!@#$%^&*()-+=~ /|")
48+
# Includes '?' which is a single-character wildcard in valkey-search TAG queries.
49+
_TAG_SPECIAL_CHARS = set(r".,<>{}[]\"':;!@#$%^&*()-+=~ /|?")
4550

4651
# Batch size for DELETE operations in delete_index.
4752
_DELETE_BATCH_SIZE = 100
@@ -70,7 +75,15 @@ def __init__(
7075
source_id: Identifier for the document source, used to
7176
namespace and filter documents.
7277
embeddings_key: Key name or API key for the embeddings provider.
78+
79+
Raises:
80+
ImportError: If valkey-glide-sync is not installed.
7381
"""
82+
if not _GLIDE_AVAILABLE:
83+
raise ImportError(
84+
"Could not import valkey-glide-sync. "
85+
"Please install with `pip install valkey-glide-sync`."
86+
)
7487
super().__init__()
7588
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
7689
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
@@ -127,7 +140,10 @@ def _ensure_index_exists(self):
127140
Uses VALKEY_DISTANCE_METRIC, VALKEY_VECTOR_TYPE, and VALKEY_VECTOR_ALGORITHM
128141
from settings. Falls back to cosine/float32/hnsw if values are unrecognized.
129142
"""
130-
embedding_dim = getattr(self._embedding, "dimension", 768)
143+
embedding_dim = getattr(self._embedding, "dimension", None)
144+
if embedding_dim is None:
145+
probe = self._embedding.embed_query("dimension probe")
146+
embedding_dim = len(probe)
131147

132148
distance_metric = self._resolve_distance_metric(settings.VALKEY_DISTANCE_METRIC)
133149
vector_type = self._resolve_vector_type(settings.VALKEY_VECTOR_TYPE)
@@ -310,7 +326,7 @@ def search(self, question: str, k: int = 2, *args, **kwargs) -> List[Document]:
310326

311327
return self._parse_search_results(results)
312328

313-
except Exception as e:
329+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
314330
logger.error(f"Error searching Valkey: {e}", exc_info=True)
315331
return []
316332

@@ -406,6 +422,8 @@ def add_texts(
406422
metadatas = metadatas or [{}] * len(texts)
407423
doc_ids: List[str] = []
408424

425+
# Use non-atomic Batch (pipeline) to reduce network round trips.
426+
batch = Batch(False)
409427
for text, embedding, metadata in zip(texts, embeddings, metadatas):
410428
doc_id = str(uuid.uuid4())
411429
key = self._doc_key(doc_id)
@@ -418,15 +436,17 @@ def add_texts(
418436
"embedding": vector_bytes,
419437
}
420438

421-
try:
422-
self._client.hset(key, fields)
423-
doc_ids.append(doc_id)
424-
except Exception as e:
425-
logger.error(
426-
f"Error adding document to Valkey (wrote {len(doc_ids)}/{len(texts)} "
427-
f"before failure): {e}"
428-
)
429-
raise
439+
batch.hset(key, fields)
440+
doc_ids.append(doc_id)
441+
442+
try:
443+
self._client.exec(batch, raise_on_error=True)
444+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
445+
logger.error(
446+
f"Error adding documents to Valkey via pipeline "
447+
f"({len(doc_ids)} documents): {e}"
448+
)
449+
raise
430450

431451
return doc_ids
432452

@@ -488,7 +508,7 @@ def delete_index(self, *args, **kwargs):
488508
batch = keys[i : i + _DELETE_BATCH_SIZE]
489509
self._client.delete(batch)
490510

491-
except Exception as e:
511+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
492512
logger.error(f"Error deleting index from Valkey: {e}", exc_info=True)
493513

494514
def save_local(self, *args, **kwargs):
@@ -550,7 +570,7 @@ def get_chunks(self) -> List[Dict[str, Any]]:
550570

551571
return chunks
552572

553-
except Exception as e:
573+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
554574
logger.error(f"Error getting chunks from Valkey: {e}", exc_info=True)
555575
return []
556576

@@ -586,7 +606,7 @@ def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str
586606
try:
587607
self._client.hset(key, fields)
588608
return doc_id
589-
except Exception as e:
609+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
590610
logger.error(f"Error adding chunk to Valkey: {e}")
591611
raise
592612

@@ -603,6 +623,6 @@ def delete_chunk(self, chunk_id: str) -> bool:
603623
key = self._doc_key(chunk_id)
604624
result = self._client.delete([key])
605625
return result > 0
606-
except Exception as e:
626+
except (RequestError, GlideConnectionError, GlideTimeoutError) as e:
607627
logger.error(f"Error deleting chunk from Valkey: {e}", exc_info=True)
608628
return False

tests/vectorstore/test_valkey.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pytest
77

8+
from glide_sync import RequestError
9+
810

911
def _make_store(source_id="test-source", embeddings_key="key"):
1012
"""Helper to create a ValkeyStore with all external deps mocked."""
@@ -136,7 +138,7 @@ def test_search_returns_documents(self):
136138
def test_search_returns_empty_on_error(self):
137139
store, mock_client, _ = _make_store()
138140

139-
with patch("application.vectorstore.valkey.ft.search", side_effect=Exception("connection lost")):
141+
with patch("application.vectorstore.valkey.ft.search", side_effect=RequestError("connection lost")):
140142
results = store.search("query")
141143
assert results == []
142144

@@ -178,12 +180,12 @@ class TestValkeyStoreAddTexts:
178180
def test_add_texts_returns_ids(self):
179181
store, mock_client, mock_emb = _make_store()
180182
mock_emb.embed_documents.return_value = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
181-
mock_client.hset = Mock(return_value=3)
183+
mock_client.exec = Mock(return_value=["OK", "OK"])
182184

183185
ids = store.add_texts(["text1", "text2"], [{"a": 1}, {"b": 2}])
184186

185187
assert len(ids) == 2
186-
assert mock_client.hset.call_count == 2
188+
mock_client.exec.assert_called_once()
187189

188190
def test_add_texts_empty_returns_empty(self):
189191
store, _, _ = _make_store()
@@ -192,35 +194,43 @@ def test_add_texts_empty_returns_empty(self):
192194
def test_add_texts_default_metadatas(self):
193195
store, mock_client, mock_emb = _make_store()
194196
mock_emb.embed_documents.return_value = [[0.1, 0.2, 0.3]]
195-
mock_client.hset = Mock(return_value=3)
197+
mock_client.exec = Mock(return_value=["OK"])
196198

197199
ids = store.add_texts(["text1"])
198200
assert len(ids) == 1
199201

200202
def test_add_texts_raises_on_error(self):
201203
store, mock_client, mock_emb = _make_store()
202204
mock_emb.embed_documents.return_value = [[0.1, 0.2, 0.3]]
203-
mock_client.hset = Mock(side_effect=Exception("write failed"))
205+
mock_client.exec = Mock(side_effect=RequestError("write failed"))
204206

205-
with pytest.raises(Exception, match="write failed"):
207+
with pytest.raises(RequestError, match="write failed"):
206208
store.add_texts(["text1"])
207209

208210
def test_add_texts_stores_correct_fields(self):
209211
store, mock_client, mock_emb = _make_store(source_id="src1")
210212
mock_emb.embed_documents.return_value = [[0.1, 0.2, 0.3]]
211-
mock_client.hset = Mock(return_value=3)
213+
mock_client.exec = Mock(return_value=["OK"])
212214

213-
store.add_texts(["hello"], [{"key": "val"}])
215+
with patch("application.vectorstore.valkey.Batch") as MockBatch:
216+
mock_batch_instance = MagicMock()
217+
MockBatch.return_value = mock_batch_instance
214218

215-
call_args = mock_client.hset.call_args
216-
key = call_args[0][0]
217-
fields = call_args[0][1]
219+
store.add_texts(["hello"], [{"key": "val"}])
220+
221+
# Verify the batch was created as non-atomic (pipeline)
222+
MockBatch.assert_called_once_with(False)
223+
224+
# Verify hset was called on the batch with correct fields
225+
call_args = mock_batch_instance.hset.call_args
226+
key = call_args[0][0]
227+
fields = call_args[0][1]
218228

219-
assert key.startswith("doc:")
220-
assert fields["content"] == "hello"
221-
assert fields["source_id"] == "src1"
222-
assert json.loads(fields["metadata"]) == {"key": "val"}
223-
assert isinstance(fields["embedding"], bytes)
229+
assert key.startswith("doc:")
230+
assert fields["content"] == "hello"
231+
assert fields["source_id"] == "src1"
232+
assert json.loads(fields["metadata"]) == {"key": "val"}
233+
assert isinstance(fields["embedding"], bytes)
224234

225235

226236
@pytest.mark.unit
@@ -277,7 +287,7 @@ def test_delete_index_paginates_large_sets(self):
277287
def test_delete_index_handles_error(self):
278288
store, mock_client, _ = _make_store()
279289

280-
with patch("application.vectorstore.valkey.ft.search", side_effect=Exception("fail")):
290+
with patch("application.vectorstore.valkey.ft.search", side_effect=RequestError("fail")):
281291
# Should not raise
282292
store.delete_index()
283293

@@ -310,7 +320,7 @@ def test_get_chunks(self):
310320
def test_get_chunks_returns_empty_on_error(self):
311321
store, mock_client, _ = _make_store()
312322

313-
with patch("application.vectorstore.valkey.ft.search", side_effect=Exception("fail")):
323+
with patch("application.vectorstore.valkey.ft.search", side_effect=RequestError("fail")):
314324
assert store.get_chunks() == []
315325

316326
def test_get_chunks_uses_return_fields(self):
@@ -378,7 +388,7 @@ def test_delete_chunk_not_found(self):
378388

379389
def test_delete_chunk_returns_false_on_error(self):
380390
store, mock_client, _ = _make_store()
381-
mock_client.delete = Mock(side_effect=Exception("fail"))
391+
mock_client.delete = Mock(side_effect=RequestError("fail"))
382392

383393
result = store.delete_chunk("uuid-123")
384394
assert result is False

0 commit comments

Comments
 (0)