Skip to content

Commit 29a99fc

Browse files
committed
reducing duplicated code
1 parent f1bb822 commit 29a99fc

1 file changed

Lines changed: 110 additions & 111 deletions

File tree

  • integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch

integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py

Lines changed: 110 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ async def count_documents_async(self) -> int:
303303
Asynchronously returns how many documents are present in the document store.
304304
:returns: Number of documents in the document store.
305305
"""
306-
self._ensure_initialized()
306+
self._ensure_initialized() # ensures _async_client is not None
307307
result = await self._async_client.count(index=self._index) # type: ignore
308308
return result["count"]
309309

@@ -602,6 +602,103 @@ def _prepare_delete_all_request(self, *, is_async: bool, refresh: bool) -> dict[
602602
"refresh": refresh,
603603
}
604604

605+
@staticmethod
606+
def _normalize_metadata_field(field: str) -> str:
607+
"""
608+
Removes the "meta." prefix from a field name if present.
609+
Documents are flattened in Elasticsearch, so metadata fields don't need the prefix.
610+
611+
:param field: The field name to normalize.
612+
:returns: The normalized field name without "meta." prefix.
613+
"""
614+
if field.startswith("meta."):
615+
return field[5:]
616+
return field
617+
618+
@staticmethod
619+
def _convert_sql_result(result: Any) -> dict[str, Any]:
620+
"""
621+
Converts Elasticsearch SQL query result to a dictionary.
622+
Handles ObjectApiResponse which behaves like a dict but isinstance() returns False.
623+
624+
:param result: The result from Elasticsearch SQL query.
625+
:returns: A dictionary containing the query results.
626+
"""
627+
if isinstance(result, dict):
628+
return result
629+
# Convert ObjectApiResponse to dict
630+
return dict(result.items()) # type: ignore
631+
632+
def _build_unique_values_query(
633+
self, metadata_field: str, search_term: Optional[str] = None, from_: int = 0, size: int = 10
634+
) -> dict[str, Any]:
635+
"""
636+
Builds the Elasticsearch query body for getting unique values.
637+
638+
:param metadata_field: The normalized metadata field name.
639+
:param search_term: Optional search term to filter the unique values.
640+
:param from_: The starting index for pagination.
641+
:param size: The number of unique values to return.
642+
:returns: The Elasticsearch query body.
643+
"""
644+
# Terms aggregation doesn't support 'from_' directly, so we fetch from_ + size and slice
645+
fetch_size = from_ + size if from_ > 0 else size
646+
647+
body: dict[str, Any] = {
648+
"aggs": {
649+
"unique_values": {
650+
"terms": {
651+
"field": metadata_field,
652+
"size": fetch_size,
653+
}
654+
}
655+
},
656+
"size": 0,
657+
}
658+
659+
if search_term:
660+
body["query"] = {
661+
"bool": {
662+
"filter": [
663+
{
664+
"prefix": {
665+
metadata_field: {
666+
"value": search_term,
667+
"case_insensitive": True,
668+
}
669+
}
670+
}
671+
]
672+
}
673+
}
674+
675+
return body
676+
677+
@staticmethod
678+
def _process_unique_values_result(result: Any, from_: int) -> dict[str, Any]:
679+
"""
680+
Processes the Elasticsearch aggregation result for unique values.
681+
682+
:param result: The Elasticsearch search result (ObjectApiResponse that behaves like a dict).
683+
:param from_: The starting index for pagination.
684+
:returns: A dictionary containing 'values' (list of unique values) and 'total' (total count).
685+
"""
686+
buckets = result["aggregations"]["unique_values"]["buckets"]
687+
688+
# Slice to handle from_ parameter
689+
if from_ > 0:
690+
buckets = buckets[from_:]
691+
692+
values = [bucket["key"] for bucket in buckets]
693+
# Get total distinct count (approximate if sum_other_doc_count > 0)
694+
sum_other = result["aggregations"]["unique_values"].get("sum_other_doc_count", 0)
695+
total = sum_other + len(result["aggregations"]["unique_values"]["buckets"])
696+
697+
return {
698+
"values": values,
699+
"total": total,
700+
}
701+
605702
async def delete_documents_async(
606703
self, document_ids: list[str], refresh: Literal["wait_for", True, False] = "wait_for"
607704
) -> None:
@@ -1189,8 +1286,7 @@ def get_field_min_max(self, metadata_field: str) -> dict[str, Any]:
11891286

11901287
try:
11911288
# Remove "meta." prefix if present, as documents are flattened in Elasticsearch
1192-
if metadata_field.startswith("meta."):
1193-
metadata_field = metadata_field[5:]
1289+
metadata_field = self._normalize_metadata_field(metadata_field)
11941290

11951291
body = {
11961292
"query": {"match_all": {}},
@@ -1223,8 +1319,7 @@ async def get_field_min_max_async(self, metadata_field: str) -> dict[str, Any]:
12231319

12241320
try:
12251321
# Remove "meta." prefix if present, as documents are flattened in Elasticsearch
1226-
if metadata_field.startswith("meta."):
1227-
metadata_field = metadata_field[5:]
1322+
metadata_field = self._normalize_metadata_field(metadata_field)
12281323

12291324
body = {
12301325
"query": {"match_all": {}},
@@ -1263,56 +1358,13 @@ def get_field_unique_values(
12631358

12641359
try:
12651360
# Remove "meta." prefix if present, as documents are flattened in Elasticsearch
1266-
if metadata_field.startswith("meta."):
1267-
metadata_field = metadata_field[5:]
1361+
metadata_field = self._normalize_metadata_field(metadata_field)
12681362

1269-
# Terms aggregation doesn't support 'from_' directly, so we fetch from_ + size and slice
1270-
fetch_size = from_ + size if from_ > 0 else size
1271-
1272-
body: dict[str, Any] = {
1273-
"aggs": {
1274-
"unique_values": {
1275-
"terms": {
1276-
"field": metadata_field,
1277-
"size": fetch_size,
1278-
}
1279-
}
1280-
},
1281-
"size": 0,
1282-
}
1283-
1284-
if search_term:
1285-
body["query"] = {
1286-
"bool": {
1287-
"filter": [
1288-
{
1289-
"prefix": {
1290-
metadata_field: {
1291-
"value": search_term,
1292-
"case_insensitive": True,
1293-
}
1294-
}
1295-
}
1296-
]
1297-
}
1298-
}
1363+
# Build the search body
1364+
body = self._build_unique_values_query(metadata_field, search_term, from_, size)
12991365

13001366
result = self.client.search(index=self._index, body=body) # type: ignore
1301-
buckets = result["aggregations"]["unique_values"]["buckets"]
1302-
1303-
# Slice to handle from_ parameter
1304-
if from_ > 0:
1305-
buckets = buckets[from_:]
1306-
1307-
values = [bucket["key"] for bucket in buckets]
1308-
# Get total distinct count (approximate if sum_other_doc_count > 0)
1309-
sum_other = result["aggregations"]["unique_values"].get("sum_other_doc_count", 0)
1310-
total = sum_other + len(result["aggregations"]["unique_values"]["buckets"])
1311-
1312-
return {
1313-
"values": values,
1314-
"total": total,
1315-
}
1367+
return self._process_unique_values_result(result, from_)
13161368
except Exception as e:
13171369
msg = f"Failed to get field unique values from Elasticsearch: {e!s}"
13181370
raise DocumentStoreError(msg) from e
@@ -1335,56 +1387,13 @@ async def get_field_unique_values_async(
13351387

13361388
try:
13371389
# Remove "meta." prefix if present, as documents are flattened in Elasticsearch
1338-
if metadata_field.startswith("meta."):
1339-
metadata_field = metadata_field[5:]
1340-
1341-
# Terms aggregation doesn't support 'from_' directly, so we fetch from_ + size and slice
1342-
fetch_size = from_ + size if from_ > 0 else size
1343-
1344-
body: dict[str, Any] = {
1345-
"aggs": {
1346-
"unique_values": {
1347-
"terms": {
1348-
"field": metadata_field,
1349-
"size": fetch_size,
1350-
}
1351-
}
1352-
},
1353-
"size": 0,
1354-
}
1390+
metadata_field = self._normalize_metadata_field(metadata_field)
13551391

1356-
if search_term:
1357-
body["query"] = {
1358-
"bool": {
1359-
"filter": [
1360-
{
1361-
"prefix": {
1362-
metadata_field: {
1363-
"value": search_term,
1364-
"case_insensitive": True,
1365-
}
1366-
}
1367-
}
1368-
]
1369-
}
1370-
}
1392+
# Build the search body
1393+
body = self._build_unique_values_query(metadata_field, search_term, from_, size)
13711394

13721395
result = await self.async_client.search(index=self._index, body=body) # type: ignore
1373-
buckets = result["aggregations"]["unique_values"]["buckets"]
1374-
1375-
# Slice to handle from_ parameter
1376-
if from_ > 0:
1377-
buckets = buckets[from_:]
1378-
1379-
values = [bucket["key"] for bucket in buckets]
1380-
# Get total distinct count (approximate if sum_other_doc_count > 0)
1381-
sum_other = result["aggregations"]["unique_values"].get("sum_other_doc_count", 0)
1382-
total = sum_other + len(result["aggregations"]["unique_values"]["buckets"])
1383-
1384-
return {
1385-
"values": values,
1386-
"total": total,
1387-
}
1396+
return self._process_unique_values_result(result, from_)
13881397
except Exception as e:
13891398
msg = f"Failed to get field unique values from Elasticsearch: {e!s}"
13901399
raise DocumentStoreError(msg) from e
@@ -1401,12 +1410,7 @@ def query_sql(self, query: str) -> dict[str, Any]:
14011410
try:
14021411
body = {"query": query}
14031412
result = self.client.sql.query(body=body) # type: ignore
1404-
# ObjectApiResponse is dict-like, convert to dict for runtime
1405-
# It behaves like a dict but isinstance() returns False
1406-
if isinstance(result, dict):
1407-
return result
1408-
# Convert ObjectApiResponse to dict
1409-
return dict(result.items()) # type: ignore
1413+
return self._convert_sql_result(result)
14101414
except Exception as e:
14111415
msg = f"Failed to execute SQL query in Elasticsearch: {e!s}"
14121416
raise DocumentStoreError(msg) from e
@@ -1423,12 +1427,7 @@ async def query_sql_async(self, query: str) -> dict[str, Any]:
14231427
try:
14241428
body = {"query": query}
14251429
result = await self.async_client.sql.query(body=body) # type: ignore
1426-
# ObjectApiResponse is dict-like, convert to dict for runtime
1427-
# It behaves like a dict but isinstance() returns False
1428-
if isinstance(result, dict):
1429-
return result
1430-
# Convert ObjectApiResponse to dict
1431-
return dict(result.items()) # type: ignore
1430+
return self._convert_sql_result(result)
14321431
except Exception as e:
14331432
msg = f"Failed to execute SQL query in Elasticsearch: {e!s}"
14341433
raise DocumentStoreError(msg) from e

0 commit comments

Comments
 (0)