Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit d759d81

Browse files
authored
feat(ibis): optimize Databricks connector and support service principal connection (#1373)
1 parent 180af86 commit d759d81

8 files changed

Lines changed: 177 additions & 21 deletions

File tree

ibis-server/app/model/__init__.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class QuerySnowflakeDTO(QueryDTO):
6868

6969

7070
class QueryDatabricksDTO(QueryDTO):
71-
connection_info: DatabricksConnectionInfo = connection_info_field
71+
connection_info: DatabricksConnectionUnion = connection_info_field
7272

7373

7474
class QueryTrinoDTO(QueryDTO):
@@ -400,7 +400,8 @@ class SnowflakeConnectionInfo(BaseConnectionInfo):
400400
)
401401

402402

403-
class DatabricksConnectionInfo(BaseConnectionInfo):
403+
class DatabricksTokenConnectionInfo(BaseConnectionInfo):
404+
databricks_type: Literal["token"] = "token"
404405
server_hostname: SecretStr = Field(
405406
alias="serverHostname",
406407
description="the server hostname of your Databricks instance",
@@ -418,6 +419,43 @@ class DatabricksConnectionInfo(BaseConnectionInfo):
418419
)
419420

420421

422+
# https://docs.databricks.com/aws/en/dev-tools/python-sql-connector#oauth-machine-to-machine-m2m-authentication
423+
class DatabricksServicePrincipalConnectionInfo(BaseConnectionInfo):
424+
databricks_type: Literal["service_principal"] = "service_principal"
425+
server_hostname: SecretStr = Field(
426+
alias="serverHostname",
427+
description="the server hostname of your Databricks instance",
428+
examples=["dbc-xxxxxxxx-xxxx.cloud.databricks.com"],
429+
)
430+
http_path: SecretStr = Field(
431+
alias="httpPath",
432+
description="the HTTP path of your Databricks SQL warehouse",
433+
examples=["/sql/1.0/warehouses/xxxxxxxx"],
434+
)
435+
client_id: SecretStr = Field(
436+
alias="clientId",
437+
description="the client ID for OAuth M2M authentication",
438+
examples=["your-client-id"],
439+
)
440+
client_secret: SecretStr = Field(
441+
alias="clientSecret",
442+
description="the client secret for OAuth M2M authentication",
443+
examples=["your-client-secret"],
444+
)
445+
azure_tenant_id: SecretStr | None = Field(
446+
alias="azureTenantId",
447+
description="the Azure tenant ID for OAuth M2M authentication",
448+
examples=["your-tenant-id"],
449+
default=None,
450+
)
451+
452+
453+
DatabricksConnectionUnion = Annotated[
454+
Union[DatabricksTokenConnectionInfo, DatabricksServicePrincipalConnectionInfo],
455+
Field(discriminator="databricks_type"),
456+
]
457+
458+
421459
class TrinoConnectionInfo(BaseConnectionInfo):
422460
host: SecretStr = Field(
423461
description="the hostname of your database", examples=["localhost"]
@@ -543,7 +581,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
543581
| RedshiftConnectionInfo
544582
| RedshiftIAMConnectionInfo
545583
| SnowflakeConnectionInfo
546-
| DatabricksConnectionInfo
584+
| DatabricksTokenConnectionInfo
547585
| TrinoConnectionInfo
548586
| LocalFileConnectionInfo
549587
| S3FileConnectionInfo

ibis-server/app/model/connector.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class ClickHouseDbError(Exception):
2727
import pyarrow as pa
2828
import sqlglot.expressions as sge
2929
import trino
30+
from databricks import sql as dbsql
31+
from databricks.sdk.core import Config as DbConfig
32+
from databricks.sdk.core import oauth_service_principal
3033
from duckdb import HTTPException, IOException
3134
from google.cloud import bigquery
3235
from google.oauth2 import service_account
@@ -40,6 +43,9 @@ class ClickHouseDbError(Exception):
4043

4144
from app.model import (
4245
ConnectionInfo,
46+
DatabricksConnectionUnion,
47+
DatabricksServicePrincipalConnectionInfo,
48+
DatabricksTokenConnectionInfo,
4349
GcsFileConnectionInfo,
4450
MinioFileConnectionInfo,
4551
RedshiftConnectionInfo,
@@ -88,6 +94,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
8894
self._connector = RedshiftConnector(connection_info)
8995
elif data_source == DataSource.postgres:
9096
self._connector = PostgresConnector(connection_info)
97+
elif data_source == DataSource.databricks:
98+
self._connector = DatabricksConnector(connection_info)
9199
else:
92100
self._connector = SimpleConnector(data_source, connection_info)
93101

@@ -584,3 +592,54 @@ def close(self) -> None:
584592
self.connection.close()
585593
except Exception as e:
586594
logger.warning(f"Error closing Redshift connection: {e}")
595+
596+
597+
class DatabricksConnector(SimpleConnector):
598+
def __init__(self, connection_info: DatabricksConnectionUnion):
599+
if isinstance(connection_info, DatabricksTokenConnectionInfo):
600+
self.connection = dbsql.connect(
601+
server_hostname=connection_info.server_hostname.get_secret_value(),
602+
http_path=connection_info.http_path.get_secret_value(),
603+
access_token=connection_info.access_token.get_secret_value(),
604+
)
605+
elif isinstance(connection_info, DatabricksServicePrincipalConnectionInfo):
606+
kwargs = {
607+
"host": connection_info.server_hostname.get_secret_value(),
608+
"client_id": connection_info.client_id.get_secret_value(),
609+
"client_secret": connection_info.client_secret.get_secret_value(),
610+
}
611+
if connection_info.azure_tenant_id is not None:
612+
kwargs["azure_tenant_id"] = (
613+
connection_info.azure_tenant_id.get_secret_value()
614+
)
615+
616+
def credential_provider():
617+
return oauth_service_principal(DbConfig(**kwargs))
618+
619+
self.connection = dbsql.connect(
620+
server_hostname=connection_info.server_hostname.get_secret_value(),
621+
http_path=connection_info.http_path.get_secret_value(),
622+
credentials_provider=credential_provider,
623+
)
624+
625+
def query(self, sql, limit=None):
626+
with closing(self.connection.cursor()) as cursor:
627+
cursor.execute(sql)
628+
629+
if limit is not None:
630+
arrow_table = cursor.fetchmany_arrow(limit)
631+
else:
632+
arrow_table = cursor.fetchall_arrow()
633+
634+
return arrow_table
635+
636+
def dry_run(self, sql):
637+
with closing(self.connection.cursor()) as cursor:
638+
cursor.execute(f"SELECT * FROM ({sql}) AS sub LIMIT 0")
639+
640+
def close(self) -> None:
641+
"""Close the Databricks connection."""
642+
try:
643+
self.connection.close()
644+
except Exception as e:
645+
logger.warning(f"Error closing Databricks connection: {e}")

ibis-server/app/model/data_source.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
ClickHouseConnectionInfo,
2222
ConnectionInfo,
2323
ConnectionUrl,
24-
DatabricksConnectionInfo,
24+
DatabricksServicePrincipalConnectionInfo,
25+
DatabricksTokenConnectionInfo,
2526
GcsFileConnectionInfo,
2627
LocalFileConnectionInfo,
2728
MinioFileConnectionInfo,
@@ -181,7 +182,12 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
181182
case DataSource.gcs_file:
182183
return GcsFileConnectionInfo.model_validate(data)
183184
case DataSource.databricks:
184-
return DatabricksConnectionInfo.model_validate(data)
185+
if (
186+
"databricks_type" in data
187+
and data["databricks_type"] == "service_principal"
188+
):
189+
return DatabricksServicePrincipalConnectionInfo.model_validate(data)
190+
return DatabricksTokenConnectionInfo.model_validate(data)
185191
case _:
186192
raise NotImplementedError(f"Unsupported data source: {self}")
187193

@@ -458,7 +464,7 @@ def get_trino_connection(info: TrinoConnectionInfo) -> BaseBackend:
458464
)
459465

460466
@staticmethod
461-
def get_databricks_connection(info: DatabricksConnectionInfo) -> BaseBackend:
467+
def get_databricks_connection(info: DatabricksTokenConnectionInfo) -> BaseBackend:
462468
return ibis.databricks.connect(
463469
server_hostname=info.server_hostname.get_secret_value(),
464470
http_path=info.http_path.get_secret_value(),

ibis-server/app/model/metadata/databricks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from loguru import logger
22

3-
from app.model import DatabricksConnectionInfo
4-
from app.model.data_source import DataSource
3+
from app.model import DatabricksTokenConnectionInfo
4+
from app.model.connector import DatabricksConnector
55
from app.model.metadata.dto import (
66
Column,
77
Constraint,
@@ -33,9 +33,9 @@
3333

3434

3535
class DatabricksMetadata(Metadata):
36-
def __init__(self, connection_info: DatabricksConnectionInfo):
36+
def __init__(self, connection_info: DatabricksTokenConnectionInfo):
3737
super().__init__(connection_info)
38-
self.connection = DataSource.databricks.get_connection(connection_info)
38+
self.connection = DatabricksConnector(connection_info)
3939

4040
def get_table_list(self) -> list[Table]:
4141
sql = """
@@ -58,7 +58,7 @@ def get_table_list(self) -> list[Table]:
5858
WHERE
5959
c.TABLE_SCHEMA NOT IN ('information_schema')
6060
"""
61-
response = self.connection.sql(sql).to_pandas().to_dict(orient="records")
61+
response = self.connection.query(sql).to_pandas().to_dict(orient="records")
6262

6363
unique_tables = {}
6464
for row in response:
@@ -122,7 +122,7 @@ def get_constraints(self) -> list[Constraint]:
122122
AND ccu.constraint_schema = tc.constraint_schema
123123
WHERE tc.constraint_type = 'FOREIGN KEY'
124124
"""
125-
res = self.connection.sql(sql).to_pandas().to_dict(orient="records")
125+
res = self.connection.query(sql).to_pandas().to_dict(orient="records")
126126
constraints = []
127127
for row in res:
128128
constraints.append(
@@ -150,7 +150,7 @@ def get_constraints(self) -> list[Constraint]:
150150

151151
def get_version(self) -> str:
152152
return (
153-
self.connection.sql("SELECT current_version().dbsql_version")
153+
self.connection.query("SELECT current_version().dbsql_version")
154154
.to_pandas()
155155
.iloc[0, 0]
156156
)

ibis-server/poetry.lock

Lines changed: 23 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ibis-server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ redshift_connector = "2.1.7"
5050
datafusion = "^47.0.0, <49.0.0"
5151
starlette = "^0.49.1"
5252
databricks-sql-connector = { version = "^4.0.1", extras = ["pyarrow"] }
53+
databricks-sdk = "^0.73.0"
5354

5455
[tool.poetry.group.jupyter]
5556
optional = true

ibis-server/tests/routers/v3/connector/databricks/conftest.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,19 @@ def init_databricks(connection_info):
5454
@pytest.fixture(scope="module")
5555
def connection_info() -> dict[str, str]:
5656
return {
57-
"serverHostname": os.getenv("DATABRICKS_SERVER_HOSTNAME"),
58-
"httpPath": os.getenv("DATABRICKS_HTTP_PATH"),
59-
"accessToken": os.getenv("DATABRICKS_TOKEN"),
57+
"databricks_type": "token",
58+
"serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"),
59+
"httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"),
60+
"accessToken": os.getenv("TEST_DATABRICKS_TOKEN"),
61+
}
62+
63+
64+
@pytest.fixture(scope="module")
65+
def service_principal_connection_info() -> dict[str, str]:
66+
return {
67+
"databricks_type": "service_principal",
68+
"serverHostname": os.getenv("TEST_DATABRICKS_SERVER_HOSTNAME"),
69+
"httpPath": os.getenv("TEST_DATABRICKS_HTTP_PATH"),
70+
"clientId": os.getenv("TEST_DATABRICKS_CLIENT_ID"),
71+
"clientSecret": os.getenv("TEST_DATABRICKS_CLIENT_SECRET"),
6072
}

ibis-server/tests/routers/v3/connector/databricks/test_query.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
],
5050
},
5151
],
52+
"dataSource": "databricks",
5253
}
5354

5455

@@ -86,15 +87,32 @@ async def test_query(client, manifest_str, connection_info):
8687
"orderkey": "int64",
8788
"custkey": "int64",
8889
"orderstatus": "string",
89-
"totalprice": "decimal128(38, 9)",
90+
"totalprice": "decimal128(18, 2)",
9091
"orderdate": "date32[day]",
9192
"order_cust_key": "string",
92-
"timestamp": "timestamp[us, tz=UTC]",
93-
"timestamptz": "timestamp[us, tz=UTC]",
94-
"test_null_time": "timestamp[us, tz=UTC]",
93+
"timestamp": "timestamp[us, tz=Etc/UTC]",
94+
"timestamptz": "timestamp[us, tz=Etc/UTC]",
95+
"test_null_time": "timestamp[us, tz=Etc/UTC]",
9596
}
9697

9798

99+
async def test_query_with_service_principal(
100+
client, manifest_str, service_principal_connection_info
101+
):
102+
response = await client.post(
103+
url=f"{base_url}/query",
104+
json={
105+
"connectionInfo": service_principal_connection_info,
106+
"manifestStr": manifest_str,
107+
"sql": "SELECT * FROM wren.public.orders ORDER BY orderkey LIMIT 1",
108+
},
109+
)
110+
assert response.status_code == 200
111+
result = response.json()
112+
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
113+
assert len(result["data"]) == 1
114+
115+
98116
async def test_query_with_limit(client, manifest_str, connection_info):
99117
response = await client.post(
100118
url=f"{base_url}/query",

0 commit comments

Comments
 (0)