Skip to content

Commit 2ba545c

Browse files
authored
feat: support additional search_kwargs with OpenSearchEmbeddingRetriever (#2825)
* feat: support additional with * fix lint * apply feedback * add support in OpenSearchHybridRetriever as well * fix tests * fix tests * fix tests * fix tests * fix tests * apply feedback
1 parent b047094 commit 2ba545c

File tree

5 files changed

+81
-2
lines changed

5 files changed

+81
-2
lines changed

integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
custom_query: dict[str, Any] | None = None,
3535
raise_on_failure: bool = True,
3636
efficient_filtering: bool = False,
37+
search_kwargs: dict[str, Any] | None = None,
3738
):
3839
"""
3940
Create the OpenSearchEmbeddingRetriever component.
@@ -90,6 +91,18 @@ def __init__(
9091
If `False`, logs a warning and returns an empty list.
9192
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
9293
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
94+
:param search_kwargs: Additional keyword arguments for finetuning the embedding search.
95+
E.g., to specify `k` and `ef_search`
96+
```python
97+
{
98+
"k": 20, # See https://docs.opensearch.org/latest/vector-search/vector-search-techniques/approximate-knn/#the-number-of-returned-results
99+
"method_parameters": {
100+
"ef_search": 512, # See https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
101+
}
102+
}
103+
```
104+
For a full list of available parameters, see the OpenSearch documentation:
105+
https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#request-body-fields
93106
94107
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
95108
"""
@@ -106,6 +119,7 @@ def __init__(
106119
self._custom_query = custom_query
107120
self._raise_on_failure = raise_on_failure
108121
self._efficient_filtering = efficient_filtering
122+
self._search_kwargs = search_kwargs
109123

110124
def to_dict(self) -> dict[str, Any]:
111125
"""
@@ -123,6 +137,7 @@ def to_dict(self) -> dict[str, Any]:
123137
custom_query=self._custom_query,
124138
raise_on_failure=self._raise_on_failure,
125139
efficient_filtering=self._efficient_filtering,
140+
search_kwargs=self._search_kwargs,
126141
)
127142

128143
@classmethod
@@ -155,6 +170,7 @@ def run(
155170
custom_query: dict[str, Any] | None = None,
156171
efficient_filtering: bool | None = None,
157172
document_store: OpenSearchDocumentStore | None = None,
173+
search_kwargs: dict[str, Any] | None = None,
158174
) -> dict[str, list[Document]]:
159175
"""
160176
Retrieve documents using a vector similarity metric.
@@ -208,6 +224,19 @@ def run(
208224
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
209225
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
210226
:param document_store: Optional instance of OpenSearchDocumentStore to use with the Retriever.
227+
:param search_kwargs: Additional keyword arguments for finetuning the embedding search. If not provided,
228+
defaults to the parameter set at initialization (if any).
229+
E.g., to specify `k` and `ef_search`
230+
```python
231+
{
232+
"k": 20, # See https://docs.opensearch.org/latest/vector-search/vector-search-techniques/approximate-knn/#the-number-of-returned-results
233+
"method_parameters": {
234+
"ef_search": 512, # See https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
235+
}
236+
}
237+
```
238+
For a full list of available parameters, see the OpenSearch documentation:
239+
https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#request-body-fields
211240
212241
:returns:
213242
Dictionary with key "documents" containing the retrieved Documents.
@@ -223,6 +252,8 @@ def run(
223252
custom_query = self._custom_query
224253
if efficient_filtering is None:
225254
efficient_filtering = self._efficient_filtering
255+
if search_kwargs is None:
256+
search_kwargs = self._search_kwargs
226257

227258
docs: list[Document] = []
228259

@@ -241,6 +272,7 @@ def run(
241272
top_k=top_k,
242273
custom_query=custom_query,
243274
efficient_filtering=efficient_filtering,
275+
search_kwargs=search_kwargs,
244276
)
245277
except Exception as e:
246278
if self._raise_on_failure:
@@ -264,6 +296,7 @@ async def run_async(
264296
custom_query: dict[str, Any] | None = None,
265297
efficient_filtering: bool | None = None,
266298
document_store: OpenSearchDocumentStore | None = None,
299+
search_kwargs: dict[str, Any] | None = None,
267300
) -> dict[str, list[Document]]:
268301
"""
269302
Asynchronously retrieve documents using a vector similarity metric.
@@ -317,6 +350,19 @@ async def run_async(
317350
:param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search.
318351
This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib".
319352
:param document_store: Optional instance of OpenSearchDocumentStore to use with the Retriever.
353+
:param search_kwargs: Additional keyword arguments for finetuning the embedding search. If not provided,
354+
defaults to the parameter set at initialization (if any).
355+
E.g., to specify `k` and `ef_search`
356+
```python
357+
{
358+
"k": 20, # See https://docs.opensearch.org/latest/vector-search/vector-search-techniques/approximate-knn/#the-number-of-returned-results
359+
"method_parameters": {
360+
"ef_search": 512, # See https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
361+
}
362+
}
363+
```
364+
For a full list of available parameters, see the OpenSearch documentation:
365+
https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#request-body-fields
320366
321367
:returns:
322368
Dictionary with key "documents" containing the retrieved Documents.
@@ -332,6 +378,8 @@ async def run_async(
332378
custom_query = self._custom_query
333379
if efficient_filtering is None:
334380
efficient_filtering = self._efficient_filtering
381+
if search_kwargs is None:
382+
search_kwargs = self._search_kwargs
335383

336384
docs: list[Document] = []
337385

@@ -350,6 +398,7 @@ async def run_async(
350398
top_k=top_k,
351399
custom_query=custom_query,
352400
efficient_filtering=efficient_filtering,
401+
search_kwargs=search_kwargs,
353402
)
354403
except Exception as e:
355404
if self._raise_on_failure:

integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/open_search_hybrid_retriever.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
top_k_embedding: int = 10,
105105
filter_policy_embedding: str | FilterPolicy = FilterPolicy.REPLACE,
106106
custom_query_embedding: dict[str, Any] | None = None,
107+
search_kwargs_embedding: dict[str, Any] | None = None,
107108
# DocumentJoiner
108109
join_mode: str | JoinMode = JoinMode.RECIPROCAL_RANK_FUSION,
109110
weights: list[float] | None = None,
@@ -153,6 +154,8 @@ def __init__(
153154
The filter policy for the embedding retriever.
154155
:param custom_query_embedding:
155156
A custom query for the embedding retriever.
157+
:param search_kwargs_embedding:
158+
Additional search kwargs for the embedding retriever.
156159
:param join_mode:
157160
The mode to use for joining the results from the BM25 and embedding retrievers.
158161
:param weights:
@@ -185,6 +188,7 @@ def __init__(
185188
self.top_k_embedding = top_k_embedding
186189
self.filter_policy_embedding = filter_policy_embedding
187190
self.custom_query_embedding = custom_query_embedding
191+
self.search_kwargs_embedding = search_kwargs_embedding
188192

189193
# DocumentJoiner
190194
self.join_mode = join_mode
@@ -209,6 +213,7 @@ def __init__(
209213
"top_k": self.top_k_embedding,
210214
"filter_policy": self.filter_policy_embedding,
211215
"custom_query": self.custom_query_embedding,
216+
"search_kwargs": self.search_kwargs_embedding,
212217
},
213218
"document_joiner": {
214219
"join_mode": self.join_mode,
@@ -311,6 +316,7 @@ def to_dict(self):
311316
else self.filter_policy_embedding
312317
),
313318
custom_query_embedding=self.custom_query_embedding,
319+
search_kwargs_embedding=self.search_kwargs_embedding,
314320
# DocumentJoiner
315321
join_mode=(self.join_mode.value if isinstance(self.join_mode, JoinMode) else self.join_mode),
316322
weights=self.weights,

integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,7 @@ def _prepare_embedding_search_request(
14091409
top_k: int,
14101410
custom_query: dict[str, Any] | None,
14111411
efficient_filtering: bool = False,
1412+
search_kwargs: dict[str, Any] | None = None,
14121413
) -> dict[str, Any]:
14131414
if not query_embedding:
14141415
msg = "query_embedding must be a non-empty list of floats"
@@ -1434,6 +1435,7 @@ def _prepare_embedding_search_request(
14341435
"embedding": {
14351436
"vector": query_embedding,
14361437
"k": top_k,
1438+
**(search_kwargs or {}),
14371439
}
14381440
}
14391441
}
@@ -1465,6 +1467,7 @@ def _embedding_retrieval(
14651467
top_k: int = 10,
14661468
custom_query: dict[str, Any] | None = None,
14671469
efficient_filtering: bool = False,
1470+
search_kwargs: dict[str, Any] | None = None,
14681471
) -> list[Document]:
14691472
"""
14701473
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
@@ -1484,6 +1487,7 @@ def _embedding_retrieval(
14841487
top_k=top_k,
14851488
custom_query=custom_query,
14861489
efficient_filtering=efficient_filtering,
1490+
search_kwargs=search_kwargs,
14871491
)
14881492
return self._search_documents(search_params)
14891493

@@ -1495,6 +1499,7 @@ async def _embedding_retrieval_async(
14951499
top_k: int = 10,
14961500
custom_query: dict[str, Any] | None = None,
14971501
efficient_filtering: bool = False,
1502+
search_kwargs: dict[str, Any] | None = None,
14981503
) -> list[Document]:
14991504
"""
15001505
Asynchronously retrieves documents that are most similar to the query embedding using a vector similarity
@@ -1515,6 +1520,7 @@ async def _embedding_retrieval_async(
15151520
top_k=top_k,
15161521
custom_query=custom_query,
15171522
efficient_filtering=efficient_filtering,
1523+
search_kwargs=search_kwargs,
15181524
)
15191525
return await self._search_documents_async(search_params)
15201526

integrations/opensearch/tests/test_embedding_retriever.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def test_to_dict(_mock_opensearch_client):
8585
"custom_query": {"some": "custom query"},
8686
"raise_on_failure": True,
8787
"efficient_filtering": False,
88+
"search_kwargs": None,
8889
},
8990
}
9091

@@ -145,6 +146,7 @@ def test_run():
145146
top_k=10,
146147
custom_query=None,
147148
efficient_filtering=False,
149+
search_kwargs=None,
148150
)
149151
assert len(res) == 1
150152
assert len(res["documents"]) == 1
@@ -164,6 +166,7 @@ async def test_run_async():
164166
top_k=10,
165167
custom_query=None,
166168
efficient_filtering=False,
169+
search_kwargs=None,
167170
)
168171
assert len(res) == 1
169172
assert len(res["documents"]) == 1
@@ -180,6 +183,7 @@ def test_run_init_params():
180183
top_k=11,
181184
custom_query="custom_query",
182185
efficient_filtering=True,
186+
search_kwargs={"k": 10},
183187
)
184188
res = retriever.run(query_embedding=[0.5, 0.7])
185189
mock_store._embedding_retrieval.assert_called_once_with(
@@ -188,6 +192,7 @@ def test_run_init_params():
188192
top_k=11,
189193
custom_query="custom_query",
190194
efficient_filtering=True,
195+
search_kwargs={"k": 10},
191196
)
192197
assert len(res) == 1
193198
assert len(res["documents"]) == 1
@@ -204,6 +209,7 @@ async def test_run_async_init_params():
204209
filters={"from": "init"},
205210
top_k=11,
206211
custom_query="custom_query",
212+
search_kwargs={"k": 10},
207213
)
208214
res = await retriever.run_async(query_embedding=[0.5, 0.7])
209215
mock_store._embedding_retrieval_async.assert_called_once_with(
@@ -212,6 +218,7 @@ async def test_run_async_init_params():
212218
top_k=11,
213219
custom_query="custom_query",
214220
efficient_filtering=False,
221+
search_kwargs={"k": 10},
215222
)
216223
assert len(res) == 1
217224
assert len(res["documents"]) == 1
@@ -222,14 +229,19 @@ async def test_run_async_init_params():
222229
def test_run_time_params():
223230
mock_store = Mock(spec=OpenSearchDocumentStore)
224231
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
225-
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
226-
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True)
232+
retriever = OpenSearchEmbeddingRetriever(
233+
document_store=mock_store, filters={"from": "init"}, top_k=11, search_kwargs={"k": 10}
234+
)
235+
res = retriever.run(
236+
query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True, search_kwargs={"k": 9}
237+
)
227238
mock_store._embedding_retrieval.assert_called_once_with(
228239
query_embedding=[0.5, 0.7],
229240
filters={"from": "run"},
230241
top_k=9,
231242
custom_query=None,
232243
efficient_filtering=True,
244+
search_kwargs={"k": 9},
233245
)
234246
assert len(res) == 1
235247
assert len(res["documents"]) == 1
@@ -249,6 +261,7 @@ async def test_run_async_time_params():
249261
top_k=9,
250262
custom_query=None,
251263
efficient_filtering=False,
264+
search_kwargs=None,
252265
)
253266
assert len(res) == 1
254267
assert len(res["documents"]) == 1
@@ -288,6 +301,7 @@ def test_run_with_runtime_document_store():
288301
top_k=10,
289302
custom_query=None,
290303
efficient_filtering=False,
304+
search_kwargs=None,
291305
)
292306
initial_store._embedding_retrieval.assert_not_called()
293307

@@ -325,6 +339,7 @@ async def test_run_async_with_runtime_document_store():
325339
top_k=10,
326340
custom_query=None,
327341
efficient_filtering=False,
342+
search_kwargs=None,
328343
)
329344
initial_store._embedding_retrieval_async.assert_not_called()
330345

integrations/opensearch/tests/test_open_search_hybrid_retriever.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TestOpenSearchHybridRetriever:
8888
"weights": None,
8989
"top_k": None,
9090
"sort_by_score": True,
91+
"search_kwargs_embedding": None,
9192
},
9293
}
9394

@@ -224,6 +225,7 @@ def test_run_with_extra_runtime_params(self, mock_embedder):
224225
top_k=1,
225226
custom_query=None,
226227
efficient_filtering=False,
228+
search_kwargs=None,
227229
)
228230

229231
def test_run_in_pipeline(self, mock_embedder):
@@ -256,4 +258,5 @@ def test_run_in_pipeline(self, mock_embedder):
256258
top_k=10,
257259
custom_query=None,
258260
efficient_filtering=False,
261+
search_kwargs=None,
259262
)

0 commit comments

Comments
 (0)