Skip to content

Commit 0c0f31c

Browse files
committed
updating tests
1 parent 5e7cd90 commit 0c0f31c

3 files changed

Lines changed: 292 additions & 5 deletions

File tree

integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py

Lines changed: 166 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import json
56
from collections.abc import Mapping
67
from math import exp
7-
from typing import Any, Optional, Union
8+
from typing import Any, Literal, Optional, Union
89

10+
import requests
911
from haystack import default_from_dict, default_to_dict, logging
1012
from haystack.dataclasses import Document
1113
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1214
from haystack.document_stores.types import DuplicatePolicy
1315
from haystack.utils.auth import Secret
1416
from opensearchpy import AsyncHttpConnection, AsyncOpenSearch, OpenSearch
17+
from opensearchpy.exceptions import SerializationError
1518
from opensearchpy.helpers import async_bulk, bulk
1619

1720
from haystack_integrations.document_stores.opensearch.auth import AsyncAWSAuth, AWSAuth
@@ -21,6 +24,8 @@
2124

2225
Hosts = Union[str, list[Union[str, Mapping[str, Union[str, int]]]]]
2326

27+
ResponseFormat = Literal["json", "jdbc", "csv", "raw"]
28+
2429
# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to
2530
# True. Scaling uses the expit function (inverse of the logit function) after applying a scaling factor
2631
# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method).
@@ -1309,7 +1314,7 @@ def get_field_unique_values(
13091314
field_name = self._normalize_metadata_field_name(metadata_field)
13101315

13111316
# filter by search_term if provided
1312-
query = {"match_all": {}}
1317+
query: dict[str, Any] = {"match_all": {}}
13131318
if search_term:
13141319
# Use match_phrase for exact phrase matching to avoid tokenization issues
13151320
query = {"match_phrase": {"content": search_term}}
@@ -1370,7 +1375,7 @@ async def get_field_unique_values_async(
13701375
field_name = self._normalize_metadata_field_name(metadata_field)
13711376

13721377
# filter by search_term if provided
1373-
query = {"match_all": {}}
1378+
query: dict[str, Any] = {"match_all": {}}
13741379
if search_term:
13751380
# Use match_phrase for exact phrase matching to avoid tokenization issues
13761381
query = {"match_phrase": {"content": search_term}}
@@ -1413,5 +1418,161 @@ async def get_field_unique_values_async(
14131418

14141419
return unique_values, total_count
14151420

1416-
def query_sql(self, query: str):
1417-
pass
1421+
def query_sql(self, query: str, response_format: ResponseFormat = "json") -> Any:
1422+
"""
1423+
Execute a raw OpenSearch SQL query against the index.
1424+
1425+
:param query: The OpenSearch SQL query to execute
1426+
:param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/
1427+
:returns: The query results in the specified format. For JSON format, returns a list of dictionaries
1428+
(the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text.
1429+
"""
1430+
self._ensure_initialized()
1431+
assert self._client is not None
1432+
1433+
# For non-JSON formats, use requests directly to avoid deserialization issues
1434+
if response_format != "json":
1435+
try:
1436+
# Get connection info from the transport
1437+
connection = self._client.transport.get_connection()
1438+
base_url = connection.host
1439+
url = f"{base_url}/_plugins/_sql?format={response_format}"
1440+
1441+
headers = {"Content-Type": "application/json"}
1442+
auth = None
1443+
if self._http_auth:
1444+
if isinstance(self._http_auth, tuple):
1445+
auth = self._http_auth
1446+
elif isinstance(self._http_auth, AWSAuth):
1447+
# For AWS auth, we need to use the opensearchpy client
1448+
# Fall through to the try/except below
1449+
pass
1450+
1451+
verify = self._verify_certs if self._verify_certs is not None else True
1452+
timeout = self._timeout if self._timeout is not None else 30.0
1453+
response = requests.post(
1454+
url,
1455+
json={"query": query},
1456+
headers=headers,
1457+
auth=auth,
1458+
verify=verify,
1459+
timeout=timeout,
1460+
)
1461+
response.raise_for_status()
1462+
return response.text
1463+
except Exception as e:
1464+
# If requests fails (e.g., AWS auth), fall back to opensearchpy
1465+
# which will raise SerializationError that we can handle
1466+
pass
1467+
1468+
try:
1469+
body = {"query": query}
1470+
params = {"format": response_format}
1471+
1472+
response_data = self._client.transport.perform_request(
1473+
method="POST",
1474+
url="/_plugins/_sql",
1475+
params=params,
1476+
body=body,
1477+
)
1478+
1479+
if response_format == "json":
1480+
# extract only the query results
1481+
if isinstance(response_data, dict) and "hits" in response_data:
1482+
hits = response_data.get("hits", {}).get("hits", [])
1483+
# extract _source from each hit, which contains the actual document data
1484+
return [hit.get("_source", {}) for hit in hits]
1485+
return response_data
1486+
else:
1487+
return response_data if isinstance(response_data, str) else str(response_data)
1488+
except SerializationError:
1489+
# If we get here, it means requests failed above (likely AWS auth)
1490+
# and opensearchpy can't deserialize the response
1491+
# Re-raise as DocumentStoreError with a helpful message
1492+
msg = f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. This format may not be supported with the current authentication method."
1493+
raise DocumentStoreError(msg) from None
1494+
except Exception as e:
1495+
msg = f"Failed to execute SQL query in OpenSearch: {e!s}"
1496+
raise DocumentStoreError(msg) from e
1497+
1498+
async def query_sql_async(self, query: str, response_format: ResponseFormat = "json") -> Any:
1499+
"""
1500+
Asynchronously execute a raw OpenSearch SQL query against the index.
1501+
1502+
:param query: The OpenSearch SQL query to execute
1503+
:param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/
1504+
:returns: The query results in the specified format. For JSON format, returns a list of dictionaries
1505+
(the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text.
1506+
"""
1507+
await self._ensure_initialized_async()
1508+
assert self._async_client is not None
1509+
1510+
# For non-JSON formats, use httpx directly to avoid deserialization issues
1511+
if response_format != "json":
1512+
try:
1513+
import httpx
1514+
1515+
# Get connection info from the transport
1516+
connection = self._async_client.transport.get_connection()
1517+
base_url = connection.host
1518+
url = f"{base_url}/_plugins/_sql?format={response_format}"
1519+
1520+
headers = {"Content-Type": "application/json"}
1521+
auth = None
1522+
if self._http_auth:
1523+
if isinstance(self._http_auth, tuple):
1524+
auth = self._http_auth
1525+
elif isinstance(self._http_auth, AWSAuth):
1526+
# For AWS auth, we need to use the opensearchpy client
1527+
# Fall through to the try/except below
1528+
pass
1529+
1530+
verify = self._verify_certs if self._verify_certs is not None else True
1531+
timeout = httpx.Timeout(self._timeout if self._timeout else 30.0)
1532+
1533+
async with httpx.AsyncClient(verify=verify, timeout=timeout) as client:
1534+
response = await client.post(
1535+
url,
1536+
json={"query": query},
1537+
headers=headers,
1538+
auth=auth,
1539+
)
1540+
response.raise_for_status()
1541+
return response.text
1542+
except ImportError:
1543+
# httpx not available, fall through to opensearchpy
1544+
pass
1545+
except Exception as e:
1546+
# If httpx fails (e.g., AWS auth), fall back to opensearchpy
1547+
# which will raise SerializationError that we can handle
1548+
pass
1549+
1550+
try:
1551+
body = {"query": query}
1552+
params = {"format": response_format}
1553+
1554+
response_data = await self._async_client.transport.perform_request(
1555+
method="POST",
1556+
url="/_plugins/_sql",
1557+
params=params,
1558+
body=body,
1559+
)
1560+
1561+
if response_format == "json":
1562+
# extract only the query results
1563+
if isinstance(response_data, dict) and "hits" in response_data:
1564+
hits = response_data.get("hits", {}).get("hits", [])
1565+
# extract _source from each hit, which contains the actual document data
1566+
return [hit.get("_source", {}) for hit in hits]
1567+
return response_data
1568+
else:
1569+
return response_data if isinstance(response_data, str) else str(response_data)
1570+
except SerializationError:
1571+
# If we get here, it means httpx failed above (likely AWS auth or not installed)
1572+
# and opensearchpy can't deserialize the response
1573+
# Re-raise as DocumentStoreError with a helpful message
1574+
msg = f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. This format may not be supported with the current authentication method. Consider installing httpx for better support."
1575+
raise DocumentStoreError(msg) from None
1576+
except Exception as e:
1577+
msg = f"Failed to execute SQL query in OpenSearch: {e!s}"
1578+
raise DocumentStoreError(msg) from e

integrations/opensearch/tests/test_document_store.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,65 @@ def test_get_field_unique_values(self, document_store: OpenSearchDocumentStore):
772772
)
773773
assert set(unique_priorities_filtered) == {"1"}
774774
assert priority_count == 1
775+
776+
def test_query_sql(self, document_store: OpenSearchDocumentStore):
777+
"""
778+
Test executing SQL queries against the OpenSearch index.
779+
"""
780+
docs = [
781+
Document(content="Python programming", meta={"category": "A", "status": "active", "priority": 1}),
782+
Document(content="Java programming", meta={"category": "B", "status": "active", "priority": 2}),
783+
Document(content="Python scripting", meta={"category": "A", "status": "inactive", "priority": 3}),
784+
Document(content="JavaScript development", meta={"category": "C", "status": "active", "priority": 1}),
785+
]
786+
document_store.write_documents(docs)
787+
time.sleep(1) # Wait for documents to be indexed
788+
789+
# Test SQL query with JSON format (default)
790+
sql_query = (
791+
f"SELECT content, category, status, priority FROM {document_store._index} " # noqa: S608
792+
f"WHERE category = 'A' ORDER BY priority"
793+
)
794+
result = document_store.query_sql(sql_query, response_format="json")
795+
796+
# New format returns a list of dictionaries (the _source from each hit)
797+
assert len(result) == 2 # Two documents with category A
798+
assert isinstance(result, list)
799+
assert all(isinstance(row, dict) for row in result)
800+
801+
# Verify data contains expected values
802+
categories = [row.get("category") for row in result]
803+
assert all(cat == "A" for cat in categories)
804+
805+
# Verify all expected fields are present
806+
for row in result:
807+
assert "content" in row
808+
assert "category" in row
809+
assert "status" in row
810+
assert "priority" in row
811+
812+
# Test SQL query with CSV format
813+
result_csv = document_store.query_sql(sql_query, response_format="csv")
814+
assert isinstance(result_csv, str)
815+
assert "content" in result_csv
816+
assert "category" in result_csv
817+
818+
# Test SQL query with JDBC format
819+
result_jdbc = document_store.query_sql(sql_query, response_format="jdbc")
820+
# JDBC format can be dict or str depending on OpenSearch version
821+
assert result_jdbc is not None
822+
823+
# Test SQL query with RAW format
824+
result_raw = document_store.query_sql(sql_query, response_format="raw")
825+
assert isinstance(result_raw, str)
826+
827+
# Test COUNT query
828+
count_query = f"SELECT COUNT(*) as total FROM {document_store._index}" # noqa: S608
829+
count_result = document_store.query_sql(count_query, response_format="json")
830+
# COUNT query may return different format, check it's a valid response
831+
assert count_result is not None
832+
833+
# Test error handling for invalid SQL query
834+
invalid_query = "SELECT * FROM non_existent_index"
835+
with pytest.raises(DocumentStoreError, match="Failed to execute SQL query"):
836+
document_store.query_sql(invalid_query)

integrations/opensearch/tests/test_document_store_async.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
from haystack.dataclasses import Document
9+
from haystack.document_stores.errors import DocumentStoreError
910
from haystack.document_stores.types import DuplicatePolicy
1011

1112
from haystack_integrations.document_stores.opensearch.document_store import OpenSearchDocumentStore
@@ -572,3 +573,66 @@ async def test_get_field_unique_values(self, document_store: OpenSearchDocumentS
572573
)
573574
assert set(unique_priorities_filtered) == {"1"}
574575
assert priority_count == 1
576+
577+
@pytest.mark.asyncio
578+
async def test_query_sql(self, document_store: OpenSearchDocumentStore):
579+
"""
580+
Test executing SQL queries against the OpenSearch index.
581+
"""
582+
docs = [
583+
Document(content="Python programming", meta={"category": "A", "status": "active", "priority": 1}),
584+
Document(content="Java programming", meta={"category": "B", "status": "active", "priority": 2}),
585+
Document(content="Python scripting", meta={"category": "A", "status": "inactive", "priority": 3}),
586+
Document(content="JavaScript development", meta={"category": "C", "status": "active", "priority": 1}),
587+
]
588+
await document_store.write_documents_async(docs)
589+
time.sleep(1) # Wait for documents to be indexed
590+
591+
# Test SQL query with JSON format (default)
592+
sql_query = (
593+
f"SELECT content, category, status, priority FROM {document_store._index} " # noqa: S608
594+
f"WHERE category = 'A' ORDER BY priority"
595+
)
596+
result = await document_store.query_sql_async(sql_query, response_format="json")
597+
598+
# New format returns a list of dictionaries (the _source from each hit)
599+
assert len(result) == 2 # Two documents with category A
600+
assert isinstance(result, list)
601+
assert all(isinstance(row, dict) for row in result)
602+
603+
# Verify data contains expected values
604+
categories = [row.get("category") for row in result]
605+
assert all(cat == "A" for cat in categories)
606+
607+
# Verify all expected fields are present
608+
for row in result:
609+
assert "content" in row
610+
assert "category" in row
611+
assert "status" in row
612+
assert "priority" in row
613+
614+
# Test SQL query with CSV format
615+
result_csv = await document_store.query_sql_async(sql_query, response_format="csv")
616+
assert isinstance(result_csv, str)
617+
assert "content" in result_csv
618+
assert "category" in result_csv
619+
620+
# Test SQL query with JDBC format
621+
result_jdbc = await document_store.query_sql_async(sql_query, response_format="jdbc")
622+
# JDBC format can be dict or str depending on OpenSearch version
623+
assert result_jdbc is not None
624+
625+
# Test SQL query with RAW format
626+
result_raw = await document_store.query_sql_async(sql_query, response_format="raw")
627+
assert isinstance(result_raw, str)
628+
629+
# Test COUNT query
630+
count_query = f"SELECT COUNT(*) as total FROM {document_store._index}" # noqa: S608
631+
count_result = await document_store.query_sql_async(count_query, response_format="json")
632+
# COUNT query may return different format, check it's a valid response
633+
assert count_result is not None
634+
635+
# Test error handling for invalid SQL query
636+
invalid_query = "SELECT * FROM non_existent_index"
637+
with pytest.raises(DocumentStoreError, match="Failed to execute SQL query"):
638+
await document_store.query_sql_async(invalid_query)

0 commit comments

Comments
 (0)