|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +import json |
5 | 6 | from collections.abc import Mapping |
6 | 7 | from math import exp |
7 | | -from typing import Any, Optional, Union |
| 8 | +from typing import Any, Literal, Optional, Union |
8 | 9 |
|
| 10 | +import requests |
9 | 11 | from haystack import default_from_dict, default_to_dict, logging |
10 | 12 | from haystack.dataclasses import Document |
11 | 13 | from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError |
12 | 14 | from haystack.document_stores.types import DuplicatePolicy |
13 | 15 | from haystack.utils.auth import Secret |
14 | 16 | from opensearchpy import AsyncHttpConnection, AsyncOpenSearch, OpenSearch |
| 17 | +from opensearchpy.exceptions import SerializationError |
15 | 18 | from opensearchpy.helpers import async_bulk, bulk |
16 | 19 |
|
17 | 20 | from haystack_integrations.document_stores.opensearch.auth import AsyncAWSAuth, AWSAuth |
|
21 | 24 |
|
22 | 25 | Hosts = Union[str, list[Union[str, Mapping[str, Union[str, int]]]]] |
23 | 26 |
|
| 27 | +ResponseFormat = Literal["json", "jdbc", "csv", "raw"] |
| 28 | + |
24 | 29 | # document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to |
25 | 30 | # True. Scaling uses the expit function (inverse of the logit function) after applying a scaling factor |
26 | 31 | # (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method). |
@@ -1309,7 +1314,7 @@ def get_field_unique_values( |
1309 | 1314 | field_name = self._normalize_metadata_field_name(metadata_field) |
1310 | 1315 |
|
1311 | 1316 | # filter by search_term if provided |
1312 | | - query = {"match_all": {}} |
| 1317 | + query: dict[str, Any] = {"match_all": {}} |
1313 | 1318 | if search_term: |
1314 | 1319 | # Use match_phrase for exact phrase matching to avoid tokenization issues |
1315 | 1320 | query = {"match_phrase": {"content": search_term}} |
@@ -1370,7 +1375,7 @@ async def get_field_unique_values_async( |
1370 | 1375 | field_name = self._normalize_metadata_field_name(metadata_field) |
1371 | 1376 |
|
1372 | 1377 | # filter by search_term if provided |
1373 | | - query = {"match_all": {}} |
| 1378 | + query: dict[str, Any] = {"match_all": {}} |
1374 | 1379 | if search_term: |
1375 | 1380 | # Use match_phrase for exact phrase matching to avoid tokenization issues |
1376 | 1381 | query = {"match_phrase": {"content": search_term}} |
@@ -1413,5 +1418,161 @@ async def get_field_unique_values_async( |
1413 | 1418 |
|
1414 | 1419 | return unique_values, total_count |
1415 | 1420 |
|
1416 | | - def query_sql(self, query: str): |
1417 | | - pass |
| 1421 | + def query_sql(self, query: str, response_format: ResponseFormat = "json") -> Any: |
| 1422 | + """ |
| 1423 | + Execute a raw OpenSearch SQL query against the index. |
| 1424 | +
|
| 1425 | + :param query: The OpenSearch SQL query to execute |
| 1426 | + :param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/ |
| 1427 | + :returns: The query results in the specified format. For JSON format, returns a list of dictionaries |
| 1428 | + (the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text. |
| 1429 | + """ |
| 1430 | + self._ensure_initialized() |
| 1431 | + assert self._client is not None |
| 1432 | + |
| 1433 | + # For non-JSON formats, use requests directly to avoid deserialization issues |
| 1434 | + if response_format != "json": |
| 1435 | + try: |
| 1436 | + # Get connection info from the transport |
| 1437 | + connection = self._client.transport.get_connection() |
| 1438 | + base_url = connection.host |
| 1439 | + url = f"{base_url}/_plugins/_sql?format={response_format}" |
| 1440 | + |
| 1441 | + headers = {"Content-Type": "application/json"} |
| 1442 | + auth = None |
| 1443 | + if self._http_auth: |
| 1444 | + if isinstance(self._http_auth, tuple): |
| 1445 | + auth = self._http_auth |
| 1446 | + elif isinstance(self._http_auth, AWSAuth): |
| 1447 | + # For AWS auth, we need to use the opensearchpy client |
| 1448 | + # Fall through to the try/except below |
| 1449 | + pass |
| 1450 | + |
| 1451 | + verify = self._verify_certs if self._verify_certs is not None else True |
| 1452 | + timeout = self._timeout if self._timeout is not None else 30.0 |
| 1453 | + response = requests.post( |
| 1454 | + url, |
| 1455 | + json={"query": query}, |
| 1456 | + headers=headers, |
| 1457 | + auth=auth, |
| 1458 | + verify=verify, |
| 1459 | + timeout=timeout, |
| 1460 | + ) |
| 1461 | + response.raise_for_status() |
| 1462 | + return response.text |
| 1463 | + except Exception as e: |
| 1464 | + # If requests fails (e.g., AWS auth), fall back to opensearchpy |
| 1465 | + # which will raise SerializationError that we can handle |
| 1466 | + pass |
| 1467 | + |
| 1468 | + try: |
| 1469 | + body = {"query": query} |
| 1470 | + params = {"format": response_format} |
| 1471 | + |
| 1472 | + response_data = self._client.transport.perform_request( |
| 1473 | + method="POST", |
| 1474 | + url="/_plugins/_sql", |
| 1475 | + params=params, |
| 1476 | + body=body, |
| 1477 | + ) |
| 1478 | + |
| 1479 | + if response_format == "json": |
| 1480 | + # extract only the query results |
| 1481 | + if isinstance(response_data, dict) and "hits" in response_data: |
| 1482 | + hits = response_data.get("hits", {}).get("hits", []) |
| 1483 | + # extract _source from each hit, which contains the actual document data |
| 1484 | + return [hit.get("_source", {}) for hit in hits] |
| 1485 | + return response_data |
| 1486 | + else: |
| 1487 | + return response_data if isinstance(response_data, str) else str(response_data) |
| 1488 | + except SerializationError: |
| 1489 | + # If we get here, it means requests failed above (likely AWS auth) |
| 1490 | + # and opensearchpy can't deserialize the response |
| 1491 | + # Re-raise as DocumentStoreError with a helpful message |
| 1492 | + msg = f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. This format may not be supported with the current authentication method." |
| 1493 | + raise DocumentStoreError(msg) from None |
| 1494 | + except Exception as e: |
| 1495 | + msg = f"Failed to execute SQL query in OpenSearch: {e!s}" |
| 1496 | + raise DocumentStoreError(msg) from e |
| 1497 | + |
| 1498 | + async def query_sql_async(self, query: str, response_format: ResponseFormat = "json") -> Any: |
| 1499 | + """ |
| 1500 | + Asynchronously execute a raw OpenSearch SQL query against the index. |
| 1501 | +
|
| 1502 | + :param query: The OpenSearch SQL query to execute |
| 1503 | + :param response_format: The format of the response. See https://docs.opensearch.org/latest/search-plugins/sql/response-formats/ |
| 1504 | + :returns: The query results in the specified format. For JSON format, returns a list of dictionaries |
| 1505 | + (the _source from each hit). For other formats (csv, jdbc, raw), returns the response as text. |
| 1506 | + """ |
| 1507 | + await self._ensure_initialized_async() |
| 1508 | + assert self._async_client is not None |
| 1509 | + |
| 1510 | + # For non-JSON formats, use httpx directly to avoid deserialization issues |
| 1511 | + if response_format != "json": |
| 1512 | + try: |
| 1513 | + import httpx |
| 1514 | + |
| 1515 | + # Get connection info from the transport |
| 1516 | + connection = self._async_client.transport.get_connection() |
| 1517 | + base_url = connection.host |
| 1518 | + url = f"{base_url}/_plugins/_sql?format={response_format}" |
| 1519 | + |
| 1520 | + headers = {"Content-Type": "application/json"} |
| 1521 | + auth = None |
| 1522 | + if self._http_auth: |
| 1523 | + if isinstance(self._http_auth, tuple): |
| 1524 | + auth = self._http_auth |
| 1525 | + elif isinstance(self._http_auth, AWSAuth): |
| 1526 | + # For AWS auth, we need to use the opensearchpy client |
| 1527 | + # Fall through to the try/except below |
| 1528 | + pass |
| 1529 | + |
| 1530 | + verify = self._verify_certs if self._verify_certs is not None else True |
| 1531 | + timeout = httpx.Timeout(self._timeout if self._timeout else 30.0) |
| 1532 | + |
| 1533 | + async with httpx.AsyncClient(verify=verify, timeout=timeout) as client: |
| 1534 | + response = await client.post( |
| 1535 | + url, |
| 1536 | + json={"query": query}, |
| 1537 | + headers=headers, |
| 1538 | + auth=auth, |
| 1539 | + ) |
| 1540 | + response.raise_for_status() |
| 1541 | + return response.text |
| 1542 | + except ImportError: |
| 1543 | + # httpx not available, fall through to opensearchpy |
| 1544 | + pass |
| 1545 | + except Exception as e: |
| 1546 | + # If httpx fails (e.g., AWS auth), fall back to opensearchpy |
| 1547 | + # which will raise SerializationError that we can handle |
| 1548 | + pass |
| 1549 | + |
| 1550 | + try: |
| 1551 | + body = {"query": query} |
| 1552 | + params = {"format": response_format} |
| 1553 | + |
| 1554 | + response_data = await self._async_client.transport.perform_request( |
| 1555 | + method="POST", |
| 1556 | + url="/_plugins/_sql", |
| 1557 | + params=params, |
| 1558 | + body=body, |
| 1559 | + ) |
| 1560 | + |
| 1561 | + if response_format == "json": |
| 1562 | + # extract only the query results |
| 1563 | + if isinstance(response_data, dict) and "hits" in response_data: |
| 1564 | + hits = response_data.get("hits", {}).get("hits", []) |
| 1565 | + # extract _source from each hit, which contains the actual document data |
| 1566 | + return [hit.get("_source", {}) for hit in hits] |
| 1567 | + return response_data |
| 1568 | + else: |
| 1569 | + return response_data if isinstance(response_data, str) else str(response_data) |
| 1570 | + except SerializationError: |
| 1571 | + # If we get here, it means httpx failed above (likely AWS auth or not installed) |
| 1572 | + # and opensearchpy can't deserialize the response |
| 1573 | + # Re-raise as DocumentStoreError with a helpful message |
| 1574 | + msg = f"Failed to execute SQL query in OpenSearch: Unable to deserialize {response_format} response. This format may not be supported with the current authentication method. Consider installing httpx for better support." |
| 1575 | + raise DocumentStoreError(msg) from None |
| 1576 | + except Exception as e: |
| 1577 | + msg = f"Failed to execute SQL query in OpenSearch: {e!s}" |
| 1578 | + raise DocumentStoreError(msg) from e |
0 commit comments