Skip to content

Commit 8f5200e

Browse files
committed
addressed comments (part_2): modified filtering logic
1 parent 88b75ab commit 8f5200e

6 files changed

Lines changed: 298 additions & 85 deletions

File tree

integrations/valkey/src/haystack_integrations/document_stores/valkey/document_store.py

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from glide_shared.commands.server_modules.ft_options.ft_search_options import (
4040
FtSearchOptions,
41+
FtSearchLimit,
4142
ReturnField,
4243
)
4344
from glide_sync import (
@@ -63,6 +64,7 @@
6364
from haystack.dataclasses.document import Document
6465
from haystack.document_stores.errors import DocumentStoreError
6566
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
67+
from haystack.errors import FilterError
6668
from haystack.utils import Secret
6769

6870
from haystack_integrations.document_stores.valkey.filters import _normalize_filters, _validate_filters
@@ -157,6 +159,7 @@ def __init__(
157159
index_name: str = "haystack_document",
158160
distance_metric: Literal["l2", "cosine", "ip"] = "cosine",
159161
embedding_dim: int = 768,
162+
metadata_fields: dict[str, type[str] | type[int]] | None = None,
160163
):
161164
"""
162165
Creates a new ValkeyDocumentStore instance.
@@ -177,11 +180,18 @@ def __init__(
177180
:param distance_metric: Distance metric for vector similarity. Options: "l2", "cosine", "ip" (inner product).
178181
Defaults to "cosine".
179182
:param embedding_dim: Dimension of document embeddings. Defaults to 768.
183+
:param metadata_fields: Dictionary mapping metadata field names to Python types for filtering.
184+
Supported types: str (for exact matching), int (for numeric comparisons).
185+
Example: {"category": str, "priority": int}.
186+
If not provided, no metadata fields will be indexed for filtering.
180187
"""
181188
self._index_name = index_name
182189
self._distance_metric = self._parse_metric(distance_metric)
183190
self._embedding_dim = embedding_dim
184191
self._dummy_vector = [ValkeyDocumentStore._DUMMY_VALUE] * self._embedding_dim
192+
193+
# Validate and normalize metadata fields
194+
self._metadata_fields = self._validate_and_normalize_metadata_fields(metadata_fields or {})
185195

186196
self._nodes_list: list[tuple[str, int]] = nodes_list or [("localhost", 6379)]
187197
self._cluster_mode: bool = cluster_mode
@@ -310,7 +320,7 @@ async def _has_index_async(self) -> bool:
310320

311321
def _prepare_index_fields(self) -> list[Field]:
312322
"""Prepare index fields configuration."""
313-
return [
323+
fields = [
314324
TagField("$.id", alias="id"),
315325
VectorField(
316326
name="vector",
@@ -321,12 +331,18 @@ def _prepare_index_fields(self) -> list[Field]:
321331
type=VectorType.FLOAT32,
322332
),
323333
),
324-
TagField("$.meta_category", alias="meta_category"),
325-
TagField("$.meta_status", alias="meta_status"),
326-
NumericField("$.meta_priority", alias="meta_priority"),
327-
NumericField("$.meta_score", alias="meta_score"),
328-
NumericField("$.meta_timestamp", alias="meta_timestamp"),
329334
]
335+
336+
# _metadata_fields keys already have meta_ prefix
337+
for field_name, field_type in self._metadata_fields.items():
338+
field_path = f"$.{field_name}"
339+
340+
if field_type == "tag":
341+
fields.append(TagField(field_path, alias=field_name))
342+
elif field_type == "numeric":
343+
fields.append(NumericField(field_path, alias=field_name))
344+
345+
return fields
330346

331347
def _create_index(self) -> None:
332348
client = self._get_connection()
@@ -370,6 +386,12 @@ def to_dict(self) -> dict[str, Any]:
370386
"""
371387
Serializes this store to a dictionary.
372388
"""
389+
metadata_fields_for_ser = {}
390+
for field_name, field_type in self._metadata_fields.items():
391+
# Remove meta_ prefix: meta_category -> category
392+
clean_name = field_name[5:] if field_name.startswith("meta_") else field_name
393+
metadata_fields_for_ser[clean_name] = str if field_type == "tag" else int
394+
373395
return default_to_dict(
374396
self,
375397
nodes_list=self._nodes_list,
@@ -385,6 +407,7 @@ def to_dict(self) -> dict[str, Any]:
385407
index_name=self._index_name,
386408
distance_metric=self._distance_metric.name.lower(),
387409
embedding_dim=self._embedding_dim,
410+
metadata_fields=metadata_fields_for_ser if metadata_fields_for_ser else None,
388411
)
389412

390413
@classmethod
@@ -508,6 +531,8 @@ def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Docume
508531

509532
return docs_no_score
510533

534+
except FilterError:
535+
raise
511536
except Exception as e:
512537
msg = f"Error filtering documents in index '{self._index_name}'"
513538
raise ValkeyDocumentStoreError(msg) from e
@@ -558,6 +583,8 @@ async def filter_documents_async(self, filters: dict[str, Any] | None = None) ->
558583

559584
return docs_no_score
560585

586+
except FilterError:
587+
raise
561588
except Exception as e:
562589
msg = f"Error filtering documents in index '{self._index_name}'"
563590
raise ValkeyDocumentStoreError(msg) from e
@@ -881,12 +908,14 @@ def _embedding_retrieval(
881908

882909
try:
883910
query, query_options = self._build_search_query_and_options(
884-
embedding, filters, limit, with_embedding=with_embedding
911+
embedding, filters, limit, with_embedding=with_embedding, supported_fields=self._metadata_fields
885912
)
886913
results = sync_ft.search(client, self._index_name, query, query_options)
887914

888915
return self._parse_documents_from_ft(results, with_embedding=with_embedding)
889916

917+
except FilterError:
918+
raise
890919
except Exception as e:
891920
msg = f"Failed to retrieve documents by embedding: {e}"
892921
raise ValkeyDocumentStoreError(msg) from e
@@ -954,12 +983,14 @@ async def _embedding_retrieval_async(
954983

955984
try:
956985
query, query_options = self._build_search_query_and_options(
957-
embedding, filters, limit, with_embedding=with_embedding
986+
embedding, filters, limit, with_embedding=with_embedding, supported_fields=self._metadata_fields
958987
)
959988
results = await ft.search(client, self._index_name, query, query_options)
960989

961990
return self._parse_documents_from_ft(results, with_embedding=with_embedding)
962991

992+
except FilterError:
993+
raise
963994
except Exception as e:
964995
msg = f"Failed to retrieve documents by embedding: {e}"
965996
raise ValkeyDocumentStoreError(msg) from e
@@ -970,21 +1001,15 @@ def _prepare_document_dict(self, doc: Document) -> dict:
9701001
payload.pop("embedding", None)
9711002

9721003
meta = doc.meta or {}
973-
doc_dict = {
974-
"id": doc.id,
975-
"payload": payload,
976-
"meta_category": meta.get("category", ""),
977-
"meta_status": meta.get("status", ""),
978-
"meta_priority": meta.get("priority", 0),
979-
"meta_score": meta.get("score", 0.0),
980-
"meta_timestamp": meta.get("timestamp", 0),
981-
}
982-
983-
if doc.embedding is not None:
984-
doc_dict["vector"] = doc.embedding
985-
else:
986-
doc_dict["vector"] = [ValkeyDocumentStore._DUMMY_VALUE] * self._embedding_dim
987-
1004+
doc_dict = {"id": doc.id, "payload": payload}
1005+
1006+
# _metadata_fields keys already have meta_ prefix
1007+
for field_name_with_prefix, field_type in self._metadata_fields.items():
1008+
# Extract original field name: meta_category -> category
1009+
field_name = field_name_with_prefix[5:] # Remove "meta_"
1010+
doc_dict[field_name_with_prefix] = meta.get(field_name, None)
1011+
1012+
doc_dict["vector"] = doc.embedding if doc.embedding else [self._DUMMY_VALUE] * self._embedding_dim
9881013
return doc_dict
9891014

9901015
@staticmethod
@@ -1003,6 +1028,7 @@ def _parse_documents_from_ft(raw: Any, *, with_embedding: bool) -> list[Document
10031028
# This occurs when no documents match the query filters or the index is empty
10041029
if not raw or raw[0] == 0:
10051030
return documents
1031+
10061032
for doc_info in raw[1].values():
10071033
# Get payload from doc_info
10081034
payload_data = doc_info.get(b"payload")
@@ -1039,16 +1065,15 @@ def _parse_documents_from_ft(raw: Any, *, with_embedding: bool) -> list[Document
10391065

10401066
@staticmethod
10411067
def _build_search_query_and_options(
1042-
embedding: list[float], filters: dict[str, Any] | None, limit: int, *, with_embedding: bool
1068+
embedding: list[float], filters: dict[str, Any] | None, limit: int, *, with_embedding: bool,
1069+
supported_fields: dict[str, str]
10431070
) -> tuple[str, FtSearchOptions]:
1044-
# Validate and normalize filters
10451071
if filters:
10461072
_validate_filters(filters)
1047-
filter_query = _normalize_filters(filters)
1073+
filter_query = _normalize_filters(filters, supported_fields)
10481074
else:
10491075
filter_query = "*"
10501076

1051-
# Set return fields
10521077
return_fields = [
10531078
ReturnField("$.id", alias="id"),
10541079
ReturnField("$.payload", alias="payload"),
@@ -1057,14 +1082,12 @@ def _build_search_query_and_options(
10571082
if with_embedding:
10581083
return_fields.append(ReturnField("$.vector", alias="vector"))
10591084

1060-
vector_param_name = "query_vector"
1061-
1062-
# Combine filters with vector search
1063-
query = f"{filter_query}=>[KNN {limit} @vector ${vector_param_name}]"
1085+
query = f"{filter_query}=>[KNN {limit} @vector $query_vector]"
10641086
query_options = FtSearchOptions(
1065-
params={vector_param_name: ValkeyDocumentStore._to_float32_bytes(embedding)}, return_fields=return_fields
1087+
params={"query_vector": ValkeyDocumentStore._to_float32_bytes(embedding)},
1088+
return_fields=return_fields,
1089+
limit=FtSearchLimit(offset=0, count=limit)
10661090
)
1067-
10681091
return query, query_options
10691092

10701093
@staticmethod
@@ -1112,3 +1135,32 @@ def _parse_metric(metric: str) -> DistanceMetricType:
11121135
allowed_metrics = list(ValkeyDocumentStore._METRIC_MAP.keys())
11131136
msg = f"Unsupported metric: {metric}. Allowed: {allowed_metrics}"
11141137
raise ValueError(msg) from err
1138+
1139+
@staticmethod
1140+
def _validate_and_normalize_metadata_fields(metadata_fields: dict[str, type[str] | type[int]]) -> dict[str, str]:
1141+
"""
1142+
Validate and normalize metadata field definitions.
1143+
1144+
:param metadata_fields: User-provided metadata field definitions mapping field names to Python types.
1145+
:return: Normalized metadata fields with meta_ prefix mapping to Valkey field types ("tag" or "numeric").
1146+
:raises ValueError: If field definitions are invalid.
1147+
"""
1148+
if not isinstance(metadata_fields, dict):
1149+
msg = "metadata_fields must be a dictionary"
1150+
raise ValueError(msg)
1151+
1152+
TYPE_MAPPING = {str: "tag", int: "numeric"}
1153+
1154+
normalized = {}
1155+
for field_name, field_type in metadata_fields.items():
1156+
if not isinstance(field_name, str) or not field_name:
1157+
msg = f"Field name must be a non-empty string, got {field_name!r}"
1158+
raise ValueError(msg)
1159+
1160+
if field_type not in TYPE_MAPPING:
1161+
msg = f"Unsupported field type {field_type!r} for field '{field_name}'. Supported: {list(TYPE_MAPPING.keys())}"
1162+
raise ValueError(msg)
1163+
1164+
normalized[f"meta_{field_name}"] = TYPE_MAPPING[field_type]
1165+
1166+
return normalized

0 commit comments

Comments
 (0)