Skip to content

Commit 8c9f540

Browse files
committed
Surface actionable hint when Azure AD auth fails for mssql
DefaultAzureCredential failures bubbled up as the ODBC driver's generic 'Login failed for user' — the real cause (e.g. 'Please run az login') was buried in the credential-chain dump. Pre-flight the token before opening the connection and raise an AzureAdAuthError with the most actionable line + a one-line hint, matching sqlcmd's clarity. The azure-identity chain dump itself is suppressed for the duration of the call so users only see our message.
1 parent 4cdaeda commit 8c9f540

2 files changed

Lines changed: 196 additions & 0 deletions

File tree

sqlit/domains/connections/providers/mssql/adapter.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,52 @@
2222
from sqlit.domains.connections.domain.config import AuthType, ConnectionConfig
2323

2424

25+
class AzureAdAuthError(Exception):
26+
"""Raised when the Azure AD credential chain cannot produce a SQL token.
27+
28+
Surfaces the actionable hint (e.g. "run 'az login'") that would otherwise
29+
be buried in the ODBC driver's generic "Login failed for user ''" error.
30+
"""
31+
32+
33+
def _format_azure_ad_hint(exc: Exception) -> str:
34+
"""Build a short, actionable error message from a credential-chain failure.
35+
36+
Picks out the most useful sub-error (typically the AzureCliCredential
37+
line saying "Please run 'az login'") and prepends a one-line hint.
38+
Falls back to the full message if no specific line stands out.
39+
"""
40+
text = str(exc)
41+
primary = _first_actionable_line(text)
42+
hint = "Run 'az login' (or set AZURE_CLIENT_ID/SECRET/TENANT environment variables)."
43+
if primary:
44+
return f"Azure AD authentication failed.\n {primary}\n{hint}"
45+
return f"Azure AD authentication failed.\n{hint}\n\nDetails:\n{text}"
46+
47+
48+
def _first_actionable_line(text: str) -> str:
49+
"""Return the most actionable line from the credential-chain dump.
50+
51+
Walks the needles in priority order — "Please run 'az login'" beats
52+
everything else because it tells the user exactly what to do. Lower-signal
53+
lines like "Environment variables are not fully configured" are passive
54+
and don't make this list; if no high-signal needle matches, the caller
55+
falls back to dumping the full chain.
56+
"""
57+
needles = (
58+
"Please run 'az login'",
59+
"Please run `az login`",
60+
"azd auth login",
61+
"Az.Account module",
62+
)
63+
lines = text.splitlines()
64+
for needle in needles:
65+
for line in lines:
66+
if needle in line:
67+
return line.strip()
68+
return ""
69+
70+
2571
class SQLServerAdapter(DatabaseAdapter):
2672
"""Adapter for Microsoft SQL Server using the mssql-python driver."""
2773

@@ -179,6 +225,8 @@ def connect(self, config: ConnectionConfig) -> Any:
179225
package_name=self.install_package,
180226
)
181227

228+
self._preflight_azure_credentials(config)
229+
182230
conn_str = self._build_connection_string(config)
183231
# Append extra_options to connection string
184232
for key, value in config.extra_options.items():
@@ -188,6 +236,42 @@ def connect(self, config: ConnectionConfig) -> Any:
188236
conn.autocommit = True
189237
return conn
190238

239+
def _preflight_azure_credentials(self, config: ConnectionConfig) -> None:
240+
"""Try to acquire a SQL Entra token before opening the SQL connection.
241+
242+
Without this, an expired/missing `az login` session surfaces as the
243+
ODBC driver's generic "Login failed for user ''" — the real cause
244+
("Please run 'az login'") is buried in stderr. Acquiring the token
245+
ourselves lets us raise an actionable AzureAdAuthError.
246+
247+
Silently no-ops if azure-identity isn't installed (the driver still
248+
handles auth — we just lose the nicer error message).
249+
"""
250+
import logging
251+
252+
from sqlit.domains.connections.domain.config import AuthType
253+
254+
if self.get_auth_type(config) != AuthType.AD_DEFAULT:
255+
return
256+
try:
257+
from azure.core.exceptions import ClientAuthenticationError
258+
from azure.identity import DefaultAzureCredential
259+
except ImportError:
260+
return
261+
262+
# azure-identity logs the full credential-chain dump to stderr at
263+
# WARNING level on failure. Our own error already names the actionable
264+
# cause, so silence the library's noise for the duration of this call.
265+
azure_logger = logging.getLogger("azure.identity")
266+
prior_level = azure_logger.level
267+
azure_logger.setLevel(logging.ERROR)
268+
try:
269+
DefaultAzureCredential().get_token("https://database.windows.net/.default")
270+
except ClientAuthenticationError as exc:
271+
raise AzureAdAuthError(_format_azure_ad_hint(exc)) from exc
272+
finally:
273+
azure_logger.setLevel(prior_level)
274+
191275
def get_databases(self, conn: Any) -> list[str]:
192276
"""Get list of databases from SQL Server."""
193277
cursor = conn.cursor()

tests/unit/test_mssql_adapter.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,115 @@ def test_get_columns_returns_primary_keys(self, adapter):
235235
assert result[0].is_primary_key is True
236236
assert result[1].name == "name"
237237
assert result[1].is_primary_key is False
238+
239+
240+
class TestMSSQLAdapterAzureAdPreflight:
241+
"""Pre-flight Entra-token check for ad_default auth converts the
242+
DefaultAzureCredential chain failure into an actionable AzureAdAuthError
243+
*before* the ODBC driver gets to emit its generic "Login failed for user ''".
244+
"""
245+
246+
@pytest.fixture
247+
def ad_default_config(self):
248+
"""Minimal ConnectionConfig stub with auth_type=ad_default."""
249+
from sqlit.domains.connections.domain.config import (
250+
ConnectionConfig,
251+
TcpEndpoint,
252+
)
253+
254+
endpoint = TcpEndpoint(host="example.database.windows.net", port="1433", database="mydb")
255+
return ConnectionConfig(
256+
name="t",
257+
db_type="mssql",
258+
endpoint=endpoint,
259+
options={"auth_type": "ad_default"},
260+
)
261+
262+
def test_preflight_surfaces_az_login_hint_on_credential_failure(self, ad_default_config):
263+
"""When DefaultAzureCredential.get_token raises, we raise AzureAdAuthError
264+
with the actionable 'Please run az login' line and a one-line hint."""
265+
from sqlit.domains.connections.providers.mssql.adapter import (
266+
AzureAdAuthError,
267+
SQLServerAdapter,
268+
)
269+
270+
chain_message = (
271+
"DefaultAzureCredential failed to retrieve a token from the included credentials.\n"
272+
"Attempted credentials:\n"
273+
"\tEnvironmentCredential: EnvironmentCredential authentication unavailable.\n"
274+
"\tAzureCliCredential: Please run 'az login' to set up an account\n"
275+
"\tAzurePowerShellCredential: Az.Account module >= 2.2.0 is not installed\n"
276+
)
277+
278+
fake_azure_core = MagicMock()
279+
fake_azure_core_exceptions = MagicMock()
280+
281+
class _ClientAuthError(Exception):
282+
pass
283+
284+
fake_azure_core_exceptions.ClientAuthenticationError = _ClientAuthError
285+
286+
fake_azure_identity = MagicMock()
287+
fake_credential = MagicMock()
288+
fake_credential.get_token.side_effect = _ClientAuthError(chain_message)
289+
fake_azure_identity.DefaultAzureCredential.return_value = fake_credential
290+
291+
with patch.dict(
292+
"sys.modules",
293+
{
294+
"azure": MagicMock(),
295+
"azure.core": fake_azure_core,
296+
"azure.core.exceptions": fake_azure_core_exceptions,
297+
"azure.identity": fake_azure_identity,
298+
},
299+
):
300+
adapter = SQLServerAdapter()
301+
with pytest.raises(AzureAdAuthError) as exc_info:
302+
adapter._preflight_azure_credentials(ad_default_config)
303+
304+
message = str(exc_info.value)
305+
assert "Please run 'az login'" in message
306+
assert "Azure AD authentication failed" in message
307+
# The verbose chain dump should NOT be included when we extracted a
308+
# specific actionable line — keep the error tight, like sqlcmd does.
309+
assert "DefaultAzureCredential failed to retrieve" not in message
310+
311+
def test_preflight_noop_when_azure_identity_missing(self, ad_default_config):
312+
"""If azure-identity is not installed, we silently skip and let the
313+
ODBC driver handle auth (preserving the existing behavior)."""
314+
from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter
315+
316+
import builtins
317+
318+
real_import = builtins.__import__
319+
320+
def _block_azure(name, *args, **kwargs):
321+
if name.startswith("azure."):
322+
raise ImportError(f"No module named {name!r}")
323+
return real_import(name, *args, **kwargs)
324+
325+
adapter = SQLServerAdapter()
326+
with patch("builtins.__import__", side_effect=_block_azure):
327+
# Should not raise
328+
adapter._preflight_azure_credentials(ad_default_config)
329+
330+
def test_preflight_skipped_for_non_ad_default_auth(self):
331+
"""Pre-flight only runs for ad_default; sql/ad_password etc. must be
332+
untouched even if azure-identity would fail."""
333+
from sqlit.domains.connections.domain.config import (
334+
ConnectionConfig,
335+
TcpEndpoint,
336+
)
337+
from sqlit.domains.connections.providers.mssql.adapter import SQLServerAdapter
338+
339+
endpoint = TcpEndpoint(host="h", port="1433", database="d", username="u", password="p")
340+
config = ConnectionConfig(
341+
name="t",
342+
db_type="mssql",
343+
endpoint=endpoint,
344+
options={"auth_type": "sql"},
345+
)
346+
347+
adapter = SQLServerAdapter()
348+
# Even if azure-identity would explode, this never touches it.
349+
adapter._preflight_azure_credentials(config)

0 commit comments

Comments
 (0)