Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _convert_json_types(self, row: List[str]) -> List[Any]:
column_name=column_name,
precision=precision,
scale=scale,
timestamp_format=self.connection.non_arrow_timestamp_format,
)
converted_row.append(converted_value)

Expand Down
5 changes: 5 additions & 0 deletions src/databricks/sql/backend/sea/utils/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dateutil import parser
from typing import Callable, Dict, Optional

from databricks.sql.utils import parse_timestamp

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -162,6 +164,9 @@ def convert_value(
precision = kwargs.get("precision", None)
scale = kwargs.get("scale", None)
return converter_func(value, precision, scale)
elif sql_type == SqlType.TIMESTAMP:
timestamp_format = kwargs.get("timestamp_format", None)
return parse_timestamp(value, timestamp_format)
else:
return converter_func(value)
Comment on lines +167 to 171
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def fetch_results(
description,
chunk_id: int,
use_cloud_fetch=True,
timestamp_format=None,
):
thrift_handle = command_id.to_thrift_handle()
if not thrift_handle:
Expand Down Expand Up @@ -1336,6 +1337,7 @@ def fetch_results(
statement_id=command_id.to_hex_guid(),
chunk_id=chunk_id,
http_client=self._http_client,
timestamp_format=timestamp_format,
)

return (
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def read(self) -> Optional[OAuthToken]:
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self.non_arrow_timestamp_format = kwargs.get("non_arrow_timestamp_format", None)
self._cursors = [] # type: List[Cursor]
Comment on lines 295 to 299
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
statement_id=execute_response.command_id.to_hex_guid(),
chunk_id=self.num_chunks,
http_client=connection.http_client,
timestamp_format=connection.non_arrow_timestamp_format,
)
if t_row_set.resultLinks:
self.num_chunks += len(t_row_set.resultLinks)
Expand Down Expand Up @@ -281,6 +282,7 @@ def _fill_results_buffer(self):
description=self.description,
use_cloud_fetch=self._use_cloud_fetch,
chunk_id=self.num_chunks,
timestamp_format=self.connection.non_arrow_timestamp_format,
)
self.results = results
self.has_more_rows = has_more_rows
Expand Down
40 changes: 37 additions & 3 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from decimal import Decimal
from enum import Enum
import re
import pytz

import lz4.frame

Expand Down Expand Up @@ -53,6 +54,32 @@ def get_session_config_value(
return None


def parse_timestamp(
value: str, timestamp_format: Optional[str] = None
) -> datetime.datetime:
"""Parse a timestamp string into a datetime object.

If timestamp_format is provided, tries strptime first and falls back to
dateutil.parser.parse on ValueError. If timestamp_format is None, uses
dateutil.parser.parse directly.

Args:
value: The timestamp string to parse.
timestamp_format: An optional strptime-compatible format string.

Returns:
A datetime.datetime object.
"""
if timestamp_format is not None:
try:
return datetime.datetime.strptime(value, timestamp_format).replace(
tzinfo=pytz.UTC
)
Comment on lines +75 to +77
except ValueError:
return parser.parse(value)
return parser.parse(value)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int):
Expand Down Expand Up @@ -81,6 +108,7 @@ def build_queue(
http_client,
lz4_compressed: bool = True,
description: List[Tuple] = [],
timestamp_format: Optional[str] = None,
) -> ResultSetQueue:
"""
Factory method to build a result set queue for Thrift backend.
Expand All @@ -93,6 +121,7 @@ def build_queue(
description (List[List[Any]]): Hive table schema description.
max_download_threads (int): Maximum number of downloader thread pool threads.
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
timestamp_format: Optional strptime-compatible format for timestamp parsing.

Returns:
ResultSetQueue
Expand All @@ -112,7 +141,7 @@ def build_queue(
)

converted_column_table = convert_to_assigned_datatypes_in_column_table(
column_table, description
column_table, description, timestamp_format=timestamp_format
)

return ColumnQueue(ColumnTable(converted_column_table, column_names))
Expand Down Expand Up @@ -760,7 +789,9 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
return pyarrow.Table.from_arrays(new_columns, schema=new_schema)


def convert_to_assigned_datatypes_in_column_table(column_table, description):
def convert_to_assigned_datatypes_in_column_table(
column_table, description, timestamp_format=None
):

converted_column_table = []
for i, col in enumerate(column_table):
Expand All @@ -774,7 +805,10 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description):
)
elif description[i][1] == "timestamp":
converted_column_table.append(
tuple((v if v is None else parser.parse(v)) for v in col)
tuple(
(v if v is None else parse_timestamp(v, timestamp_format))
for v in col
)
)
else:
converted_column_table.append(col)
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_sea_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,37 @@ def test_convert_unsupported_type(self):
SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT, None)
== "complex_value"
)

def test_convert_timestamp_with_format(self):
"""Test converting timestamp with an explicit strptime format."""
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = SqlTypeConverter.convert_value(
"2023-12-31 12:30:00.123000",
SqlType.TIMESTAMP,
None,
timestamp_format=fmt,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000)

def test_convert_timestamp_with_format_fallback(self):
"""Test that non-matching format falls back to dateutil."""
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = SqlTypeConverter.convert_value(
"08-Mar-2024 14:30:15",
SqlType.TIMESTAMP,
None,
timestamp_format=fmt,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2024, 3, 8, 14, 30, 15)

def test_convert_timestamp_without_format(self):
"""Test converting timestamp without explicit format uses dateutil."""
result = SqlTypeConverter.convert_value(
"2023-01-15T12:30:45",
SqlType.TIMESTAMP,
None,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2023, 1, 15, 12, 30, 45)
39 changes: 39 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from databricks.sql.utils import (
convert_to_assigned_datatypes_in_column_table,
parse_timestamp,
ColumnTable,
concat_table_chunks,
serialize_query_tags,
Expand Down Expand Up @@ -224,3 +225,41 @@ def test_serialize_query_tags_all_none_values(self):
query_tags = {"key1": None, "key2": None, "key3": None}
result = serialize_query_tags(query_tags)
assert result == "key1,key2,key3"


class TestParseTimestamp:
def test_no_format_uses_dateutil(self):
result = parse_timestamp("2023-12-31 12:30:00")
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0)

def test_matching_format_uses_strptime(self):
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = parse_timestamp("2023-12-31 12:30:00.123000", fmt)
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000)

def test_non_matching_format_falls_back_to_dateutil(self):
fmt = "%Y-%m-%d %H:%M:%S.%f"
# This doesn't match the format, so should fall back to dateutil
result = parse_timestamp("08-Mar-2024 14:30:15", fmt)
assert result == datetime.datetime(2024, 3, 8, 14, 30, 15)

def test_convert_column_table_with_timestamp_format(self):
description = [
("ts_col", "timestamp", None, None, None, None, None),
]
column_table = [("2023-12-31 12:30:00.000000",)]
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = convert_to_assigned_datatypes_in_column_table(
column_table, description, timestamp_format=fmt
)
assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0)

def test_convert_column_table_without_timestamp_format(self):
description = [
("ts_col", "timestamp", None, None, None, None, None),
]
column_table = [("2023-12-31 12:30:00",)]
result = convert_to_assigned_datatypes_in_column_table(
column_table, description
)
assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0)
Loading