Skip to content

Commit c93287c

Browse files
feat: add sparse vector storage to ElasticsearchDocumentStore (#2989)
* feat: add sparse vector storage to ElasticsearchDocumentStore (#2939) * test: update retriever tests for new ElasticsearchDocumentStore serialization - Update est_bm25_retriever.py and est_embedding_retriever.py to include sparse_vector_field in serialized document_store init parameters. * test: add sync and async tests for sparse vector storage - Add est_write_documents_with_sparse_vectors and est_write_documents_with_sparse_embedding_warning to est_document_store.py - Add est_write_documents_async_with_sparse_vectors to est_document_store_async.py - Update existing warning test in est_document_store_async.py - Add est_init_with_sparse_vector_field and update serialization tests. * style: fix B905 (strict zip) and E501 (line length) linting errors * style: fix mypy type inference for _default_mappings * refactor: address PR review feedback for sparse vector storage - Add SPECIAL_FIELDS validation for sparse_vector_field in __init__ - Add sparse_vector_field to __init__ docstring - Inject sparse_vector mapping into custom_mapping when both provided - Extract _handle_sparse_embedding helper to deduplicate write methods - Convert _deserialize_document to reconstruct SparseEmbedding on read * test: address PR review feedback for sparse vector tests - Add SPECIAL_FIELDS validation test - Add custom_mapping injection test - Add legacy from_dict backward compat test - Fix async test to use async_client for index deletion - Add retrieval reconstruction assertions to sync and async sparse tests * fixing docstrings * just as a safeguard original custom_mapping dict is left unchanged * organising imports * formatting * adding more tests + fixing typing issues * formatting * updating unit tests * adding unit tests for _handle_sparse_embedding function --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent 0172236 commit c93287c

File tree

6 files changed

+309
-31
lines changed

6 files changed

+309
-31
lines changed

integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import copy
45

56
# ruff: noqa: FBT002, FBT001 boolean-type-hint-positional-argument and boolean-default-value-positional-argument
67
# ruff: noqa: B008 function-call-in-default-argument
78
# ruff: noqa: S101 disable checks for uses of the assert keyword
8-
9-
109
from collections.abc import Mapping
1110
from dataclasses import replace
1211
from typing import Any, Literal
@@ -86,6 +85,7 @@ def __init__(
8685
api_key: Secret | str | None = Secret.from_env_var("ELASTIC_API_KEY", strict=False),
8786
api_key_id: Secret | str | None = Secret.from_env_var("ELASTIC_API_KEY_ID", strict=False),
8887
embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine",
88+
sparse_vector_field: str | None = None,
8989
**kwargs: Any,
9090
) -> None:
9191
"""
@@ -117,6 +117,9 @@ def __init__(
117117
To choose the most appropriate function, look for information about your embedding model.
118118
To understand how document scores are computed, see the Elasticsearch
119119
[documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params)
120+
:param sparse_vector_field: If set, the name of the Elasticsearch field where sparse embeddings
121+
will be stored using the `sparse_vector` field type. When not set, any `sparse_embedding`
122+
data on Documents is silently dropped during writes.
120123
:param **kwargs: Optional arguments that `Elasticsearch` takes.
121124
"""
122125
self._hosts = hosts
@@ -126,16 +129,26 @@ def __init__(
126129
self._api_key = api_key
127130
self._api_key_id = api_key_id
128131
self._embedding_similarity_function = embedding_similarity_function
132+
self._sparse_vector_field = sparse_vector_field
129133
self._custom_mapping = custom_mapping
130134
self._kwargs = kwargs
131135
self._initialized = False
132136

137+
if self._sparse_vector_field and self._sparse_vector_field in SPECIAL_FIELDS:
138+
msg = f"sparse_vector_field '{self._sparse_vector_field}' conflicts with a reserved field name."
139+
raise ValueError(msg)
140+
133141
if self._custom_mapping and not isinstance(self._custom_mapping, dict):
134142
msg = "custom_mapping must be a dictionary"
135143
raise ValueError(msg)
136144

145+
if self._custom_mapping and self._sparse_vector_field:
146+
self._custom_mapping = copy.deepcopy(custom_mapping) # original custom_mapping dict is left unchanged
147+
self._custom_mapping.setdefault("properties", {}) # type: ignore # can't be None here
148+
self._custom_mapping["properties"][self._sparse_vector_field] = {"type": "sparse_vector"} # type: ignore # can't be None here
149+
137150
if not self._custom_mapping:
138-
self._default_mappings = {
151+
self._default_mappings: dict[str, Any] = {
139152
"properties": {
140153
"embedding": {
141154
"type": "dense_vector",
@@ -156,6 +169,8 @@ def __init__(
156169
}
157170
],
158171
}
172+
if self._sparse_vector_field:
173+
self._default_mappings["properties"][self._sparse_vector_field] = {"type": "sparse_vector"}
159174

160175
def _ensure_initialized(self) -> None:
161176
"""
@@ -277,6 +292,7 @@ def to_dict(self) -> dict[str, Any]:
277292
api_key=self._api_key.to_dict() if isinstance(self._api_key, Secret) else None,
278293
api_key_id=self._api_key_id.to_dict() if isinstance(self._api_key_id, Secret) else None,
279294
embedding_similarity_function=self._embedding_similarity_function,
295+
sparse_vector_field=self._sparse_vector_field,
280296
**self._kwargs,
281297
)
282298

@@ -404,12 +420,11 @@ async def filter_documents_async(self, filters: dict[str, Any] | None = None) ->
404420
documents = await self._search_documents_async(query=query)
405421
return documents
406422

407-
@staticmethod
408-
def _deserialize_document(hit: dict[str, Any]) -> Document:
423+
def _deserialize_document(self, hit: dict[str, Any]) -> Document:
409424
"""
410425
Creates a `Document` from the search hit provided.
411426
412-
This is mostly useful in self.filter_documents().
427+
This is mostly useful in self.filter_documents() and self.filter_documents_async().
413428
414429
:param hit: A search hit from Elasticsearch.
415430
:returns: `Document` created from the search hit.
@@ -420,8 +435,40 @@ def _deserialize_document(hit: dict[str, Any]) -> Document:
420435
data["metadata"]["highlighted"] = hit["highlight"]
421436
data["score"] = hit["_score"]
422437

438+
if self._sparse_vector_field and self._sparse_vector_field in data:
439+
es_sparse = data.pop(self._sparse_vector_field)
440+
sorted_items = sorted(es_sparse.items(), key=lambda x: int(x[0]))
441+
data["sparse_embedding"] = {
442+
"indices": [int(k) for k, _ in sorted_items],
443+
"values": [v for _, v in sorted_items],
444+
}
445+
423446
return Document.from_dict(data)
424447

448+
def _handle_sparse_embedding(self, doc_dict: dict[str, Any], doc_id: str) -> None:
449+
"""
450+
Extracts the sparse_embedding from a document dict and converts it to the Elasticsearch sparse_vector format.
451+
452+
:param doc_dict: The dictionary representation of the document.
453+
:param doc_id: The document ID, used for warning messages.
454+
"""
455+
if "sparse_embedding" not in doc_dict:
456+
return
457+
sparse_embedding = doc_dict.pop("sparse_embedding")
458+
if not sparse_embedding:
459+
return
460+
if self._sparse_vector_field:
461+
doc_dict[self._sparse_vector_field] = {
462+
str(idx): val for idx, val in zip(sparse_embedding["indices"], sparse_embedding["values"], strict=True)
463+
}
464+
else:
465+
logger.warning(
466+
"Document {doc_id} has the `sparse_embedding` field set, "
467+
"but `sparse_vector_field` is not configured for this ElasticsearchDocumentStore. "
468+
"The `sparse_embedding` field will be ignored.",
469+
doc_id=doc_id,
470+
)
471+
425472
def write_documents(
426473
self,
427474
documents: list[Document],
@@ -457,16 +504,7 @@ def write_documents(
457504
elasticsearch_actions = []
458505
for doc in documents:
459506
doc_dict = doc.to_dict()
460-
461-
if "sparse_embedding" in doc_dict:
462-
sparse_embedding = doc_dict.pop("sparse_embedding", None)
463-
if sparse_embedding:
464-
logger.warning(
465-
"Document {doc_id} has the `sparse_embedding` field set,"
466-
"but storing sparse embeddings in Elasticsearch is not currently supported."
467-
"The `sparse_embedding` field will be ignored.",
468-
doc_id=doc.id,
469-
)
507+
self._handle_sparse_embedding(doc_dict, doc.id)
470508
elasticsearch_actions.append(
471509
{
472510
"_op_type": action,
@@ -544,16 +582,7 @@ async def write_documents_async(
544582
actions = []
545583
for doc in documents:
546584
doc_dict = doc.to_dict()
547-
548-
if "sparse_embedding" in doc_dict:
549-
sparse_embedding = doc_dict.pop("sparse_embedding", None)
550-
if sparse_embedding:
551-
logger.warning(
552-
"Document {doc_id} has the `sparse_embedding` field set,"
553-
"but storing sparse embeddings in Elasticsearch is not currently supported."
554-
"The `sparse_embedding` field will be ignored.",
555-
doc_id=doc.id,
556-
)
585+
self._handle_sparse_embedding(doc_dict, doc.id)
557586

558587
action = {
559588
"_op_type": "create" if policy == DuplicatePolicy.FAIL else "index",

integrations/elasticsearch/tests/test_bm25_retriever.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_to_dict(_mock_elasticsearch_client):
5656
"custom_mapping": None,
5757
"index": "default",
5858
"embedding_similarity_function": "cosine",
59+
"sparse_vector_field": None,
5960
},
6061
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
6162
},
@@ -74,7 +75,7 @@ def test_from_dict(_mock_elasticsearch_client):
7475
"type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever",
7576
"init_parameters": {
7677
"document_store": {
77-
"init_parameters": {"hosts": "some fake host", "index": "default"},
78+
"init_parameters": {"hosts": "some fake host", "index": "default", "sparse_vector_field": None},
7879
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
7980
},
8081
"filters": {},
@@ -99,7 +100,7 @@ def test_from_dict_no_filter_policy(_mock_elasticsearch_client):
99100
"type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever",
100101
"init_parameters": {
101102
"document_store": {
102-
"init_parameters": {"hosts": "some fake host", "index": "default"},
103+
"init_parameters": {"hosts": "some fake host", "index": "default", "sparse_vector_field": None},
103104
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
104105
},
105106
"filters": {},

0 commit comments

Comments
 (0)