Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/databricks/sql/backend/kernel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def __init__(
self._auth_provider = auth_provider
self._catalog = catalog
self._schema = schema
# ``_use_arrow_native_complex_types`` is the connector-side
# toggle for whether complex columns (ARRAY / MAP / STRUCT)
# are surfaced as native Arrow shapes or as compact JSON
# strings. The Thrift backend forwards it server-side
# (``complexTypesAsArrow``); the kernel doesn't have a wire
# equivalent, so we flip the kernel's client-side
# ``complex_types_as_json`` post-processor to match. Default
# ``True`` mirrors the connector's existing default.
self._use_arrow_native_complex_types = kwargs.get(
"_use_arrow_native_complex_types", True
)
# NB: don't call ``kernel_auth_kwargs`` here. That call
# materialises the bearer token in-process; keeping a
# cleartext copy on a long-lived connector object that may
Expand Down Expand Up @@ -155,6 +166,7 @@ def open_session(
catalog=catalog or self._catalog,
schema=schema or self._schema,
session_conf=session_conf,
complex_types_as_json=not self._use_arrow_native_complex_types,
**auth_kwargs,
)
except Exception as exc:
Expand Down
30 changes: 25 additions & 5 deletions src/databricks/sql/backend/kernel/type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,35 @@ def _databricks_type_for_field(field: pyarrow.Field) -> str:
Consults the field's Arrow metadata under
``databricks.type_name`` (written by the kernel from the SEA
response's column type) so types that collapse onto a generic
Arrow shape can still be distinguished. Today only ``VARIANT``
is mapped; everything else delegates to
``_arrow_type_to_dbapi_string``.
Arrow shape can still be distinguished. This matters in two
cases:

- ``VARIANT`` (always ``Utf8`` on the wire — no Arrow shape
distinguishes it from ``STRING``).
- The ``complex_types_as_json`` post-processor rewrites
``ARRAY`` / ``MAP`` / ``STRUCT`` columns to ``Utf8`` carrying
compact JSON text. The Thrift backend reports the original
SQL type in ``description`` even when ``complexTypesAsArrow``
is off and the wire payload is a JSON string; we match that
by recovering the type name from manifest metadata.
"""
md = field.metadata or {}
# `databricks.type_name` is bytes (Arrow metadata is always
# bytes); compare against bytes to avoid one encode per field.
if md.get(b"databricks.type_name") == b"VARIANT":
return "variant"
type_name = md.get(b"databricks.type_name")
if type_name is not None:
# Lowercase to match the canonical SqlType slugs the Thrift
# backend produces (``"array"`` / ``"map"`` / ``"struct"`` /
# ``"variant"``). Other server-reported names (``"INT"`` etc.)
# would also pass through this branch but we deliberately
# don't honour them — the Arrow shape is the authoritative
# source for primitives, and the kernel's own type-name
# mapping (`map_databricks_type`) is conservative on some
# types (e.g. ``DECIMAL`` arrives as ``decimal`` on the
# Arrow side, which matches Thrift).
decoded = type_name.decode("ascii", errors="replace").lower()
if decoded in {"variant", "array", "map", "struct"}:
return decoded
return _arrow_type_to_dbapi_string(field.type)


Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _create_backend(
http_client=self.http_client,
catalog=kwargs.get("catalog"),
schema=kwargs.get("schema"),
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
)

databricks_client_class: Type[DatabricksClient]
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_kernel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,44 @@ def test_open_session_rejects_double_open(monkeypatch):
c.open_session(session_configuration=None, catalog=None, schema=None)


@pytest.mark.parametrize(
"kwargs, expected_flag",
[
({}, False), # default → arrow-native → kernel JSON off
({"_use_arrow_native_complex_types": True}, False),
({"_use_arrow_native_complex_types": False}, True),
],
)
def test_open_session_passes_complex_types_as_json_to_kernel(
monkeypatch, kwargs, expected_flag
):
"""``_use_arrow_native_complex_types=False`` flips the kernel's
``complex_types_as_json`` post-processor on; the default and
explicit ``True`` both leave it off. The flag is inverted at the
boundary because the connector's option is "native Arrow"-shaped
and the kernel's is "rewrite to JSON strings"-shaped."""
captured = {}

def fake_session(**kw):
captured.update(kw)
sess = MagicMock()
sess.session_id = "sess-id"
return sess

monkeypatch.setattr(kernel_client._kernel, "Session", fake_session)

c = kernel_client.KernelDatabricksClient(
server_hostname="example.cloud.databricks.com",
http_path="/sql/1.0/warehouses/abc",
auth_provider=AccessTokenAuthProvider("dapi-test"),
ssl_options=None,
**kwargs,
)
c.open_session(session_configuration=None, catalog=None, schema=None)

assert captured.get("complex_types_as_json") is expected_flag


def test_execute_command_forwards_parameters_to_bind_param():
"""``execute_command(parameters=[...])`` routes each parameter
through ``bind_tspark_params`` onto the kernel statement before
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_kernel_type_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,58 @@ def test_description_uses_databricks_type_name_for_variant():
assert desc[1][1] == "string"


@pytest.mark.parametrize(
"metadata_value, expected",
[
(b"ARRAY", "array"),
(b"MAP", "map"),
(b"STRUCT", "struct"),
# Lowercase / mixed case both fine — server may report either.
(b"array", "array"),
(b"Struct", "struct"),
],
)
def test_description_recovers_complex_type_name_from_metadata(metadata_value, expected):
"""When ``complex_types_as_json`` rewrites a complex column to
``Utf8``, the kernel preserves the original SQL type name under
``databricks.type_name``. ``description`` must report that name
(matching the Thrift backend's behaviour with
``complexTypesAsArrow=False``), not the post-processed ``string``.
"""
schema = pa.schema(
[
pa.field(
"c",
pa.string(),
metadata={b"databricks.type_name": metadata_value},
),
]
)
desc = description_from_arrow_schema(schema)
assert desc[0][1] == expected


def test_description_passes_through_unknown_databricks_type_name():
"""Server-reported names other than the handful we explicitly
recognise (VARIANT / ARRAY / MAP / STRUCT) defer to the Arrow
shape — the Arrow type is the authoritative source for primitives
and the kernel's own type mapping is conservative there. Confirms
we don't accidentally claim ``int`` from metadata when the Arrow
column is something concrete like ``int64``."""
schema = pa.schema(
[
pa.field(
"n",
pa.int64(),
metadata={b"databricks.type_name": b"INT"},
),
]
)
desc = description_from_arrow_schema(schema)
# `int64` Arrow → "bigint" via the existing arrow-type mapper.
assert desc[0][1] == "bigint"


# ─── bind_tspark_params ──────────────────────────────────────────────────


Expand Down
Loading