Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import time
import threading
from typing import List, Optional, Union, Any, TYPE_CHECKING
from uuid import UUID

from databricks.sql.result_set import ThriftResultSet

from databricks.sql.telemetry.models.event import StatementType

if TYPE_CHECKING:
from databricks.sql.client import Cursor
Expand Down Expand Up @@ -887,6 +888,7 @@ def get_execution_result(
arrow_schema_bytes=schema_bytes,
result_format=t_result_set_metadata_resp.resultFormat,
)
execute_response.command_id.set_statement_type(StatementType.QUERY)

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -899,6 +901,7 @@ def get_execution_result(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
Expand Down Expand Up @@ -1025,6 +1028,8 @@ def execute_command(
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

execute_response.command_id.set_statement_type(StatementType.QUERY)

return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
Expand All @@ -1036,6 +1041,7 @@ def execute_command(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_catalogs(
Expand Down Expand Up @@ -1065,6 +1071,8 @@ def get_catalogs(
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

execute_response.command_id.set_statement_type(StatementType.METADATA)

return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
Expand All @@ -1076,6 +1084,7 @@ def get_catalogs(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_schemas(
Expand Down Expand Up @@ -1111,6 +1120,8 @@ def get_schemas(
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

execute_response.command_id.set_statement_type(StatementType.METADATA)

return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
Expand All @@ -1122,6 +1133,7 @@ def get_schemas(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_tables(
Expand Down Expand Up @@ -1161,6 +1173,8 @@ def get_tables(
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

execute_response.command_id.set_statement_type(StatementType.METADATA)

return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
Expand All @@ -1172,6 +1186,7 @@ def get_tables(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def get_columns(
Expand Down Expand Up @@ -1211,6 +1226,8 @@ def get_columns(
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

execute_response.command_id.set_statement_type(StatementType.METADATA)

return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
Expand All @@ -1222,6 +1239,7 @@ def get_columns(
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
session_id_hex=self._session_id_hex,
)

def _handle_execute_response(self, resp, cursor):
Expand Down Expand Up @@ -1256,6 +1274,7 @@ def fetch_results(
lz4_compressed: bool,
arrow_schema_bytes,
description,
chunk_id: int,
use_cloud_fetch=True,
):
thrift_handle = command_id.to_thrift_handle()
Expand Down Expand Up @@ -1293,9 +1312,17 @@ def fetch_results(
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
session_id_hex=self._session_id_hex,
statement_id=command_id.to_hex_guid(),
statement_type=command_id.statement_type,
chunk_id=chunk_id,
)

return queue, resp.hasMoreRows
return (
queue,
resp.hasMoreRows,
len(resp.results.resultLinks) if resp.results.resultLinks else 0,
)

def cancel_command(self, command_id: CommandId) -> None:
thrift_handle = command_id.to_thrift_handle()
Expand Down
15 changes: 15 additions & 0 deletions src/databricks/sql/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from databricks.sql.backend.utils.guid_utils import guid_to_hex_id
from databricks.sql.telemetry.models.enums import StatementType
from databricks.sql.thrift_api.TCLIService import ttypes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -300,6 +301,7 @@ def __init__(
self.operation_type = operation_type
self.has_result_set = has_result_set
self.modified_row_count = modified_row_count
self._statement_type = StatementType.NONE

def __str__(self) -> str:
"""
Expand Down Expand Up @@ -411,6 +413,19 @@ def to_hex_guid(self) -> str:
else:
return str(self.guid)

def set_statement_type(self, statement_type: StatementType):
"""
Set the statement type for this command.
"""
self._statement_type = statement_type

@property
def statement_type(self) -> StatementType:
"""
Get the statement type for this command.
"""
return self._statement_type


@dataclass
class ExecuteResponse:
Expand Down
31 changes: 21 additions & 10 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def read(self) -> Optional[OAuthToken]:

driver_connection_params = DriverConnectionParameters(
http_path=http_path,
mode=DatabricksClientType.THRIFT,
mode=DatabricksClientType.SEA
if self.session.use_sea
else DatabricksClientType.THRIFT,
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
Expand Down Expand Up @@ -706,7 +708,7 @@ def _handle_staging_operation(
session_id_hex=self.connection.get_session_id_hex(),
)

@log_latency(StatementType.SQL)
@log_latency()
def _handle_staging_put(
self, presigned_url: str, local_file: str, headers: Optional[dict] = None
):
Expand All @@ -715,6 +717,7 @@ def _handle_staging_put(
Raise an exception if request fails. Returns no data.
"""

self.statement_type = StatementType.SQL
if local_file is None:
raise ProgrammingError(
"Cannot perform PUT without specifying a local_file",
Expand Down Expand Up @@ -746,7 +749,7 @@ def _handle_staging_put(
+ "but not yet applied on the server. It's possible this command may fail later."
)

@log_latency(StatementType.SQL)
@log_latency()
def _handle_staging_get(
self, local_file: str, presigned_url: str, headers: Optional[dict] = None
):
Expand All @@ -755,6 +758,7 @@ def _handle_staging_get(
Raise an exception if request fails. Returns no data.
"""

self.statement_type = StatementType.SQL
if local_file is None:
raise ProgrammingError(
"Cannot perform GET without specifying a local_file",
Expand All @@ -774,12 +778,13 @@ def _handle_staging_get(
with open(local_file, "wb") as fp:
fp.write(r.content)

@log_latency(StatementType.SQL)
@log_latency()
def _handle_staging_remove(
self, presigned_url: str, headers: Optional[dict] = None
):
"""Make an HTTP DELETE request to the presigned_url"""

self.statement_type = StatementType.SQL
r = requests.delete(url=presigned_url, headers=headers)

if not r.ok:
Expand All @@ -788,7 +793,7 @@ def _handle_staging_remove(
session_id_hex=self.connection.get_session_id_hex(),
)

@log_latency(StatementType.QUERY)
@log_latency()
def execute(
self,
operation: str,
Expand Down Expand Up @@ -827,6 +832,7 @@ def execute(
:returns self
"""

self.statement_type = StatementType.QUERY
logger.debug(
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
)
Expand Down Expand Up @@ -873,7 +879,7 @@ def execute(

return self

@log_latency(StatementType.QUERY)
@log_latency()
def execute_async(
self,
operation: str,
Expand All @@ -889,6 +895,7 @@ def execute_async(
:return:
"""

self.statement_type = StatementType.QUERY
param_approach = self._determine_parameter_approach(parameters)
if param_approach == ParameterApproach.NONE:
prepared_params = NO_NATIVE_PARAMS
Expand Down Expand Up @@ -992,13 +999,14 @@ def executemany(self, operation, seq_of_parameters):
self.execute(operation, parameters)
return self

@log_latency(StatementType.METADATA)
@log_latency()
def catalogs(self) -> "Cursor":
"""
Get all available catalogs.

:returns self
"""
self.statement_type = StatementType.METADATA
self._check_not_closed()
self._close_and_clear_active_result_set()
self.active_result_set = self.backend.get_catalogs(
Expand All @@ -1009,7 +1017,7 @@ def catalogs(self) -> "Cursor":
)
return self

@log_latency(StatementType.METADATA)
@log_latency()
def schemas(
self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None
) -> "Cursor":
Expand All @@ -1019,6 +1027,7 @@ def schemas(
Names can contain % wildcards.
:returns self
"""
self.statement_type = StatementType.METADATA
self._check_not_closed()
self._close_and_clear_active_result_set()
self.active_result_set = self.backend.get_schemas(
Expand All @@ -1031,7 +1040,7 @@ def schemas(
)
return self

@log_latency(StatementType.METADATA)
@log_latency()
def tables(
self,
catalog_name: Optional[str] = None,
Expand All @@ -1045,6 +1054,7 @@ def tables(
Names can contain % wildcards.
:returns self
"""
self.statement_type = StatementType.METADATA
self._check_not_closed()
self._close_and_clear_active_result_set()

Expand All @@ -1060,7 +1070,7 @@ def tables(
)
return self

@log_latency(StatementType.METADATA)
@log_latency()
def columns(
self,
catalog_name: Optional[str] = None,
Expand All @@ -1074,6 +1084,7 @@ def columns(
Names can contain % wildcards.
:returns self
"""
self.statement_type = StatementType.METADATA
self._check_not_closed()
self._close_and_clear_active_result_set()

Expand Down
Loading
Loading