@@ -235,3 +235,115 @@ def test_get_columns_returns_primary_keys(self, adapter):
235235 assert result [0 ].is_primary_key is True
236236 assert result [1 ].name == "name"
237237 assert result [1 ].is_primary_key is False
238+
239+
240+ class TestMSSQLAdapterAzureAdPreflight :
241+ """Pre-flight Entra-token check for ad_default auth converts the
242+ DefaultAzureCredential chain failure into an actionable AzureAdAuthError
243+ *before* the ODBC driver gets to emit its generic "Login failed for user ''".
244+ """
245+
246+ @pytest .fixture
247+ def ad_default_config (self ):
248+ """Minimal ConnectionConfig stub with auth_type=ad_default."""
249+ from sqlit .domains .connections .domain .config import (
250+ ConnectionConfig ,
251+ TcpEndpoint ,
252+ )
253+
254+ endpoint = TcpEndpoint (host = "example.database.windows.net" , port = "1433" , database = "mydb" )
255+ return ConnectionConfig (
256+ name = "t" ,
257+ db_type = "mssql" ,
258+ endpoint = endpoint ,
259+ options = {"auth_type" : "ad_default" },
260+ )
261+
262+ def test_preflight_surfaces_az_login_hint_on_credential_failure (self , ad_default_config ):
263+ """When DefaultAzureCredential.get_token raises, we raise AzureAdAuthError
264+ with the actionable 'Please run az login' line and a one-line hint."""
265+ from sqlit .domains .connections .providers .mssql .adapter import (
266+ AzureAdAuthError ,
267+ SQLServerAdapter ,
268+ )
269+
270+ chain_message = (
271+ "DefaultAzureCredential failed to retrieve a token from the included credentials.\n "
272+ "Attempted credentials:\n "
273+ "\t EnvironmentCredential: EnvironmentCredential authentication unavailable.\n "
274+ "\t AzureCliCredential: Please run 'az login' to set up an account\n "
275+ "\t AzurePowerShellCredential: Az.Account module >= 2.2.0 is not installed\n "
276+ )
277+
278+ fake_azure_core = MagicMock ()
279+ fake_azure_core_exceptions = MagicMock ()
280+
281+ class _ClientAuthError (Exception ):
282+ pass
283+
284+ fake_azure_core_exceptions .ClientAuthenticationError = _ClientAuthError
285+
286+ fake_azure_identity = MagicMock ()
287+ fake_credential = MagicMock ()
288+ fake_credential .get_token .side_effect = _ClientAuthError (chain_message )
289+ fake_azure_identity .DefaultAzureCredential .return_value = fake_credential
290+
291+ with patch .dict (
292+ "sys.modules" ,
293+ {
294+ "azure" : MagicMock (),
295+ "azure.core" : fake_azure_core ,
296+ "azure.core.exceptions" : fake_azure_core_exceptions ,
297+ "azure.identity" : fake_azure_identity ,
298+ },
299+ ):
300+ adapter = SQLServerAdapter ()
301+ with pytest .raises (AzureAdAuthError ) as exc_info :
302+ adapter ._preflight_azure_credentials (ad_default_config )
303+
304+ message = str (exc_info .value )
305+ assert "Please run 'az login'" in message
306+ assert "Azure AD authentication failed" in message
307+ # The verbose chain dump should NOT be included when we extracted a
308+ # specific actionable line — keep the error tight, like sqlcmd does.
309+ assert "DefaultAzureCredential failed to retrieve" not in message
310+
311+ def test_preflight_noop_when_azure_identity_missing (self , ad_default_config ):
312+ """If azure-identity is not installed, we silently skip and let the
313+ ODBC driver handle auth (preserving the existing behavior)."""
314+ from sqlit .domains .connections .providers .mssql .adapter import SQLServerAdapter
315+
316+ import builtins
317+
318+ real_import = builtins .__import__
319+
320+ def _block_azure (name , * args , ** kwargs ):
321+ if name .startswith ("azure." ):
322+ raise ImportError (f"No module named { name !r} " )
323+ return real_import (name , * args , ** kwargs )
324+
325+ adapter = SQLServerAdapter ()
326+ with patch ("builtins.__import__" , side_effect = _block_azure ):
327+ # Should not raise
328+ adapter ._preflight_azure_credentials (ad_default_config )
329+
330+ def test_preflight_skipped_for_non_ad_default_auth (self ):
331+ """Pre-flight only runs for ad_default; sql/ad_password etc. must be
332+ untouched even if azure-identity would fail."""
333+ from sqlit .domains .connections .domain .config import (
334+ ConnectionConfig ,
335+ TcpEndpoint ,
336+ )
337+ from sqlit .domains .connections .providers .mssql .adapter import SQLServerAdapter
338+
339+ endpoint = TcpEndpoint (host = "h" , port = "1433" , database = "d" , username = "u" , password = "p" )
340+ config = ConnectionConfig (
341+ name = "t" ,
342+ db_type = "mssql" ,
343+ endpoint = endpoint ,
344+ options = {"auth_type" : "sql" },
345+ )
346+
347+ adapter = SQLServerAdapter ()
348+ # Even if azure-identity would explode, this never touches it.
349+ adapter ._preflight_azure_credentials (config )
0 commit comments