Skip to content

Commit 8ee83e6

Browse files
authored
[Identity] Allow policy override (#46072)
Similar to SDK clients, this allows credential pipelines to be customized at the policy level. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 0ebd192 commit 8ee83e6

6 files changed

Lines changed: 178 additions & 18 deletions

File tree

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Credential HTTP pipeline policies can now be overridden via the `headers_policy`, `logging_policy`, `http_logging_policy`, `proxy_policy`, `user_agent_policy`, `custom_hook_policy`, and `retry_policy` keyword arguments when constructing credentials. The `per_retry_policies` and `per_call_policies` are also now supported. This allows users to inject custom policies or override settings of built-in policies. ([#46072](https://github.com/Azure/azure-sdk-for-python/pull/46072))
8+
79
### Breaking Changes
810

911
### Bugs Fixed

sdk/identity/azure-identity/azure/identity/_credentials/imds.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def __init__(self, **kwargs: Any) -> None:
8484
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
8585
# the credential probes only if it's part of a ChainedTokenCredential chain.
8686
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
87-
super().__init__(retry_policy_class=ImdsRetryPolicy, **dict(PIPELINE_SETTINGS, **kwargs))
87+
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
88+
retry_policy = merged_kwargs.pop("retry_policy", None) or ImdsRetryPolicy(**merged_kwargs)
89+
super().__init__(retry_policy=retry_policy, **merged_kwargs)
8890
self._config = kwargs
8991

9092
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:

sdk/identity/azure-identity/azure/identity/_internal/pipeline.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
from collections.abc import Iterable
6+
57
from azure.core.configuration import Configuration
68
from azure.core.pipeline import Pipeline
79
from azure.core.pipeline.policies import (
@@ -28,27 +30,38 @@ def _get_config(**kwargs) -> Configuration:
2830
:rtype: ~azure.core.configuration.Configuration
2931
"""
3032
config: Configuration = Configuration(**kwargs)
31-
config.custom_hook_policy = CustomHookPolicy(**kwargs)
32-
config.headers_policy = HeadersPolicy(**kwargs)
33-
config.http_logging_policy = HttpLoggingPolicy(**kwargs)
34-
config.logging_policy = NetworkTraceLoggingPolicy(**kwargs)
35-
config.proxy_policy = ProxyPolicy(**kwargs)
36-
config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
33+
config.custom_hook_policy = kwargs.get("custom_hook_policy") or CustomHookPolicy(**kwargs)
34+
config.headers_policy = kwargs.get("headers_policy") or HeadersPolicy(**kwargs)
35+
config.http_logging_policy = kwargs.get("http_logging_policy") or HttpLoggingPolicy(**kwargs)
36+
config.logging_policy = kwargs.get("logging_policy") or NetworkTraceLoggingPolicy(**kwargs)
37+
config.proxy_policy = kwargs.get("proxy_policy") or ProxyPolicy(**kwargs)
38+
config.user_agent_policy = kwargs.get("user_agent_policy") or UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs)
3739
return config
3840

3941

40-
def _get_policies(config, _per_retry_policies=None, **kwargs):
42+
def _get_policies(config, **kwargs):
43+
per_call_policies = kwargs.get("per_call_policies", None) or []
44+
per_retry_policies = kwargs.get("per_retry_policies", None) or []
45+
4146
policies = [
4247
RequestIdPolicy(**kwargs),
4348
config.headers_policy,
4449
config.user_agent_policy,
4550
config.proxy_policy,
4651
ContentDecodePolicy(**kwargs),
47-
config.retry_policy,
4852
]
4953

50-
if _per_retry_policies:
51-
policies.extend(_per_retry_policies)
54+
if isinstance(per_call_policies, Iterable):
55+
policies.extend(per_call_policies)
56+
elif per_call_policies is not None:
57+
policies.append(per_call_policies)
58+
59+
policies.append(config.retry_policy)
60+
61+
if isinstance(per_retry_policies, Iterable):
62+
policies.extend(per_retry_policies)
63+
elif per_retry_policies is not None:
64+
policies.append(per_retry_policies)
5265

5366
policies.extend(
5467
[
@@ -65,8 +78,7 @@ def _get_policies(config, _per_retry_policies=None, **kwargs):
6578
def build_pipeline(transport=None, policies=None, **kwargs):
6679
if not policies:
6780
config = _get_config(**kwargs)
68-
retry_policy_class = kwargs.pop("retry_policy_class", None)
69-
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else RetryPolicy(**kwargs)
81+
config.retry_policy = kwargs.pop("retry_policy", None) or RetryPolicy(**kwargs)
7082
policies = _get_policies(config, **kwargs)
7183
if not transport:
7284
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module
@@ -85,8 +97,7 @@ def build_async_pipeline(transport=None, policies=None, **kwargs):
8597
from azure.core.pipeline.policies import AsyncRetryPolicy
8698

8799
config = _get_config(**kwargs)
88-
retry_policy_class = kwargs.pop("retry_policy_class", None)
89-
config.retry_policy = retry_policy_class(**kwargs) if retry_policy_class else AsyncRetryPolicy(**kwargs)
100+
config.retry_policy = kwargs.pop("retry_policy", None) or AsyncRetryPolicy(**kwargs)
90101
policies = _get_policies(config, **kwargs)
91102
if not transport:
92103
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module

sdk/identity/azure-identity/azure/identity/aio/_credentials/azure_arc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_client(self, **kwargs: Any) -> Optional[AsyncManagedIdentityClient]:
2020
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
2121
if url and imds:
2222
return AsyncManagedIdentityClient(
23-
_per_retry_policies=[ArcChallengeAuthPolicy()],
23+
per_retry_policies=[ArcChallengeAuthPolicy()],
2424
request_factory=functools.partial(_get_request, url),
2525
**kwargs,
2626
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ def __init__(self, **kwargs: Any) -> None:
4949
# probes for the IMDS endpoint before attempting to get a token. If None (the default),
5050
# the credential probes only if it's part of a ChainedTokenCredential chain.
5151
self._enable_imds_probe = kwargs.pop("_enable_imds_probe", None)
52-
kwargs["retry_policy_class"] = AsyncImdsRetryPolicy
53-
self._client = AsyncManagedIdentityClient(_get_request, **dict(PIPELINE_SETTINGS, **kwargs))
52+
merged_kwargs = dict(PIPELINE_SETTINGS, **kwargs)
53+
retry_policy = merged_kwargs.pop("retry_policy", None) or AsyncImdsRetryPolicy(**merged_kwargs)
54+
self._client = AsyncManagedIdentityClient(_get_request, retry_policy=retry_policy, **merged_kwargs)
5455
if EnvironmentVariables.AZURE_POD_IDENTITY_AUTHORITY_HOST in os.environ:
5556
self._endpoint_available: Optional[bool] = True
5657
else:
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
"""Tests for policy override support in azure-identity pipelines."""
6+
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from azure.core.pipeline.policies import (
12+
ContentDecodePolicy,
13+
CustomHookPolicy,
14+
DistributedTracingPolicy,
15+
HeadersPolicy,
16+
HttpLoggingPolicy,
17+
NetworkTraceLoggingPolicy,
18+
ProxyPolicy,
19+
RequestIdPolicy,
20+
RetryPolicy,
21+
SansIOHTTPPolicy,
22+
UserAgentPolicy,
23+
)
24+
25+
from azure.identity._internal.pipeline import (
26+
_get_config,
27+
_get_policies,
28+
build_pipeline,
29+
build_async_pipeline,
30+
)
31+
32+
CONFIG_POLICIES = [
33+
("custom_hook_policy", CustomHookPolicy),
34+
("headers_policy", HeadersPolicy),
35+
("http_logging_policy", HttpLoggingPolicy),
36+
("logging_policy", NetworkTraceLoggingPolicy),
37+
("proxy_policy", ProxyPolicy),
38+
("user_agent_policy", UserAgentPolicy),
39+
]
40+
41+
42+
class TestGetConfigPolicyOverrides:
43+
"""Tests that _get_config respects policy override kwargs."""
44+
45+
def test_default_policies_created_when_no_overrides(self):
46+
config = _get_config()
47+
for attr, cls in CONFIG_POLICIES:
48+
assert isinstance(getattr(config, attr), cls)
49+
50+
@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
51+
def test_single_policy_override(self, kwarg, cls):
52+
custom = Mock(spec=cls)
53+
config = _get_config(**{kwarg: custom})
54+
assert getattr(config, kwarg) is custom
55+
56+
@pytest.mark.parametrize("kwarg,cls", CONFIG_POLICIES)
57+
def test_non_overridden_policies_unaffected(self, kwarg, cls):
58+
"""Overriding one policy should not affect others."""
59+
custom = Mock(spec=cls)
60+
config = _get_config(**{kwarg: custom})
61+
for other_attr, other_cls in CONFIG_POLICIES:
62+
if other_attr == kwarg:
63+
assert getattr(config, other_attr) is custom
64+
else:
65+
assert isinstance(getattr(config, other_attr), other_cls)
66+
67+
68+
class TestGetPoliciesOverrides:
69+
"""Tests for per_call_policies and per_retry_policies in _get_policies."""
70+
71+
@staticmethod
72+
def _make_config():
73+
config = _get_config()
74+
config.retry_policy = RetryPolicy()
75+
return config
76+
77+
def test_default_policy_order(self):
78+
policies = _get_policies(self._make_config())
79+
80+
assert [type(p) for p in policies] == [
81+
RequestIdPolicy,
82+
HeadersPolicy,
83+
UserAgentPolicy,
84+
ProxyPolicy,
85+
ContentDecodePolicy,
86+
RetryPolicy,
87+
CustomHookPolicy,
88+
NetworkTraceLoggingPolicy,
89+
DistributedTracingPolicy,
90+
HttpLoggingPolicy,
91+
]
92+
93+
@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
94+
def test_per_call_policies_inserted_before_retry(self, as_list):
95+
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
96+
arg = custom_policies if as_list else custom_policies[0]
97+
98+
policies = _get_policies(self._make_config(), per_call_policies=arg)
99+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
100+
for custom in custom_policies:
101+
assert policies.index(custom) < retry_idx
102+
103+
@pytest.mark.parametrize("as_list", [False, True], ids=["single", "list"])
104+
def test_per_retry_policies_inserted_after_retry(self, as_list):
105+
custom_policies = [Mock(spec=SansIOHTTPPolicy) for _ in range(2 if as_list else 1)]
106+
arg = custom_policies if as_list else custom_policies[0]
107+
108+
policies = _get_policies(self._make_config(), per_retry_policies=arg)
109+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
110+
for custom in custom_policies:
111+
assert policies.index(custom) > retry_idx
112+
113+
def test_both_per_call_and_per_retry(self):
114+
per_call = Mock(spec=SansIOHTTPPolicy)
115+
per_retry = Mock(spec=SansIOHTTPPolicy)
116+
117+
policies = _get_policies(self._make_config(), per_call_policies=per_call, per_retry_policies=per_retry)
118+
retry_idx = next(i for i, p in enumerate(policies) if isinstance(p, RetryPolicy))
119+
assert policies.index(per_call) < retry_idx
120+
assert policies.index(per_retry) > retry_idx
121+
122+
123+
class TestBuildPipelineOverrides:
124+
"""Tests for policy overrides in build_pipeline and build_async_pipeline."""
125+
126+
@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
127+
def test_retry_policy_override(self, builder):
128+
custom_retry = Mock(spec=RetryPolicy)
129+
pipeline = builder(retry_policy=custom_retry, transport=Mock())
130+
effective_policies = [getattr(policy, "_policy", policy) for policy in pipeline._impl_policies]
131+
assert custom_retry in effective_policies
132+
133+
def test_default_retry_policy_when_no_override(self):
134+
pipeline = build_pipeline(transport=Mock())
135+
retry_policies = [p for p in pipeline._impl_policies if isinstance(p, RetryPolicy)]
136+
assert len(retry_policies) == 1
137+
138+
@pytest.mark.parametrize("builder", [build_pipeline, build_async_pipeline])
139+
def test_policy_override_flows_through(self, builder):
140+
"""Verify that config policy overrides reach the pipeline."""
141+
custom_headers = Mock(spec=HeadersPolicy)
142+
pipeline = builder(headers_policy=custom_headers, transport=Mock())
143+
effective_policies = [getattr(policy, "_policy", policy) for policy in pipeline._impl_policies]
144+
assert custom_headers in effective_policies

0 commit comments

Comments
 (0)