Skip to content

Commit 84ce1ce

Browse files
committed
add entra auth manager impl
1 parent c23db8d commit 84ce1ce

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

pyiceberg/catalog/rest/auth.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,68 @@ def auth_header(self) -> str:
249249
return f"Bearer {self.credentials.token}"
250250

251251

252+
class EntraAuthManager(AuthManager):
253+
"""Auth Manager implementation that supports Microsoft Entra ID (Azure AD) authentication.
254+
255+
This manager uses the Azure Identity library's DefaultAzureCredential which automatically
256+
tries multiple authentication methods including environment variables, managed identity,
257+
and Azure CLI.
258+
259+
See https://learn.microsoft.com/en-us/azure/developer/python/sdk/authentication/credential-chains
260+
for more details on DefaultAzureCredential.
261+
"""
262+
263+
DEFAULT_SCOPE = "https://storage.azure.com/.default"
264+
265+
def __init__(
266+
self,
267+
scopes: list[str] | None = None,
268+
**credential_kwargs: Any,
269+
):
270+
"""
271+
Initialize EntraAuthManager.
272+
273+
Args:
274+
scopes: List of OAuth2 scopes. Defaults to ["https://storage.azure.com/.default"].
275+
**credential_kwargs: Arguments passed to DefaultAzureCredential.
276+
Supported authentication methods:
277+
- Environment Variables: Set AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET
278+
- Managed Identity: Works automatically on Azure; for user-assigned, pass managed_identity_client_id
279+
- Azure CLI: Works automatically if logged in via `az login`
280+
- Workload Identity: Works automatically in AKS with workload identity configured # codespell:ignore aks
281+
"""
282+
try:
283+
from azure.identity import DefaultAzureCredential
284+
except ImportError as e:
285+
raise ImportError("Azure Identity library not found. Please install with: pip install pyiceberg[entra-auth]") from e
286+
287+
self._scopes = scopes or [self.DEFAULT_SCOPE]
288+
self._lock = threading.Lock()
289+
self._token: str | None = None
290+
self._expires_at: float = 0
291+
self._credential = DefaultAzureCredential(**credential_kwargs)
292+
293+
def _refresh_token(self) -> None:
294+
"""Refresh the access token from Azure."""
295+
token = self._credential.get_token(*self._scopes)
296+
self._token = token.token
297+
# expires_on is a Unix timestamp; add a 60-second margin for safety
298+
self._expires_at = token.expires_on - 60
299+
300+
def _get_token(self) -> str:
301+
"""Get a valid access token, refreshing if necessary."""
302+
with self._lock:
303+
if not self._token or time.time() >= self._expires_at:
304+
self._refresh_token()
305+
if self._token is None:
306+
raise ValueError("Failed to obtain Entra access token")
307+
return self._token
308+
309+
def auth_header(self) -> str:
310+
"""Return the Authorization header value with a valid Bearer token."""
311+
return f"Bearer {self._get_token()}"
312+
313+
252314
class AuthManagerAdapter(AuthBase):
253315
"""A `requests.auth.AuthBase` adapter for integrating an `AuthManager` into a `requests.Session`.
254316
@@ -330,3 +392,4 @@ def create(cls, class_or_name: str, config: dict[str, Any]) -> AuthManager:
330392
AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager)
331393
AuthManagerFactory.register("oauth2", OAuth2AuthManager)
332394
AuthManagerFactory.register("google", GoogleAuthManager)
395+
AuthManagerFactory.register("entra", EntraAuthManager)

tests/catalog/test_rest_auth.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import requests
2323
from requests_mock import Mocker
2424

25-
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, GoogleAuthManager, NoopAuthManager
25+
from pyiceberg.catalog.rest.auth import AuthManagerAdapter, BasicAuthManager, EntraAuthManager, GoogleAuthManager, NoopAuthManager
2626

2727
TEST_URI = "https://iceberg-test-catalog/"
2828
GOOGLE_CREDS_URI = "https://oauth2.googleapis.com/token"
@@ -153,3 +153,80 @@ def test_google_auth_manager_import_error() -> None:
153153
with patch.dict("sys.modules", {"google.auth": None, "google.auth.transport.requests": None}):
154154
with pytest.raises(ImportError, match="Google Auth libraries not found. Please install 'google-auth'."):
155155
GoogleAuthManager()
156+
157+
158+
@patch("azure.identity.DefaultAzureCredential")
159+
def test_entra_auth_manager_default_credential(mock_default_cred: MagicMock, rest_mock: Mocker) -> None:
160+
"""Test EntraAuthManager with DefaultAzureCredential."""
161+
mock_credential_instance = MagicMock()
162+
mock_token = MagicMock()
163+
mock_token.token = "entra_default_token"
164+
mock_token.expires_on = 9999999999 # Far future timestamp
165+
mock_credential_instance.get_token.return_value = mock_token
166+
mock_default_cred.return_value = mock_credential_instance
167+
168+
auth_manager = EntraAuthManager()
169+
session = requests.Session()
170+
session.auth = AuthManagerAdapter(auth_manager)
171+
session.get(TEST_URI)
172+
173+
mock_default_cred.assert_called_once_with()
174+
mock_credential_instance.get_token.assert_called_once_with("https://storage.azure.com/.default")
175+
history = rest_mock.request_history
176+
assert len(history) == 1
177+
actual_headers = history[0].headers
178+
assert actual_headers["Authorization"] == "Bearer entra_default_token"
179+
180+
181+
@patch("azure.identity.DefaultAzureCredential")
182+
def test_entra_auth_manager_with_managed_identity_client_id(mock_default_cred: MagicMock, rest_mock: Mocker) -> None:
183+
"""Test EntraAuthManager with managed_identity_client_id passed to DefaultAzureCredential."""
184+
mock_credential_instance = MagicMock()
185+
mock_token = MagicMock()
186+
mock_token.token = "entra_mi_token"
187+
mock_token.expires_on = 9999999999
188+
mock_credential_instance.get_token.return_value = mock_token
189+
mock_default_cred.return_value = mock_credential_instance
190+
191+
auth_manager = EntraAuthManager(managed_identity_client_id="user-assigned-client-id")
192+
session = requests.Session()
193+
session.auth = AuthManagerAdapter(auth_manager)
194+
session.get(TEST_URI)
195+
196+
mock_default_cred.assert_called_once_with(managed_identity_client_id="user-assigned-client-id")
197+
mock_credential_instance.get_token.assert_called_once_with("https://storage.azure.com/.default")
198+
history = rest_mock.request_history
199+
assert len(history) == 1
200+
actual_headers = history[0].headers
201+
assert actual_headers["Authorization"] == "Bearer entra_mi_token"
202+
203+
204+
@patch("azure.identity.DefaultAzureCredential")
205+
def test_entra_auth_manager_custom_scopes(mock_default_cred: MagicMock, rest_mock: Mocker) -> None:
206+
"""Test EntraAuthManager with custom scopes."""
207+
mock_credential_instance = MagicMock()
208+
mock_token = MagicMock()
209+
mock_token.token = "entra_custom_scope_token"
210+
mock_token.expires_on = 9999999999
211+
mock_credential_instance.get_token.return_value = mock_token
212+
mock_default_cred.return_value = mock_credential_instance
213+
214+
custom_scopes = ["https://datalake.azure.net/.default", "https://storage.azure.com/.default"]
215+
auth_manager = EntraAuthManager(scopes=custom_scopes)
216+
session = requests.Session()
217+
session.auth = AuthManagerAdapter(auth_manager)
218+
session.get(TEST_URI)
219+
220+
mock_default_cred.assert_called_once_with()
221+
mock_credential_instance.get_token.assert_called_once_with(*custom_scopes)
222+
history = rest_mock.request_history
223+
assert len(history) == 1
224+
actual_headers = history[0].headers
225+
assert actual_headers["Authorization"] == "Bearer entra_custom_scope_token"
226+
227+
228+
def test_entra_auth_manager_import_error() -> None:
229+
"""Test EntraAuthManager raises ImportError if azure-identity is not installed."""
230+
with patch.dict("sys.modules", {"azure.identity": None}):
231+
with pytest.raises(ImportError, match="Azure Identity library not found"):
232+
EntraAuthManager()

0 commit comments

Comments
 (0)