Skip to content

Commit 6bb13c5

Browse files
committed
Add service principal authentication support
Signed-off-by: David <dr00b@users.noreply.github.com>
1 parent 6b80531 commit 6bb13c5

10 files changed

Lines changed: 630 additions & 6 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ celerybeat.pid
158158

159159
# Environments
160160
.env
161+
.env.local
162+
test.env
161163
.venv
162164
env/
163165
venv/
@@ -206,4 +208,4 @@ poetry.toml
206208
# LSP config files
207209
pyrightconfig.json
208210

209-
# End of https://www.toptal.com/developers/gitignore/api/python,macos
211+
# End of https://www.toptal.com/developers/gitignore/api/python,macos

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Release History
22

3+
# Unreleased
4+
5+
- Feature: Added first-class Databricks service principal (machine-to-machine OAuth) authentication support for SQLAlchemy connections, working across AWS, Azure, and GCP workspaces (databricks/databricks-sqlalchemy#29)
6+
37
# 2.0.8 (2025-09-08)
48

59
- Feature: Added support for variant datatype (databricks/databricks-sqlalchemy#42 by @msrathore-db)
@@ -19,4 +23,4 @@
1923
# 2.0.4 (2025-01-27)
2024

2125
- All the SQLAlchemy features from `databricks-sql-connector>=4.0.0` have been moved to this `databricks-sqlalchemy` library
22-
- Support for SQLAlchemy v2 dialect is provided
26+
- Support for SQLAlchemy v2 dialect is provided

README.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Every SQLAlchemy application that connects to a database needs to use an [Engine
2323

2424
1. Host
2525
2. HTTP Path for a compute resource
26-
3. API access token
26+
3. API access token, or service principal connection parms
2727
4. Initial catalog for the connection
2828
5. Initial schema for the connection
2929

@@ -46,6 +46,33 @@ engine = create_engine(
4646
)
4747
```
4848

49+
### Service principal authentication
50+
51+
Workspaces that prohibit Personal Access Tokens can now use Databricks service principals (see the [Databricks documentation](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m) for how to create one). Supply the service principal credentials directly in the Databricks SQLAlchemy URL and set `authentication=service_principal`.
52+
53+
```python
54+
import os
55+
from sqlalchemy import create_engine
56+
57+
client_id = os.getenv("DATABRICKS_SP_CLIENT_ID")
58+
client_secret = os.getenv("DATABRICKS_SP_CLIENT_SECRET")
59+
host = os.getenv("DATABRICKS_SERVER_HOSTNAME")
60+
http_path = os.getenv("DATABRICKS_HTTP_PATH")
61+
catalog = os.getenv("DATABRICKS_CATALOG")
62+
schema = os.getenv("DATABRICKS_SCHEMA")
63+
64+
engine = create_engine(
65+
"databricks://"
66+
f"{client_id}:{client_secret}"
67+
f"@{host}?http_path={http_path}&catalog={catalog}&schema={schema}"
68+
"&authentication=service_principal"
69+
)
70+
```
71+
72+
`client_id` and `client_secret` are read from the username and password components of the URL. If you prefer to keep the username as `token`, you can pass them in the query string via `client_id` and `client_secret`. By default the dialect requests the Databricks `sql` OAuth scope from the workspace's `/oidc` endpoint. You can override the scopes by providing `sp_scopes` (comma separated) in the query string if you have custom scopes configured.
73+
74+
For local development, copy `test.env.example` to `test.env`, populate it with your workspace and service principal values, and keep `test.env` untracked (it's listed in `.gitignore`). `pytest` automatically reads this file because it is referenced in `pyproject.toml` via `env_files`.
75+
4976
## Types
5077

5178
The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = "^3.8.0"
1313
databricks_sql_connector = { version = ">=4.0.0"}
1414
pyarrow = { version = ">=14.0.1"}
1515
sqlalchemy = { version = ">=2.0.21" }
16+
requests = { version = ">=2.31.0,<3.0.0" }
1617

1718
[tool.poetry.dev-dependencies]
1819
pytest = "^7.1.2"

src/databricks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import threading
2+
import time
3+
from typing import Callable, Dict, Iterable, List, Optional
4+
5+
import requests
6+
from databricks.sql.auth.authenticators import CredentialsProvider
7+
from databricks.sql.auth.endpoint import get_oauth_endpoints
8+
9+
10+
class ServicePrincipalConfigurationError(ValueError):
11+
"""Raised when the service principal configuration is incomplete."""
12+
13+
14+
class ServicePrincipalAuthenticationError(RuntimeError):
15+
"""Raised when fetching an OAuth token fails."""
16+
17+
18+
def _normalize_hostname(hostname: str) -> str:
19+
maybe_scheme = "" if hostname.startswith("https://") else "https://"
20+
trimmed = (
21+
hostname[len("https://") :] if hostname.startswith("https://") else hostname
22+
)
23+
return f"{maybe_scheme}{trimmed}".rstrip("/")
24+
25+
26+
class ServicePrincipalCredentialsProvider(CredentialsProvider):
27+
"""CredentialsProvider that performs the Databricks OAuth client credentials flow."""
28+
29+
DEFAULT_SCOPES = ("sql",)
30+
31+
def __init__(
32+
self,
33+
server_hostname: str,
34+
client_id: str,
35+
client_secret: str,
36+
*,
37+
scopes: Optional[Iterable[str]] = None,
38+
refresh_margin: int = 60,
39+
request_timeout: int = 10,
40+
):
41+
if not server_hostname:
42+
raise ServicePrincipalConfigurationError("server_hostname is required")
43+
if not client_id:
44+
raise ServicePrincipalConfigurationError("client_id is required")
45+
if not client_secret:
46+
raise ServicePrincipalConfigurationError("client_secret is required")
47+
48+
self._hostname = _normalize_hostname(server_hostname)
49+
oauth_endpoints = get_oauth_endpoints(self._hostname, use_azure_auth=False)
50+
if not oauth_endpoints:
51+
raise ServicePrincipalConfigurationError(
52+
f"Unable to determine OAuth endpoints for host {server_hostname}"
53+
)
54+
55+
scope_tuple = tuple(scopes) if scopes else self.DEFAULT_SCOPES
56+
mapped_scopes = oauth_endpoints.get_scopes_mapping(list(scope_tuple))
57+
58+
self._client_id = client_id
59+
self._client_secret = client_secret
60+
self._scopes: List[str] = mapped_scopes
61+
self._refresh_margin = refresh_margin
62+
self._request_timeout = request_timeout
63+
self._access_token: Optional[str] = None
64+
self._expires_at: float = 0
65+
self._lock = threading.Lock()
66+
self._token_endpoint = self._discover_token_endpoint(oauth_endpoints)
67+
68+
def auth_type(self) -> str:
69+
return "databricks-service-principal"
70+
71+
def __call__(self) -> Callable[[], Dict[str, str]]:
72+
def header_factory() -> Dict[str, str]:
73+
access_token = self._get_token()
74+
return {"Authorization": f"Bearer {access_token}"}
75+
76+
return header_factory
77+
78+
def _discover_token_endpoint(self, oauth_endpoints) -> str:
79+
openid_config_url = oauth_endpoints.get_openid_config_url(self._hostname)
80+
try:
81+
response = requests.get(openid_config_url, timeout=self._request_timeout)
82+
response.raise_for_status()
83+
config = response.json()
84+
except Exception as exc:
85+
raise ServicePrincipalAuthenticationError(
86+
"Failed to load Databricks OAuth configuration"
87+
) from exc
88+
89+
token_endpoint = config.get("token_endpoint")
90+
if not token_endpoint:
91+
raise ServicePrincipalAuthenticationError(
92+
"OAuth configuration did not include a token endpoint"
93+
)
94+
return token_endpoint
95+
96+
def _needs_refresh(self) -> bool:
97+
if not self._access_token:
98+
return True
99+
now = time.time()
100+
return now >= (self._expires_at - self._refresh_margin)
101+
102+
def _get_token(self) -> str:
103+
with self._lock:
104+
if self._needs_refresh():
105+
self._refresh_token()
106+
assert self._access_token
107+
return self._access_token
108+
109+
def _refresh_token(self) -> None:
110+
111+
payload = {
112+
"grant_type": "client_credentials",
113+
"client_id": self._client_id,
114+
"client_secret": self._client_secret,
115+
"scope": " ".join(self._scopes),
116+
}
117+
118+
response = requests.post(
119+
self._token_endpoint, data=payload, timeout=self._request_timeout
120+
)
121+
try:
122+
response.raise_for_status()
123+
except Exception as exc:
124+
raise ServicePrincipalAuthenticationError(
125+
"Failed to retrieve OAuth token for service principal"
126+
) from exc
127+
128+
try:
129+
parsed = response.json()
130+
access_token = parsed["access_token"]
131+
except Exception as exc: # pragma: no cover - defensive
132+
raise ServicePrincipalAuthenticationError(
133+
"OAuth response did not include an access token"
134+
) from exc
135+
136+
expires_in_raw = parsed.get("expires_in", 3600)
137+
try:
138+
expires_in = int(expires_in_raw)
139+
except (TypeError, ValueError):
140+
expires_in = 3600
141+
142+
self._access_token = access_token
143+
self._expires_at = time.time() + max(expires_in, self._refresh_margin + 1)

src/databricks/sqlalchemy/base.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
get_comment_from_dte_output,
1414
parse_column_info_from_tgetcolumnsresponse,
1515
)
16+
from databricks.sqlalchemy._service_principal import (
17+
ServicePrincipalConfigurationError,
18+
ServicePrincipalCredentialsProvider,
19+
)
1620

1721
import sqlalchemy
1822
from sqlalchemy import DDL, event
@@ -24,7 +28,7 @@
2428
ReflectedTableComment,
2529
)
2630
from sqlalchemy.engine.reflection import ReflectionDefaults
27-
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
31+
from sqlalchemy.exc import ArgumentError, DatabaseError, SQLAlchemyError
2832

2933
try:
3034
import alembic
@@ -45,6 +49,14 @@ class DatabricksImpl(DefaultImpl):
4549
class DatabricksDialect(default.DefaultDialect):
4650
"""This dialect implements only those methods required to pass our e2e tests"""
4751

52+
_SERVICE_PRINCIPAL_ALIASES = {
53+
"serviceprincipal",
54+
"service_principal",
55+
"service-principal",
56+
"serviceprincipal-auth",
57+
"sp",
58+
}
59+
4860
# See sqlalchemy.engine.interfaces for descriptions of each of these properties
4961
name: str = "databricks"
5062
driver: str = "databricks"
@@ -105,22 +117,76 @@ def create_connect_args(self, url):
105117
# TODO: can schema be provided after HOST?
106118
# Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/***
107119

108-
kwargs = {
120+
credentials_provider = self._build_service_principal_provider(url)
121+
122+
kwargs: Dict[str, Any] = {
109123
"server_hostname": url.host,
110-
"access_token": url.password,
111124
"http_path": url.query.get("http_path"),
112125
"catalog": url.query.get("catalog"),
113126
"schema": url.query.get("schema"),
114127
"use_inline_params": False,
115128
}
116129

130+
if credentials_provider:
131+
kwargs["credentials_provider"] = credentials_provider
132+
else:
133+
kwargs["access_token"] = url.password
134+
117135
self.schema = kwargs["schema"]
118136
self.catalog = kwargs["catalog"]
119137

120138
self._force_paramstyle_to_native_mode()
121139

122140
return [], kwargs
123141

142+
def _build_service_principal_provider(
143+
self, url
144+
) -> Optional[ServicePrincipalCredentialsProvider]:
145+
auth_value = (
146+
url.query.get("authentication")
147+
or url.query.get("auth")
148+
or url.query.get("auth_type")
149+
)
150+
username_hint = (url.username or "").lower() if url.username else ""
151+
is_service_principal = False
152+
153+
if auth_value and auth_value.lower() in self._SERVICE_PRINCIPAL_ALIASES:
154+
is_service_principal = True
155+
elif username_hint in self._SERVICE_PRINCIPAL_ALIASES:
156+
is_service_principal = True
157+
158+
if not is_service_principal:
159+
return None
160+
161+
client_id = url.query.get("client_id") or url.username
162+
client_secret = url.password or url.query.get("client_secret")
163+
if not client_id:
164+
raise ArgumentError("Service principal connections require a client_id")
165+
if not client_secret:
166+
raise ArgumentError("Service principal connections require a client_secret")
167+
168+
scopes_raw = (
169+
url.query.get("sp_scopes")
170+
or url.query.get("sp_scope")
171+
or url.query.get("scope")
172+
)
173+
scopes = (
174+
[scope.strip() for scope in scopes_raw.split(",") if scope.strip()]
175+
if scopes_raw
176+
else None
177+
)
178+
try:
179+
provider = ServicePrincipalCredentialsProvider(
180+
server_hostname=url.host or "",
181+
client_id=client_id,
182+
client_secret=client_secret,
183+
scopes=scopes,
184+
)
185+
except ServicePrincipalConfigurationError as exc:
186+
raise ArgumentError(str(exc)) from exc
187+
188+
return provider
189+
124190
def get_columns(
125191
self, connection, table_name, schema=None, **kwargs
126192
) -> List[ReflectedColumn]:

test.env.example

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copy this file to `test.env` (which is .gitignored) and fill in the values
2+
# with real credentials before running local tests.
3+
DATABRICKS_SERVER_HOSTNAME=<workspace-hostname>
4+
DATABRICKS_HTTP_PATH=<sql-http-path>
5+
DATABRICKS_CATALOG=<catalog>
6+
DATABRICKS_SCHEMA=<schema>
7+
DATABRICKS_SP_CLIENT_ID=<service-principal-client-id>
8+
DATABRICKS_SP_CLIENT_SECRET=<service-principal-secret>

0 commit comments

Comments
 (0)