Skip to content

Commit ba1c275

Browse files
authored
Merge pull request #209 from Maxteabag/mssql-ad-default-preflight-hint
Surface actionable hint when Azure AD auth fails for mssql
2 parents eb286c8 + 8c9f540 commit ba1c275

2 files changed

Lines changed: 196 additions & 0 deletions

File tree

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,52 @@
2222
from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig
2323

2424

25+
class AzureAdAuthError(Exception):
26+
"""Raised when the Azure AD credential chain cannot produce a SQL token.
27+
28+
Surfaces the actionable hint (e.g. "run 'az login'") that would otherwise
29+
be buried in the ODBC driver's generic "Login failed for user ''" error.
30+
"""
31+
32+
33+
def _format_azure_ad_hint(exc: Exception) -> str:
34+
"""Build a short, actionable error message from a credential-chain failure.
35+
36+
Picks out the most useful sub-error (typically the AzureCliCredential
37+
line saying "Please run 'az login'") and prepends a one-line hint.
38+
Falls back to the full message if no specific line stands out.
39+
"""
40+
text = str(exc)
41+
primary = _first_actionable_line(text)
42+
hint = "Run 'az login' (or set AZURE_CLIENT_ID/SECRET/TENANT environment variables)."
43+
if primary:
44+
return f"Azure AD authentication failed.\n {primary}\n{hint}"
45+
return f"Azure AD authentication failed.\n{hint}\n\nDetails:\n{text}"
46+
47+
48+
def _first_actionable_line(text: str) -> str:
49+
"""Return the most actionable line from the credential-chain dump.
50+
51+
Walks the needles in priority order — "Please run 'az login'" beats
52+
everything else because it tells the user exactly what to do. Lower-signal
53+
lines like "Environment variables are not fully configured" are passive
54+
and don't make this list; if no high-signal needle matches, the caller
55+
falls back to dumping the full chain.
56+
"""
57+
needles = (
58+
"Please run 'az login'",
59+
"Please run `az login`",
60+
"azd auth login",
61+
"Az.Account module",
62+
)
63+
lines = text.splitlines()
64+
for needle in needles:
65+
for line in lines:
66+
if needle in line:
67+
return line.strip()
68+
return ""
69+
70+
2571
class SQLServerAdapter(DatabaseAdapter):
2672
"""Adapter for Microsoft SQL Server using the mssql-python driver."""
2773

@@ -179,6 +225,8 @@ def connect(self, config: ConnectionConfig) -> Any:
179225
package_name=self.install_package,
180226
)
181227

228+
self._preflight_azure_credentials(config)
229+
182230
conn_str = self._build_connection_string(config)
183231
# Append extra_options to connection string
184232
for key, value in config.extra_options.items():
@@ -188,6 +236,42 @@ def connect(self, config: ConnectionConfig) -> Any:
188236
conn.autocommit = True
189237
return conn
190238

239+
def _preflight_azure_credentials(self, config: ConnectionConfig) -> None:
240+
"""Try to acquire a SQL Entra token before opening the SQL connection.
241+
242+
Without this, an expired/missing `az login` session surfaces as the
243+
ODBC driver's generic "Login failed for user ''" — the real cause
244+
("Please run 'az login'") is buried in stderr. Acquiring the token
245+
ourselves lets us raise an actionable AzureAdAuthError.
246+
247+
Silently no-ops if azure-identity isn't installed (the driver still
248+
handles auth — we just lose the nicer error message).
249+
"""
250+
import logging
251+
252+
from sqlit.domains.connections.domain.config import AuthType
253+
254+
if self.get_auth_type(config) != AuthType.AD_DEFAULT:
255+
return
256+
try:
257+
from azure.core.exceptions import ClientAuthenticationError
258+
from azure.identity import DefaultAzureCredential
259+
except ImportError:
260+
return
261+
262+
# azure-identity logs the full credential-chain dump to stderr at
263+
# WARNING level on failure. Our own error already names the actionable
264+
# cause, so silence the library's noise for the duration of this call.
265+
azure_logger = logging.getLogger("azure.identity")
266+
prior_level = azure_logger.level
267+
azure_logger.setLevel(logging.ERROR)
268+
try:
269+
DefaultAzureCredential().get_token("https://database.windows.net/.default")
270+
except ClientAuthenticationError as exc:
271+
raise AzureAdAuthError(_format_azure_ad_hint(exc)) from exc
272+
finally:
273+
azure_logger.setLevel(prior_level)
274+
191275
def get_databases(self, conn: Any) -> list[str]:
192276
"""Get list of databases from SQL Server."""
193277
cursor = conn.cursor()

tests/unit/test_mssql_adapter.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,115 @@ def test_get_columns_returns_primary_keys(self, adapter):
235235
assert result[0].is_primary_key is True
236236
assert result[1].name == "name"
237237
assert result[1].is_primary_key is False
238+
239+
240+
class TestMSSQLAdapterAzureAdPreflight:
241+
"""Pre-flight Entra-token check for ad_default auth converts the
242+
DefaultAzureCredential chain failure into an actionable AzureAdAuthError
243+
*before* the ODBC driver gets to emit its generic "Login failed for user ''".
244+
"""
245+
246+
@pytest.fixture
247+
def ad_default_config(self):
248+
"""Minimal ConnectionConfig stub with auth_type=ad_default."""
249+
from sqlit.domains.connections.domain.config import (
250+
ConnectionConfig,
251+
TcpEndpoint,
252+
)
253+
254+
endpoint = TcpEndpoint(host="example.database.windows.net", port="1433", database="mydb")
255+
return ConnectionConfig(
256+
name="t",
257+
db_type="mssql",
258+
endpoint=endpoint,
259+
options={"auth_type": "ad_default"},
260+
)
261+
262+
def test_preflight_surfaces_az_login_hint_on_credential_failure(self, ad_default_config):
263+
"""When DefaultAzureCredential.get_token raises, we raise AzureAdAuthError
264+
with the actionable 'Please run az login' line and a one-line hint."""
265+
from sqlit.domains.connections.providers.mssql.adapter import (
266+
AzureAdAuthError,
267+
SQLServerAdapter,
268+
)
269+
270+
chain_message = (
271+
"DefaultAzureCredential failed to retrieve a token from the included credentials.\n"
272+
"Attempted credentials:\n"
273+
"\tEnvironmentCredential: EnvironmentCredential authentication unavailable.\n"
274+
"\tAzureCliCredential: Please run 'az login' to set up an account\n"
275+
"\tAzurePowerShellCredential: Az.Account module >= 2.2.0 is not installed\n"
276+
)
277+
278+
fake_azure_core = MagicMock()
279+
fake_azure_core_exceptions = MagicMock()
280+
281+
class _ClientAuthError(Exception):
282+
pass
283+
284+
fake_azure_core_exceptions.ClientAuthenticationError = _ClientAuthError
285+
286+
fake_azure_identity = MagicMock()
287+
fake_credential = MagicMock()
288+
fake_credential.get_token.side_effect = _ClientAuthError(chain_message)
289+
fake_azure_identity.DefaultAzureCredential.return_value = fake_credential
290+
291+
with patch.dict(
292+
"sys.modules",
293+
{
294+
"azure": MagicMock(),
295+
"azure.core": fake_azure_core,
296+
"azure.core.exceptions": fake_azure_core_exceptions,
297+
"azure.identity": fake_azure_identity,
298+
},
299+
):
300+
adapter = SQLServerAdapter()
301+
with pytest.raises(AzureAdAuthError) as exc_info:
302+
adapter._preflight_azure_credentials(ad_default_config)
303+
304+
message = str(exc_info.value)
305+
assert "Please run 'az login'" in message
306+
assert "Azure AD authentication failed" in message
307+
# The verbose chain dump should NOT be included when we extracted a
308+
# specific actionable line — keep the error tight, like sqlcmd does.
309+
assert "DefaultAzureCredential failed to retrieve" not in message
310+
311+
def test_preflight_noop_when_azure_identity_missing(self, ad_default_config):
312+
"""If azure-identity is not installed, we silently skip and let the
313+
ODBC driver handle auth (preserving the existing behavior)."""
314+
from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter
315+
316+
import builtins
317+
318+
real_import = builtins.__import__
319+
320+
def _block_azure(name, *args, **kwargs):
321+
if name.startswith("azure."):
322+
raise ImportError(f"No module named {name!r}")
323+
return real_import(name, *args, **kwargs)
324+
325+
adapter = SQLServerAdapter()
326+
with patch("builtins.__import__", side_effect=_block_azure):
327+
# Should not raise
328+
adapter._preflight_azure_credentials(ad_default_config)
329+
330+
def test_preflight_skipped_for_non_ad_default_auth(self):
331+
"""Pre-flight only runs for ad_default; sql/ad_password etc. must be
332+
untouched even if azure-identity would fail."""
333+
from sqlit.domains.connections.domain.config import (
334+
ConnectionConfig,
335+
TcpEndpoint,
336+
)
337+
from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter
338+
339+
endpoint = TcpEndpoint(host="h", port="1433", database="d", username="u", password="p")
340+
config = ConnectionConfig(
341+
name="t",
342+
db_type="mssql",
343+
endpoint=endpoint,
344+
options={"auth_type": "sql"},
345+
)
346+
347+
adapter = SQLServerAdapter()
348+
# Even if azure-identity would explode, this never touches it.
349+
adapter._preflight_azure_credentials(config)

0 commit comments

Comments
 (0)