Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/dependency-wheel-promotion-gate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const { data: files } = await github.rest.pulls.listFiles({
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.payload.pull_request.number,
Expand Down
24 changes: 22 additions & 2 deletions postgres/assets/configuration/spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1189,17 +1189,26 @@ files:
description: |
Configuration section used for Azure AD Authentication.

This supports using System or User assigned managed identities.
This supports using System, User assigned managed identities or
workload identity federation (e.g. on AKS).

For more information on configuration, see
https://docs.datadoghq.com/database_monitoring/guide/managed_authentication

For more information on Managed Identities, see the Azure docs
https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview

For more information on Workload Identity, see the Azure docs
https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview

To enable Azure AD Authentication, set `azure.managed_authentication.enabled` to `true`.
Additionally set `azure.managed_authentication.auth_type` to `managed_identity` (default)
or `workload_identity`.
For managed identity, `client_id` is required.
For workload identity, `client_id` and `tenant_id` are optional overrides. By default they are
read from the `AZURE_CLIENT_ID` and `AZURE_TENANT_ID` environment variables, which are
automatically injected by the AKS workload identity webhook along with `AZURE_FEDERATED_TOKEN_FILE`.
If `azure.managed_authentication.enabled` is set, then the `password` fields will be ignored.
`azure.managed_authentication.client_id` is required to enable Azure AD Authentication.

For more information on scopes, see the Azure docs
https://learn.microsoft.com/en-us/azure/active-directory/develop/scopes-oidc
Expand All @@ -1209,7 +1218,18 @@ files:
- name: enabled
type: boolean
example: false
- name: auth_type
type: string
example: managed_identity
default: managed_identity
- name: client_id
description: |
Required for `managed_identity` auth. Optional for `workload_identity`,
where it defaults to the `AZURE_CLIENT_ID` environment variable.
type: string
- name: tenant_id
description: |
Only used for `workload_identity` auth.
type: string
- name: identity_scope
type: string
Expand Down
1 change: 1 addition & 0 deletions postgres/changelog.d/23436.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Azure Workload Identity authentication support.
23 changes: 15 additions & 8 deletions postgres/datadog_checks/postgres/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
# Licensed under a 3-clause BSD style license (see LICENSE)

from azure.core.credentials import AccessToken
from azure.identity import ManagedIdentityCredential
from azure.identity import ManagedIdentityCredential, WorkloadIdentityCredential

DEFAULT_PERMISSION_SCOPE = "https://ossrdbms-aad.database.windows.net/.default"

AZURE_AUTH_TYPES = ('managed_identity', 'workload_identity')
AZURE_DEFAULT_AUTH_TYPE = 'managed_identity'

# Use the azure identity API to generate a token that will be used
# authenticate with either a system or user assigned managed identity
def generate_managed_identity_token(client_id: str, identity_scope: str = None) -> AccessToken:
credential = ManagedIdentityCredential(client_id=client_id)
if not identity_scope:
identity_scope = DEFAULT_PERMISSION_SCOPE
return credential.get_token(identity_scope)

def generate_azure_token(
auth_type: str,
client_id: str | None = None,
tenant_id: str | None = None,
identity_scope: str | None = None,
) -> AccessToken:
if auth_type == 'workload_identity':
credential = WorkloadIdentityCredential(client_id=client_id, tenant_id=tenant_id)
else:
credential = ManagedIdentityCredential(client_id=client_id)
return credential.get_token(identity_scope or DEFAULT_PERMISSION_SCOPE)
15 changes: 11 additions & 4 deletions postgres/datadog_checks/postgres/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
from typing import TYPE_CHECKING, Optional, Tuple

from datadog_checks.postgres.azure import AZURE_AUTH_TYPES, AZURE_DEFAULT_AUTH_TYPE
from datadog_checks.postgres.config_models import InstanceConfig, defaults, dict_defaults
from datadog_checks.postgres.config_models.instance import (
Aws,
Expand Down Expand Up @@ -360,10 +361,16 @@ def apply_cloud_defaults(args: dict, instance: dict, validation_result: Validati
'managed_authentication': {**args.get('azure', {}).get('managed_authentication', {}), 'enabled': True},
}

if args.get('azure', {}).get('managed_authentication', {}).get('enabled') and not args.get('azure', {}).get(
'managed_authentication', {}
).get('client_id'):
validation_result.add_error('Azure client_id must be set when using Azure managed authentication')
azure_managed_auth = args.get('azure', {}).get('managed_authentication', {})
if azure_managed_auth.get('enabled'):
azure_auth_type = azure_managed_auth.get('auth_type') or AZURE_DEFAULT_AUTH_TYPE
if azure_auth_type not in AZURE_AUTH_TYPES:
validation_result.add_error(
f"Invalid azure.managed_authentication.auth_type '{azure_auth_type}'. "
f"Must be one of {AZURE_AUTH_TYPES}."
)
if azure_auth_type == 'managed_identity' and not azure_managed_auth.get('client_id'):
validation_result.add_error('Azure client_id must be set when using managed_identity authentication')


def deprecation_warning(option: str, replacement: str):
Expand Down
7 changes: 6 additions & 1 deletion postgres/datadog_checks/postgres/config_models/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ class ManagedAuthentication1(BaseModel):
arbitrary_types_allowed=True,
frozen=True,
)
client_id: Optional[str] = None
auth_type: Optional[str] = Field('managed_identity', examples=['managed_identity'])
client_id: Optional[str] = Field(
None,
description='Required for `managed_identity` auth. Optional for `workload_identity`,\nwhere it defaults to the `AZURE_CLIENT_ID` environment variable.\n',
)
enabled: Optional[bool] = Field(None, examples=[False])
identity_scope: Optional[str] = Field(None, examples=['https://ossrdbms-aad.database.windows.net/.default'])
tenant_id: Optional[str] = Field(None, description='Only used for `workload_identity` auth.\n')


class Azure(BaseModel):
Expand Down
25 changes: 18 additions & 7 deletions postgres/datadog_checks/postgres/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,31 @@ def _fetch_token(self) -> Tuple[str, float]:


class AzureTokenProvider(TokenProvider):
"""
Token provider for Azure Managed Identity.
"""
"""Token provider for Azure managed authentication."""

def __init__(self, client_id: str, identity_scope: str = None, skew_seconds: int = 60):
def __init__(
self,
auth_type: str,
client_id: str | None = None,
tenant_id: str | None = None,
identity_scope: str | None = None,
skew_seconds: int = 60,
):
super().__init__(skew_seconds=skew_seconds)
self.auth_type = auth_type
self.client_id = client_id
self.tenant_id = tenant_id
self.identity_scope = identity_scope

def _fetch_token(self) -> Tuple[str, float]:
# Import azure only when this method is called
from .azure import generate_managed_identity_token
from .azure import generate_azure_token

token = generate_managed_identity_token(client_id=self.client_id, identity_scope=self.identity_scope)
token = generate_azure_token(
auth_type=self.auth_type,
client_id=self.client_id,
tenant_id=self.tenant_id,
identity_scope=self.identity_scope,
)
return token.token, float(token.expires_on)


Expand Down
13 changes: 11 additions & 2 deletions postgres/datadog_checks/postgres/data/conf.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -771,17 +771,26 @@ instances:
## @param managed_authentication - mapping - optional
## Configuration section used for Azure AD Authentication.
##
## This supports using System or User assigned managed identities.
## This supports using System, User assigned managed identities or
## workload identity federation (e.g. on AKS).
##
## For more information on configuration, see
## https://docs.datadoghq.com/database_monitoring/guide/managed_authentication
##
## For more information on Managed Identities, see the Azure docs
## https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
##
## For more information on Workload Identity, see the Azure docs
## https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
##
## To enable Azure AD Authentication, set `azure.managed_authentication.enabled` to `true`.
## Additionally set `azure.managed_authentication.auth_type` to `managed_identity` (default)
## or `workload_identity`.
## For managed identity, `client_id` is required.
## For workload identity, `client_id` and `tenant_id` are optional overrides. By default they are
## read from the `AZURE_CLIENT_ID` and `AZURE_TENANT_ID` environment variables, which are
## automatically injected by the AKS workload identity webhook along with `AZURE_FEDERATED_TOKEN_FILE`.
## If `azure.managed_authentication.enabled` is set, then the `password` fields will be ignored.
## `azure.managed_authentication.client_id` is required to enable Azure AD Authentication.
##
## For more information on scopes, see the Azure docs
## https://learn.microsoft.com/en-us/azure/active-directory/develop/scopes-oidc
Expand Down
3 changes: 3 additions & 0 deletions postgres/datadog_checks/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,11 @@ def build_token_provider(self) -> TokenProvider:
role_arn=self._config.aws.managed_authentication.role_arn,
)
elif self._config.azure.managed_authentication.enabled:
auth_type = self._config.azure.managed_authentication.auth_type
return AzureTokenProvider(
auth_type=auth_type,
client_id=self._config.azure.managed_authentication.client_id,
tenant_id=self._config.azure.managed_authentication.tenant_id,
identity_scope=self._config.azure.managed_authentication.identity_scope,
)
else:
Expand Down
35 changes: 35 additions & 0 deletions postgres/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,41 @@ def test_cloud_validations(mock_check, minimal_instance):
assert config.azure.managed_authentication.enabled


def test_cloud_validations_azure_workload_identity(mock_check, minimal_instance):
"""Test that workload_identity auth_type is valid without client_id."""
instance = minimal_instance.copy()
instance['azure'] = {'managed_authentication': {'enabled': True, 'auth_type': 'workload_identity'}}
instance['password'] = None
mock_check.instance = instance
mock_check.init_config = {}
config, result = build_config(check=mock_check)
assert result.valid
assert config.azure.managed_authentication.auth_type == 'workload_identity'


@pytest.mark.parametrize(
'auth_type, has_client_id, expect_valid',
[
('workload_identity', False, True),
('managed_identity', True, True),
('managed_identity', False, False),
('bad_type', True, False),
],
ids=['workload_no_client_id', 'managed_with_client_id', 'managed_missing_client_id', 'invalid_auth_type'],
)
def test_azure_auth_type_validation(mock_check, minimal_instance, auth_type, has_client_id, expect_valid):
instance = minimal_instance.copy()
auth_cfg = {'enabled': True, 'auth_type': auth_type}
if has_client_id:
auth_cfg['client_id'] = 'some-id'
instance['azure'] = {'managed_authentication': auth_cfg}
instance['password'] = None
mock_check.instance = instance
mock_check.init_config = {}
_, result = build_config(check=mock_check)
assert result.valid == expect_valid


@pytest.mark.parametrize(
'rds_host, expected_rds_tag, expected_instance_endpoint',
[
Expand Down
63 changes: 58 additions & 5 deletions postgres/tests/test_token_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from unittest.mock import MagicMock, Mock, patch

import pytest

from datadog_checks.postgres.connection_pool import (
AWSTokenProvider,
AzureTokenProvider,
Expand Down Expand Up @@ -183,15 +185,18 @@ def test_aws_token_provider_integration():

def test_azure_token_provider_initialization():
"""Test AzureTokenProvider initialization."""
provider = AzureTokenProvider(client_id="test-client-id", identity_scope="https://test.scope/.default")
provider = AzureTokenProvider(
auth_type="managed_identity", client_id="test-client-id", identity_scope="https://test.scope/.default"
)

assert provider.auth_type == "managed_identity"
assert provider.client_id == "test-client-id"
assert provider.identity_scope == "https://test.scope/.default"


def test_azure_token_provider_initialization_without_scope():
"""Test AzureTokenProvider initialization without identity_scope."""
provider = AzureTokenProvider(client_id="test-client-id")
provider = AzureTokenProvider(auth_type="managed_identity", client_id="test-client-id")

assert provider.identity_scope is None

Expand All @@ -207,7 +212,9 @@ def test_azure_fetch_token_with_scope(mock_credential_class):
mock_credential.get_token.return_value = mock_token
mock_credential_class.return_value = mock_credential

provider = AzureTokenProvider(client_id="test-client-id", identity_scope="https://custom.scope/.default")
provider = AzureTokenProvider(
auth_type="managed_identity", client_id="test-client-id", identity_scope="https://custom.scope/.default"
)

token, expires_at = provider._fetch_token()

Expand All @@ -228,7 +235,7 @@ def test_azure_fetch_token_without_scope(mock_credential_class):
mock_credential.get_token.return_value = mock_token
mock_credential_class.return_value = mock_credential

provider = AzureTokenProvider(client_id="test-client-id")
provider = AzureTokenProvider(auth_type="managed_identity", client_id="test-client-id")

token, expires_at = provider._fetch_token()

Expand All @@ -249,7 +256,7 @@ def test_azure_token_provider_integration():
mock_credential.get_token.return_value = mock_token
mock_credential_class.return_value = mock_credential

provider = AzureTokenProvider(client_id="test-client-id")
provider = AzureTokenProvider(auth_type="managed_identity", client_id="test-client-id")

# First call should fetch token
token1 = provider.get_token()
Expand Down Expand Up @@ -365,6 +372,52 @@ def mock_ConnectionPool(*args, **pool_kwargs):
pool_manager_2.close_all()


@pytest.mark.parametrize(
'client_id, tenant_id, expected_kwargs',
[
(None, None, {'client_id': None, 'tenant_id': None}),
('c', None, {'client_id': 'c', 'tenant_id': None}),
(None, 't', {'client_id': None, 'tenant_id': 't'}),
('c', 't', {'client_id': 'c', 'tenant_id': 't'}),
],
)
@patch('datadog_checks.postgres.azure.WorkloadIdentityCredential')
def test_workload_identity_credential_kwargs(mock_cred, client_id, tenant_id, expected_kwargs):
mock_token = Mock()
mock_token.token = "wi_token"
mock_token.expires_on = time.time() + 3600
mock_cred.return_value.get_token.return_value = mock_token

provider = AzureTokenProvider(auth_type='workload_identity', client_id=client_id, tenant_id=tenant_id)
provider.get_token()

mock_cred.assert_called_once_with(**expected_kwargs)


def test_azure_workload_identity_token_provider_integration():
"""Test AzureTokenProvider with workload_identity auth_type."""
with patch('datadog_checks.postgres.azure.WorkloadIdentityCredential') as mock_credential_class:
mock_token = Mock()
mock_token.token = "integration_wi_token"
mock_token.expires_on = time.time() + 3600

mock_credential = Mock()
mock_credential.get_token.return_value = mock_token
mock_credential_class.return_value = mock_credential

provider = AzureTokenProvider(auth_type='workload_identity', client_id="test-client", tenant_id="test-tenant")

# First call should fetch token
token1 = provider.get_token()
assert token1 == "integration_wi_token"
assert mock_credential.get_token.call_count == 1

# Second call should use cached token
token2 = provider.get_token()
assert token2 == "integration_wi_token"
assert mock_credential.get_token.call_count == 1


class MockTokenProvider(TokenProvider):
"""Mock implementation of TokenProvider for testing."""

Expand Down
Loading
Loading