Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import threading
from typing import List, Optional, Union, Any, TYPE_CHECKING
from uuid import UUID

from databricks.sql.result_set import ThriftResultSet

Expand Down Expand Up @@ -1017,9 +1018,11 @@ def execute_command(
self._handle_execute_response_async(resp, cursor)
return None
else:
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)
(
execute_response,
is_direct_results,
statement_id,
) = self._handle_execute_response(resp, cursor)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
Expand All @@ -1036,6 +1039,8 @@ def execute_command(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

def get_catalogs(
Expand All @@ -1057,9 +1062,11 @@ def get_catalogs(
)
resp = self.make_request(self._client.GetCatalogs, req)

execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)
(
execute_response,
is_direct_results,
statement_id,
) = self._handle_execute_response(resp, cursor)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
Expand All @@ -1076,6 +1083,8 @@ def get_catalogs(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

def get_schemas(
Expand Down Expand Up @@ -1103,9 +1112,11 @@ def get_schemas(
)
resp = self.make_request(self._client.GetSchemas, req)

execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)
(
execute_response,
is_direct_results,
statement_id,
) = self._handle_execute_response(resp, cursor)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
Expand All @@ -1122,6 +1133,8 @@ def get_schemas(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

def get_tables(
Expand Down Expand Up @@ -1153,9 +1166,11 @@ def get_tables(
)
resp = self.make_request(self._client.GetTables, req)

execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)
(
execute_response,
is_direct_results,
statement_id,
) = self._handle_execute_response(resp, cursor)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
Expand All @@ -1172,6 +1187,8 @@ def get_tables(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

def get_columns(
Expand Down Expand Up @@ -1203,9 +1220,11 @@ def get_columns(
)
resp = self.make_request(self._client.GetColumns, req)

execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)
(
execute_response,
is_direct_results,
statement_id,
) = self._handle_execute_response(resp, cursor)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
Expand All @@ -1222,6 +1241,8 @@ def get_columns(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

def _handle_execute_response(self, resp, cursor):
Expand All @@ -1237,7 +1258,15 @@ def _handle_execute_response(self, resp, cursor):
resp.directResults and resp.directResults.operationStatus,
)

return self._results_message_to_execute_response(resp, final_operation_state)
execute_response, is_direct_results = self._results_message_to_execute_response(
resp, final_operation_state
)

return (
execute_response,
is_direct_results,
cursor.active_command_id.to_hex_guid(),
)

def _handle_execute_response_async(self, resp, cursor):
command_id = CommandId.from_thrift_handle(resp.operationHandle)
Expand All @@ -1257,6 +1286,7 @@ def fetch_results(
arrow_schema_bytes,
description,
use_cloud_fetch=True,
statement_id=None,
):
thrift_handle = command_id.to_thrift_handle()
if not thrift_handle:
Expand Down Expand Up @@ -1293,6 +1323,8 @@ def fetch_results(
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
session_id_hex=self._session_id_hex,
statement_id=statement_id,
)

return queue, resp.hasMoreRows
Expand Down
25 changes: 17 additions & 8 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union
from typing import List, Union, Optional, Tuple

from databricks.sql.cloudfetch.downloader import (
ResultSetDownloadHandler,
Expand All @@ -22,24 +22,28 @@ def __init__(
max_download_threads: int,
lz4_compressed: bool,
ssl_options: SSLOptions,
session_id_hex: Optional[str] = None,
statement_id: Optional[str] = None,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
for i, link in enumerate(links):
if link.rowCount <= 0:
continue
logger.debug(
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
link.startRowOffset, link.rowCount
"ResultFileDownloadManager: adding file link, chunk id {}, start offset {}, row count: {}".format(
i, link.startRowOffset, link.rowCount
)
)
self._pending_links.append(link)
self._pending_links.append((i, link))

self._download_tasks: List[Future[DownloadedFile]] = []
self._max_download_threads: int = max_download_threads
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_options = ssl_options
self.session_id_hex = session_id_hex
self.statement_id = statement_id

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -89,14 +93,19 @@ def _schedule_downloads(self):
while (len(self._download_tasks) < self._max_download_threads) and (
len(self._pending_links) > 0
):
link = self._pending_links.pop(0)
chunk_id, link = self._pending_links.pop(0)
logger.debug(
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
"- chunk: {}, start: {}, row count: {}".format(
chunk_id, link.startRowOffset, link.rowCount
)
)
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_options=self._ssl_options,
chunk_id=chunk_id,
session_id_hex=self.session_id_hex,
statement_id=self.statement_id,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand Down
13 changes: 11 additions & 2 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Optional

import requests
from requests.adapters import HTTPAdapter, Retry
Expand All @@ -9,6 +10,7 @@
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions
from databricks.sql.telemetry.latency_logger import log_latency

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,11 +68,18 @@ def __init__(
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_options: SSLOptions,
chunk_id: int,
session_id_hex: Optional[str] = None,
statement_id: Optional[str] = None,
):
self.settings = settings
self.link = link
self._ssl_options = ssl_options
self.chunk_id = chunk_id
self.session_id_hex = session_id_hex
self.statement_id = statement_id

@log_latency()
def run(self) -> DownloadedFile:
"""
Download the file described in the cloud fetch link.
Expand All @@ -80,8 +89,8 @@ def run(self) -> DownloadedFile:
"""

logger.debug(
"ResultSetDownloadHandler: starting file download, offset {}, row count {}".format(
self.link.startRowOffset, self.link.rowCount
"ResultSetDownloadHandler: starting file download, chunk id {}, offset {}, row count {}".format(
self.chunk_id, self.link.startRowOffset, self.link.rowCount
)
)

Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def __init__(
max_download_threads: int = 10,
ssl_options=None,
is_direct_results: bool = True,
session_id_hex: Optional[str] = None,
statement_id: Optional[str] = None,
):
"""
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
Expand Down Expand Up @@ -233,6 +235,8 @@ def __init__(
lz4_compressed=execute_response.lz4_compressed,
description=execute_response.description,
ssl_options=ssl_options,
session_id_hex=session_id_hex,
statement_id=statement_id,
)

# Call parent constructor with common attributes
Expand Down
Loading
Loading