Skip to content

Commit f186c9c

Browse files
committed
feat: add FAISSEmbeddingRetriever component
- Add components/retrievers/faiss/embedding_retriever.py with @component decorator, run(), run_async(), to_dict(), from_dict() with FilterPolicy support and backward-compat deserialization guard - Add components/__init__.py, components/retrievers/__init__.py, components/retrievers/faiss/__init__.py namespace packages - Add tests/test_embedding_retriever.py with 8 tests covering: basic run, runtime filters, top_k override, to_dict/from_dict roundtrip, FilterPolicy REPLACE/MERGE, ValueError on wrong store type, and end-to-end pipeline execution - Update pyproject.toml types script to also typecheck haystack_integrations.components.retrievers.faiss
1 parent c42e049 commit f186c9c

6 files changed

Lines changed: 301 additions & 1 deletion

File tree

integrations/faiss/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ integration = 'pytest -m "integration" {args:tests}'
6767
all = 'pytest {args:tests}'
6868
cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x {args:tests}'
6969

70-
types = "mypy -p haystack_integrations.document_stores.faiss {args}"
70+
types = "mypy -p haystack_integrations.document_stores.faiss -p haystack_integrations.components.retrievers.faiss {args}"
7171

7272
[tool.mypy]
7373
install_types = true
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from .embedding_retriever import FAISSEmbeddingRetriever
5+
6+
__all__ = ["FAISSEmbeddingRetriever"]
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
from typing import Any
5+
6+
from haystack import component, default_from_dict, default_to_dict
7+
from haystack.dataclasses import Document
8+
from haystack.document_stores.types import FilterPolicy
9+
from haystack.document_stores.types.filter_policy import apply_filter_policy
10+
11+
from haystack_integrations.document_stores.faiss import FAISSDocumentStore
12+
13+
14+
@component
15+
class FAISSEmbeddingRetriever:
16+
"""
17+
Retrieves documents from the `FAISSDocumentStore`, based on their dense embeddings.
18+
19+
Example usage:
20+
```python
21+
from haystack import Document, Pipeline
22+
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
23+
from haystack.document_stores.types import DuplicatePolicy
24+
25+
from haystack_integrations.document_stores.faiss import FAISSDocumentStore
26+
from haystack_integrations.components.retrievers.faiss import FAISSEmbeddingRetriever
27+
28+
document_store = FAISSDocumentStore(embedding_dim=768)
29+
30+
documents = [
31+
Document(content="There are over 7,000 languages spoken around the world today."),
32+
Document(content="Elephants have been observed to behave in a way that indicates a high level of intelligence."),
33+
Document(content="In certain places, you can witness the phenomenon of bioluminescent waves."),
34+
]
35+
36+
document_embedder = SentenceTransformersDocumentEmbedder()
37+
document_embedder.warm_up()
38+
documents_with_embeddings = document_embedder.run(documents)["documents"]
39+
40+
document_store.write_documents(documents_with_embeddings, policy=DuplicatePolicy.OVERWRITE)
41+
42+
query_pipeline = Pipeline()
43+
query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder())
44+
query_pipeline.add_component("retriever", FAISSEmbeddingRetriever(document_store=document_store))
45+
query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
46+
47+
query = "How many languages are there?"
48+
res = query_pipeline.run({"text_embedder": {"text": query}})
49+
50+
assert res["retriever"]["documents"][0].content == "There are over 7,000 languages spoken around the world today."
51+
```
52+
"""
53+
54+
def __init__(
55+
self,
56+
*,
57+
document_store: FAISSDocumentStore,
58+
filters: dict[str, Any] | None = None,
59+
top_k: int = 10,
60+
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
61+
):
62+
"""
63+
:param document_store: An instance of `FAISSDocumentStore`.
64+
:param filters: Filters applied to the retrieved Documents at initialisation time. At runtime, these are merged
65+
with any runtime filters according to the `filter_policy`.
66+
:param top_k: Maximum number of Documents to return.
67+
:param filter_policy: Policy to determine how init-time and runtime filters are combined.
68+
See `FilterPolicy` for details. Defaults to `FilterPolicy.REPLACE`.
69+
:raises ValueError: If `document_store` is not an instance of `FAISSDocumentStore`.
70+
"""
71+
if not isinstance(document_store, FAISSDocumentStore):
72+
msg = "document_store must be an instance of FAISSDocumentStore"
73+
raise ValueError(msg)
74+
75+
self.document_store = document_store
76+
self.filters = filters or {}
77+
self.top_k = top_k
78+
self.filter_policy = (
79+
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
80+
)
81+
82+
def to_dict(self) -> dict[str, Any]:
83+
"""
84+
Serializes the component to a dictionary.
85+
86+
:returns: Dictionary with serialized data.
87+
"""
88+
return default_to_dict(
89+
self,
90+
filters=self.filters,
91+
top_k=self.top_k,
92+
filter_policy=self.filter_policy.value,
93+
document_store=self.document_store.to_dict(),
94+
)
95+
96+
@classmethod
97+
def from_dict(cls, data: dict[str, Any]) -> "FAISSEmbeddingRetriever":
98+
"""
99+
Deserializes the component from a dictionary.
100+
101+
:param data: Dictionary to deserialize from.
102+
:returns: Deserialized component.
103+
"""
104+
doc_store_params = data["init_parameters"]["document_store"]
105+
data["init_parameters"]["document_store"] = FAISSDocumentStore.from_dict(doc_store_params)
106+
# Pipelines serialized with old versions of the component might not
107+
# have the filter_policy field.
108+
if filter_policy := data["init_parameters"].get("filter_policy"):
109+
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
110+
return default_from_dict(cls, data)
111+
112+
@component.output_types(documents=list[Document])
113+
def run(
114+
self,
115+
query_embedding: list[float],
116+
filters: dict[str, Any] | None = None,
117+
top_k: int | None = None,
118+
) -> dict[str, list[Document]]:
119+
"""
120+
Retrieve documents from the `FAISSDocumentStore`, based on their embeddings.
121+
122+
:param query_embedding: Embedding of the query.
123+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
124+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
125+
details.
126+
:param top_k: Maximum number of Documents to return. Overrides the value set at initialization.
127+
:returns: A dictionary with the following keys:
128+
- `documents`: List of `Document`s that are similar to `query_embedding`.
129+
"""
130+
filters = apply_filter_policy(self.filter_policy, self.filters, filters)
131+
top_k = top_k or self.top_k
132+
docs = self.document_store.search(query_embedding=query_embedding, top_k=top_k, filters=filters)
133+
return {"documents": docs}
134+
135+
@component.output_types(documents=list[Document])
136+
async def run_async(
137+
self,
138+
query_embedding: list[float],
139+
filters: dict[str, Any] | None = None,
140+
top_k: int | None = None,
141+
) -> dict[str, list[Document]]:
142+
"""
143+
Asynchronously retrieve documents from the `FAISSDocumentStore`, based on their embeddings.
144+
145+
Since FAISS search is CPU-bound and fully in-memory, this delegates directly to the synchronous
146+
`run()` method. No I/O or network calls are involved.
147+
148+
:param query_embedding: Embedding of the query.
149+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
150+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
151+
details.
152+
:param top_k: Maximum number of Documents to return. Overrides the value set at initialization.
153+
:returns: A dictionary with the following keys:
154+
- `documents`: List of `Document`s that are similar to `query_embedding`.
155+
"""
156+
return self.run(query_embedding=query_embedding, filters=filters, top_k=top_k)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
from haystack import Pipeline
7+
from haystack.dataclasses import Document
8+
from haystack.document_stores.types import FilterPolicy
9+
10+
from haystack_integrations.components.retrievers.faiss import FAISSEmbeddingRetriever
11+
from haystack_integrations.document_stores.faiss import FAISSDocumentStore
12+
13+
EMBEDDING_DIM = 3
14+
15+
16+
@pytest.fixture
17+
def document_store():
18+
"""In-memory FAISSDocumentStore with dim=3 for fast unit tests."""
19+
return FAISSDocumentStore(embedding_dim=EMBEDDING_DIM)
20+
21+
22+
@pytest.fixture
23+
def populated_store(document_store):
24+
"""Store pre-loaded with 3 documents that have embeddings and metadata."""
25+
docs = [
26+
Document(content="alpha", embedding=[1.0, 0.0, 0.0], meta={"category": "A"}),
27+
Document(content="beta", embedding=[0.0, 1.0, 0.0], meta={"category": "B"}),
28+
Document(content="gamma", embedding=[0.0, 0.0, 1.0], meta={"category": "A"}),
29+
]
30+
document_store.write_documents(docs)
31+
return document_store
32+
33+
34+
class TestFAISSEmbeddingRetriever:
35+
def test_run_with_query_embedding_only(self, populated_store):
36+
retriever = FAISSEmbeddingRetriever(document_store=populated_store, top_k=2)
37+
result = retriever.run(query_embedding=[1.0, 0.0, 0.0])
38+
39+
assert "documents" in result
40+
assert isinstance(result["documents"], list)
41+
assert len(result["documents"]) == 2
42+
# All returned items must be Document instances
43+
assert all(isinstance(d, Document) for d in result["documents"])
44+
45+
def test_run_with_filters(self, populated_store):
46+
retriever = FAISSEmbeddingRetriever(document_store=populated_store, top_k=3)
47+
filters = {"field": "meta.category", "operator": "==", "value": "A"}
48+
result = retriever.run(query_embedding=[1.0, 0.0, 0.0], filters=filters)
49+
50+
assert "documents" in result
51+
contents = [d.content for d in result["documents"]]
52+
# Only category-A docs should be returned
53+
assert all(d.meta["category"] == "A" for d in result["documents"])
54+
assert "beta" not in contents
55+
56+
def test_run_with_top_k_override(self, populated_store):
57+
retriever = FAISSEmbeddingRetriever(document_store=populated_store, top_k=3)
58+
result = retriever.run(query_embedding=[1.0, 0.0, 0.0], top_k=1)
59+
60+
assert len(result["documents"]) == 1
61+
62+
def test_to_dict_from_dict_roundtrip(self, document_store):
63+
retriever = FAISSEmbeddingRetriever(
64+
document_store=document_store,
65+
filters={"field": "meta.category", "operator": "==", "value": "A"},
66+
top_k=5,
67+
filter_policy=FilterPolicy.MERGE,
68+
)
69+
70+
serialized = retriever.to_dict()
71+
assert serialized["type"] == (
72+
"haystack_integrations.components.retrievers.faiss.embedding_retriever.FAISSEmbeddingRetriever"
73+
)
74+
assert serialized["init_parameters"]["top_k"] == 5
75+
assert serialized["init_parameters"]["filter_policy"] == FilterPolicy.MERGE.value
76+
assert "document_store" in serialized["init_parameters"]
77+
78+
restored = FAISSEmbeddingRetriever.from_dict(serialized)
79+
assert restored.top_k == 5
80+
assert restored.filter_policy == FilterPolicy.MERGE
81+
assert isinstance(restored.document_store, FAISSDocumentStore)
82+
83+
def test_filter_policy_replace(self, populated_store):
84+
"""REPLACE: runtime filters fully replace init-time filters."""
85+
init_filters = {"field": "meta.category", "operator": "==", "value": "A"}
86+
runtime_filters = {"field": "meta.category", "operator": "==", "value": "B"}
87+
88+
retriever = FAISSEmbeddingRetriever(
89+
document_store=populated_store,
90+
filters=init_filters,
91+
top_k=3,
92+
filter_policy=FilterPolicy.REPLACE,
93+
)
94+
result = retriever.run(query_embedding=[0.0, 1.0, 0.0], filters=runtime_filters)
95+
96+
# Only category B docs should appear — the init filter was replaced
97+
assert all(d.meta["category"] == "B" for d in result["documents"])
98+
99+
def test_filter_policy_merge(self, populated_store):
100+
"""MERGE: runtime filters are merged with init-time filters."""
101+
init_filters = {"field": "meta.category", "operator": "==", "value": "A"}
102+
103+
retriever = FAISSEmbeddingRetriever(
104+
document_store=populated_store,
105+
filters=init_filters,
106+
top_k=3,
107+
filter_policy=FilterPolicy.MERGE,
108+
)
109+
# Run without any runtime filter — init filter alone should apply
110+
result = retriever.run(query_embedding=[1.0, 0.0, 0.0])
111+
112+
assert len(result["documents"]) >= 1
113+
assert all(d.meta["category"] == "A" for d in result["documents"])
114+
115+
def test_invalid_document_store_type(self):
116+
with pytest.raises(ValueError, match="document_store must be an instance of FAISSDocumentStore"):
117+
FAISSEmbeddingRetriever(document_store="not_a_store") # type: ignore[arg-type]
118+
119+
def test_run_in_pipeline(self, populated_store):
120+
"""End-to-end: FAISSEmbeddingRetriever wired into a Haystack Pipeline."""
121+
retriever = FAISSEmbeddingRetriever(document_store=populated_store, top_k=2)
122+
123+
pipeline = Pipeline()
124+
pipeline.add_component("retriever", retriever)
125+
126+
result = pipeline.run({"retriever": {"query_embedding": [1.0, 0.0, 0.0]}})
127+
128+
assert "retriever" in result
129+
assert "documents" in result["retriever"]
130+
assert isinstance(result["retriever"]["documents"], list)
131+
assert len(result["retriever"]["documents"]) == 2
132+
assert all(isinstance(d, Document) for d in result["retriever"]["documents"])

0 commit comments

Comments
 (0)