Skip to content

Commit 2c1e58c

Browse files
authored
Merge pull request #210 from Maxteabag/mssql-direct-token-attach
Skip duplicate auth and cache the Entra token for ad_default mssql
2 parents ba1c275 + cbbedd2 commit 2c1e58c

3 files changed

Lines changed: 388 additions & 26 deletions

File tree

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import struct
56
from typing import TYPE_CHECKING, Any
67

78
from sqlit.domains.connections.providers.adapters.base import (
@@ -21,6 +22,22 @@
2122
if TYPE_CHECKING:
2223
from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig
2324

25+
# ODBC connection attribute that lets us hand SQL Server a pre-acquired
26+
# Entra access token instead of having the driver acquire one itself.
27+
SQL_COPT_SS_ACCESS_TOKEN = 1256
28+
29+
30+
def _build_access_token_struct(token: str) -> bytes:
31+
"""Pack a JWT into the layout SQL_COPT_SS_ACCESS_TOKEN expects.
32+
33+
SQL Server's ODBC driver wants a 4-byte little-endian length prefix
34+
followed by the token encoded as UTF-16-LE bytes. Same layout the
35+
mssql-python driver builds internally; we just produce it ourselves
36+
so we can skip the driver's redundant token acquisition.
37+
"""
38+
token_bytes = token.encode("UTF-16-LE")
39+
return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
40+
2441

2542
class AzureAdAuthError(Exception):
2643
"""Raised when the Azure AD credential chain cannot produce a SQL token.
@@ -166,11 +183,15 @@ def detect_capabilities(self, conn: Any, config: ConnectionConfig) -> None:
166183
except Exception:
167184
pass
168185

169-
def _build_connection_string(self, config: ConnectionConfig) -> str:
186+
def _build_connection_string(self, config: ConnectionConfig, *, attach_token: bool = False) -> str:
170187
"""Build mssql-python connection string from config.
171188
172189
Args:
173190
config: Connection configuration.
191+
attach_token: True when we'll be supplying SQL_COPT_SS_ACCESS_TOKEN
192+
ourselves. In that case omit the `Authentication=` directive —
193+
the two paths conflict, and the directive would make the driver
194+
spawn `az` to acquire its own token, defeating the optimization.
174195
175196
Returns:
176197
semicolon-delimited key=value connection string.
@@ -200,6 +221,9 @@ def _build_connection_string(self, config: ConnectionConfig) -> str:
200221

201222
auth = self.get_auth_type(config)
202223

224+
if attach_token and auth == AuthType.AD_DEFAULT:
225+
return base
226+
203227
if auth == AuthType.WINDOWS:
204228
return base + "Trusted_Connection=yes;"
205229
elif auth == AuthType.SQL_SERVER:
@@ -225,39 +249,56 @@ def connect(self, config: ConnectionConfig) -> Any:
225249
package_name=self.install_package,
226250
)
227251

228-
self._preflight_azure_credentials(config)
252+
token = self._preflight_azure_credentials(config)
229253

230-
conn_str = self._build_connection_string(config)
254+
conn_str = self._build_connection_string(config, attach_token=token is not None)
231255
# Append extra_options to connection string
232256
for key, value in config.extra_options.items():
233257
conn_str += f"{key}={value};"
234-
conn = mssql_python.connect(conn_str)
258+
259+
attrs_before: dict[int, bytes] | None = None
260+
if token is not None:
261+
attrs_before = {SQL_COPT_SS_ACCESS_TOKEN: _build_access_token_struct(token)}
262+
263+
conn = mssql_python.connect(conn_str, attrs_before=attrs_before)
235264
# Enable autocommit to allow DDL statements like CREATE DATABASE
236265
conn.autocommit = True
237266
return conn
238267

239-
def _preflight_azure_credentials(self, config: ConnectionConfig) -> None:
240-
"""Try to acquire a SQL Entra token before opening the SQL connection.
268+
def _preflight_azure_credentials(self, config: ConnectionConfig) -> str | None:
269+
"""Acquire a SQL Entra token and return it for direct ODBC attach.
241270
242-
Without this, an expired/missing `az login` session surfaces as the
243-
ODBC driver's generic "Login failed for user ''" — the real cause
244-
("Please run 'az login'") is buried in stderr. Acquiring the token
245-
ourselves lets us raise an actionable AzureAdAuthError.
271+
Returning the JWT lets `connect()` hand it to the driver via
272+
SQL_COPT_SS_ACCESS_TOKEN, eliminating the duplicate token acquisition
273+
the driver would otherwise do when it sees `Authentication=...` in the
274+
connection string. Returns None if azure-identity isn't installed or
275+
the config isn't ad_default — falls back to driver-side auth.
246276
247-
Silently no-ops if azure-identity isn't installed (the driver still
248-
handles auth — we just lose the nicer error message).
277+
A persistent file cache (~5 minute refresh-before-expiry buffer)
278+
avoids spawning `az account get-access-token` on every invocation,
279+
which dominates cold-start cost for one-shot `sqlit query` runs.
280+
281+
Failures surface as AzureAdAuthError with an actionable hint
282+
("Please run 'az login'", etc.) instead of the driver's generic
283+
"Login failed for user ''".
249284
"""
250285
import logging
251286

252287
from sqlit.domains.connections.domain.config import AuthType
253288

254289
if self.get_auth_type(config) != AuthType.AD_DEFAULT:
255-
return
290+
return None
256291
try:
257292
from azure.core.exceptions import ClientAuthenticationError
258293
from azure.identity import DefaultAzureCredential
259294
except ImportError:
260-
return
295+
return None
296+
297+
from . import token_cache
298+
299+
cached = token_cache.load()
300+
if cached is not None:
301+
return cached.token
261302

262303
# azure-identity logs the full credential-chain dump to stderr at
263304
# WARNING level on failure. Our own error already names the actionable
@@ -266,12 +307,22 @@ def _preflight_azure_credentials(self, config: ConnectionConfig) -> None:
266307
prior_level = azure_logger.level
267308
azure_logger.setLevel(logging.ERROR)
268309
try:
269-
DefaultAzureCredential().get_token("https://database.windows.net/.default")
310+
access_token = DefaultAzureCredential().get_token(
311+
"https://database.windows.net/.default"
312+
)
270313
except ClientAuthenticationError as exc:
271314
raise AzureAdAuthError(_format_azure_ad_hint(exc)) from exc
272315
finally:
273316
azure_logger.setLevel(prior_level)
274317

318+
try:
319+
token_cache.save(access_token.token, access_token.expires_on)
320+
except OSError:
321+
# Cache write failures are non-fatal — we still have the token.
322+
pass
323+
324+
return access_token.token
325+
275326
def get_databases(self, conn: Any) -> list[str]:
276327
"""Get list of databases from SQL Server."""
277328
cursor = conn.cursor()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Persistent file cache for Azure SQL access tokens.
2+
3+
Each `sqlit query` invocation otherwise spawns `az account get-access-token`
4+
via azure-identity's AzureCliCredential, which costs ~300ms-1s of CLI
5+
startup. Caching the JWT on disk between invocations makes one-shot queries
6+
roughly as fast as the SQL roundtrip itself.
7+
8+
The token is stored at 0600 under the user's sqlit config dir. Anyone with
9+
read access to that file can impersonate the user against Azure SQL for the
10+
token's remaining lifetime (default 1h) — same tradeoff as caching `az`'s
11+
own MSAL cache, and the same threat surface as `~/.azure/`.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import json
17+
import os
18+
import time
19+
from dataclasses import dataclass
20+
21+
from sqlit.shared.core.store import CONFIG_DIR
22+
23+
CACHE_FILE = CONFIG_DIR / "azure_sql_token.json"
24+
25+
# Tokens are treated as expired this many seconds before their real expiry,
26+
# so a token acquired here is still good when the ODBC handshake runs.
27+
_REFRESH_BEFORE_EXPIRY = 300
28+
29+
30+
@dataclass(frozen=True)
31+
class CachedToken:
32+
token: str
33+
expires_on: int
34+
35+
36+
def load() -> CachedToken | None:
37+
"""Return a cached token if it exists and is comfortably non-expired."""
38+
try:
39+
data = json.loads(CACHE_FILE.read_text(encoding="utf-8"))
40+
except (FileNotFoundError, json.JSONDecodeError):
41+
return None
42+
expires_on = int(data.get("expires_on", 0))
43+
if expires_on <= time.time() + _REFRESH_BEFORE_EXPIRY:
44+
return None
45+
token = data.get("token")
46+
if not isinstance(token, str) or not token:
47+
return None
48+
return CachedToken(token=token, expires_on=expires_on)
49+
50+
51+
def save(token: str, expires_on: int) -> None:
52+
"""Persist a token atomically with 0600 perms."""
53+
CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
54+
payload = json.dumps({"token": token, "expires_on": int(expires_on)})
55+
tmp = CACHE_FILE.with_suffix(".tmp")
56+
tmp.write_text(payload, encoding="utf-8")
57+
os.chmod(tmp, 0o600)
58+
os.replace(tmp, CACHE_FILE)
59+
60+
61+
def clear() -> None:
62+
"""Remove the cached token if present."""
63+
try:
64+
CACHE_FILE.unlink()
65+
except FileNotFoundError:
66+
pass

0 commit comments

Comments
 (0)