22
33from __future__ import annotations
44
5+ import struct
56from typing import TYPE_CHECKING , Any
67
78from sqlit .domains .connections .providers .adapters .base import (
2122if TYPE_CHECKING :
2223 from sqlit .domains .connections .domain .config import AuthType , ConnectionConfig
2324
25+ # ODBC connection attribute that lets us hand SQL Server a pre-acquired
26+ # Entra access token instead of having the driver acquire one itself.
27+ SQL_COPT_SS_ACCESS_TOKEN = 1256
28+
29+
30+ def _build_access_token_struct (token : str ) -> bytes :
31+ """Pack a JWT into the layout SQL_COPT_SS_ACCESS_TOKEN expects.
32+
33+ SQL Server's ODBC driver wants a 4-byte little-endian length prefix
34+ followed by the token encoded as UTF-16-LE bytes. Same layout the
35+ mssql-python driver builds internally; we just produce it ourselves
36+ so we can skip the driver's redundant token acquisition.
37+ """
38+ token_bytes = token .encode ("UTF-16-LE" )
39+ return struct .pack (f"<I{ len (token_bytes )} s" , len (token_bytes ), token_bytes )
40+
2441
2542class AzureAdAuthError (Exception ):
2643 """Raised when the Azure AD credential chain cannot produce a SQL token.
@@ -166,11 +183,15 @@ def detect_capabilities(self, conn: Any, config: ConnectionConfig) -> None:
166183 except Exception :
167184 pass
168185
169- def _build_connection_string (self , config : ConnectionConfig ) -> str :
186+ def _build_connection_string (self , config : ConnectionConfig , * , attach_token : bool = False ) -> str :
170187 """Build mssql-python connection string from config.
171188
172189 Args:
173190 config: Connection configuration.
191+ attach_token: True when we'll be supplying SQL_COPT_SS_ACCESS_TOKEN
192+ ourselves. In that case omit the `Authentication=` directive —
193+ the two paths conflict, and the directive would make the driver
194+ spawn `az` to acquire its own token, defeating the optimization.
174195
175196 Returns:
176197 semicolon-delimited key=value connection string.
@@ -200,6 +221,9 @@ def _build_connection_string(self, config: ConnectionConfig) -> str:
200221
201222 auth = self .get_auth_type (config )
202223
224+ if attach_token and auth == AuthType .AD_DEFAULT :
225+ return base
226+
203227 if auth == AuthType .WINDOWS :
204228 return base + "Trusted_Connection=yes;"
205229 elif auth == AuthType .SQL_SERVER :
@@ -225,39 +249,56 @@ def connect(self, config: ConnectionConfig) -> Any:
225249 package_name = self .install_package ,
226250 )
227251
228- self ._preflight_azure_credentials (config )
252+ token = self ._preflight_azure_credentials (config )
229253
230- conn_str = self ._build_connection_string (config )
254+ conn_str = self ._build_connection_string (config , attach_token = token is not None )
231255 # Append extra_options to connection string
232256 for key , value in config .extra_options .items ():
233257 conn_str += f"{ key } ={ value } ;"
234- conn = mssql_python .connect (conn_str )
258+
259+ attrs_before : dict [int , bytes ] | None = None
260+ if token is not None :
261+ attrs_before = {SQL_COPT_SS_ACCESS_TOKEN : _build_access_token_struct (token )}
262+
263+ conn = mssql_python .connect (conn_str , attrs_before = attrs_before )
235264 # Enable autocommit to allow DDL statements like CREATE DATABASE
236265 conn .autocommit = True
237266 return conn
238267
239- def _preflight_azure_credentials (self , config : ConnectionConfig ) -> None :
240- """Try to acquire a SQL Entra token before opening the SQL connection .
268+ def _preflight_azure_credentials (self , config : ConnectionConfig ) -> str | None :
269+ """Acquire a SQL Entra token and return it for direct ODBC attach .
241270
242- Without this, an expired/missing `az login` session surfaces as the
243- ODBC driver's generic "Login failed for user ''" — the real cause
244- ("Please run 'az login'") is buried in stderr. Acquiring the token
245- ourselves lets us raise an actionable AzureAdAuthError.
271+ Returning the JWT lets `connect()` hand it to the driver via
272+ SQL_COPT_SS_ACCESS_TOKEN, eliminating the duplicate token acquisition
273+ the driver would otherwise do when it sees `Authentication=...` in the
274+ connection string. Returns None if azure-identity isn't installed or
275+ the config isn't ad_default — falls back to driver-side auth.
246276
247- Silently no-ops if azure-identity isn't installed (the driver still
248- handles auth — we just lose the nicer error message).
277+ A persistent file cache (~5 minute refresh-before-expiry buffer)
278+ avoids spawning `az account get-access-token` on every invocation,
279+ which dominates cold-start cost for one-shot `sqlit query` runs.
280+
281+ Failures surface as AzureAdAuthError with an actionable hint
282+ ("Please run 'az login'", etc.) instead of the driver's generic
283+ "Login failed for user ''".
249284 """
250285 import logging
251286
252287 from sqlit .domains .connections .domain .config import AuthType
253288
254289 if self .get_auth_type (config ) != AuthType .AD_DEFAULT :
255- return
290+ return None
256291 try :
257292 from azure .core .exceptions import ClientAuthenticationError
258293 from azure .identity import DefaultAzureCredential
259294 except ImportError :
260- return
295+ return None
296+
297+ from . import token_cache
298+
299+ cached = token_cache .load ()
300+ if cached is not None :
301+ return cached .token
261302
262303 # azure-identity logs the full credential-chain dump to stderr at
263304 # WARNING level on failure. Our own error already names the actionable
@@ -266,12 +307,22 @@ def _preflight_azure_credentials(self, config: ConnectionConfig) -> None:
266307 prior_level = azure_logger .level
267308 azure_logger .setLevel (logging .ERROR )
268309 try :
269- DefaultAzureCredential ().get_token ("https://database.windows.net/.default" )
310+ access_token = DefaultAzureCredential ().get_token (
311+ "https://database.windows.net/.default"
312+ )
270313 except ClientAuthenticationError as exc :
271314 raise AzureAdAuthError (_format_azure_ad_hint (exc )) from exc
272315 finally :
273316 azure_logger .setLevel (prior_level )
274317
318+ try :
319+ token_cache .save (access_token .token , access_token .expires_on )
320+ except OSError :
321+ # Cache write failures are non-fatal — we still have the token.
322+ pass
323+
324+ return access_token .token
325+
275326 def get_databases (self , conn : Any ) -> list [str ]:
276327 """Get list of databases from SQL Server."""
277328 cursor = conn .cursor ()
0 commit comments