@@ -27,6 +27,9 @@ class ClickHouseDbError(Exception):
2727import pyarrow as pa
2828import sqlglot .expressions as sge
2929import trino
30+ from databricks import sql as dbsql
31+ from databricks .sdk .core import Config as DbConfig
32+ from databricks .sdk .core import oauth_service_principal
3033from duckdb import HTTPException , IOException
3134from google .cloud import bigquery
3235from google .oauth2 import service_account
@@ -40,6 +43,9 @@ class ClickHouseDbError(Exception):
4043
4144from app .model import (
4245 ConnectionInfo ,
46+ DatabricksConnectionUnion ,
47+ DatabricksServicePrincipalConnectionInfo ,
48+ DatabricksTokenConnectionInfo ,
4349 GcsFileConnectionInfo ,
4450 MinioFileConnectionInfo ,
4551 RedshiftConnectionInfo ,
@@ -88,6 +94,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
8894 self ._connector = RedshiftConnector (connection_info )
8995 elif data_source == DataSource .postgres :
9096 self ._connector = PostgresConnector (connection_info )
97+ elif data_source == DataSource .databricks :
98+ self ._connector = DatabricksConnector (connection_info )
9199 else :
92100 self ._connector = SimpleConnector (data_source , connection_info )
93101
@@ -584,3 +592,54 @@ def close(self) -> None:
584592 self .connection .close ()
585593 except Exception as e :
586594 logger .warning (f"Error closing Redshift connection: { e } " )
595+
596+
597+ class DatabricksConnector (SimpleConnector ):
598+ def __init__ (self , connection_info : DatabricksConnectionUnion ):
599+ if isinstance (connection_info , DatabricksTokenConnectionInfo ):
600+ self .connection = dbsql .connect (
601+ server_hostname = connection_info .server_hostname .get_secret_value (),
602+ http_path = connection_info .http_path .get_secret_value (),
603+ access_token = connection_info .access_token .get_secret_value (),
604+ )
605+ elif isinstance (connection_info , DatabricksServicePrincipalConnectionInfo ):
606+ kwargs = {
607+ "host" : connection_info .server_hostname .get_secret_value (),
608+ "client_id" : connection_info .client_id .get_secret_value (),
609+ "client_secret" : connection_info .client_secret .get_secret_value (),
610+ }
611+ if connection_info .azure_tenant_id is not None :
612+ kwargs ["azure_tenant_id" ] = (
613+ connection_info .azure_tenant_id .get_secret_value ()
614+ )
615+
616+ def credential_provider ():
617+ return oauth_service_principal (DbConfig (** kwargs ))
618+
619+ self .connection = dbsql .connect (
620+ server_hostname = connection_info .server_hostname .get_secret_value (),
621+ http_path = connection_info .http_path .get_secret_value (),
622+ credentials_provider = credential_provider ,
623+ )
624+
625+ def query (self , sql , limit = None ):
626+ with closing (self .connection .cursor ()) as cursor :
627+ cursor .execute (sql )
628+
629+ if limit is not None :
630+ arrow_table = cursor .fetchmany_arrow (limit )
631+ else :
632+ arrow_table = cursor .fetchall_arrow ()
633+
634+ return arrow_table
635+
636+ def dry_run (self , sql ):
637+ with closing (self .connection .cursor ()) as cursor :
638+ cursor .execute (f"SELECT * FROM ({ sql } ) AS sub LIMIT 0" )
639+
640+ def close (self ) -> None :
641+ """Close the Databricks connection."""
642+ try :
643+ self .connection .close ()
644+ except Exception as e :
645+ logger .warning (f"Error closing Databricks connection: { e } " )
0 commit comments