Skip to content

Commit c1e031c

Browse files
committed
More changes and address PR comments
1 parent 984899b commit c1e031c

File tree

5 files changed

+260
-33
lines changed

5 files changed

+260
-33
lines changed

msal/application.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from .oauth2cli import Client, JwtAssertionCreator
1313
from .oauth2cli.oidc import decode_part
14-
from .authority import Authority, WORLD_WIDE
14+
from .authority import (
15+
Authority,
16+
WORLD_WIDE,
17+
_get_instance_discovery_endpoint,
18+
_get_instance_discovery_host,
19+
)
1520
from .mex import send_request as mex_send_request
1621
from .wstrust_request import send_request as wst_send_request
1722
from .wstrust_response import *
@@ -671,7 +676,7 @@ def __init__(
671676
self._region_detected = None
672677
self.client, self._regional_client = self._build_client(
673678
client_credential, self.authority)
674-
self.authority_groups = None
679+
self.authority_groups = {}
675680
self._telemetry_buffer = {}
676681
self._telemetry_lock = Lock()
677682
_msal_extension_check()
@@ -1304,9 +1309,16 @@ def _find_msal_accounts(self, environment):
13041309
}
13051310
return list(grouped_accounts.values())
13061311

1307-
def _get_instance_metadata(self): # This exists so it can be mocked in unit test
1312+
def _get_instance_metadata(self, instance): # This exists so it can be mocked in unit test
1313+
instance_discovery_host = _get_instance_discovery_host(instance)
13081314
resp = self.http_client.get(
1309-
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", # TBD: We may extend this to use self._instance_discovery endpoint
1315+
_get_instance_discovery_endpoint(instance),
1316+
params={
1317+
'api-version': '1.1',
1318+
'authorization_endpoint': (
1319+
"https://{}/common/oauth2/authorize".format(instance_discovery_host)
1320+
),
1321+
},
13101322
headers={'Accept': 'application/json'})
13111323
resp.raise_for_status()
13121324
return json.loads(resp.text)['metadata']
@@ -1318,10 +1330,10 @@ def _get_authority_aliases(self, instance):
13181330
# Then it is an ADFS/B2C/known_authority_hosts situation
13191331
# which may not reach the central endpoint, so we skip it.
13201332
return []
1321-
if not self.authority_groups:
1322-
self.authority_groups = [
1323-
set(group['aliases']) for group in self._get_instance_metadata()]
1324-
for group in self.authority_groups:
1333+
if instance not in self.authority_groups:
1334+
self.authority_groups[instance] = [
1335+
set(group['aliases']) for group in self._get_instance_metadata(instance)]
1336+
for group in self.authority_groups[instance]:
13251337
if instance in group:
13261338
return [alias for alias in group if alias != instance]
13271339
return []

msal/authority.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
_CIAM_DOMAIN_SUFFIX = ".ciamlogin.com"
4545

4646

47+
def _get_instance_discovery_host(instance):
48+
return instance if instance in WELL_KNOWN_AUTHORITY_HOSTS else WORLD_WIDE
49+
50+
51+
def _get_instance_discovery_endpoint(instance):
52+
return 'https://{}/common/discovery/instance'.format(
53+
_get_instance_discovery_host(instance))
54+
55+
4756
class AuthorityBuilder(object):
4857
def __init__(self, instance, tenant):
4958
"""A helper to save caller from doing string concatenation.
@@ -152,10 +161,8 @@ def _initialize_entra_authority(
152161
) or (len(parts) == 3 and parts[2].lower().startswith("b2c_"))
153162
self._is_known_to_developer = self.is_adfs or self._is_b2c or not validate_authority
154163
is_known_to_microsoft = self.instance in WELL_KNOWN_AUTHORITY_HOSTS
155-
instance_discovery_host = (
156-
self.instance if self.instance in WELL_KNOWN_AUTHORITY_HOSTS else WORLD_WIDE)
157-
instance_discovery_endpoint = 'https://{}/common/discovery/instance'.format( # Note: This URL seemingly returns V1 endpoint only
158-
instance_discovery_host
164+
instance_discovery_endpoint = _get_instance_discovery_endpoint( # Note: This URL seemingly returns V1 endpoint only
165+
self.instance
159166
) if instance_discovery in (None, True) else instance_discovery
160167
if instance_discovery_endpoint and not (
161168
is_known_to_microsoft or self._is_known_to_developer):

tests/http_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,44 @@ def raise_for_status(self):
4040
if self._raw_resp is not None: # Turns out `if requests.response` won't work
4141
# cause it would be True when 200<=status<400
4242
self._raw_resp.raise_for_status()
43+
44+
45+
class RecordingHttpClient(object):
46+
def __init__(self):
47+
self.get_calls = []
48+
self.post_calls = []
49+
self._get_routes = []
50+
self._post_routes = []
51+
52+
def add_get_route(self, matcher, responder):
53+
self._get_routes.append((matcher, responder))
54+
55+
def add_post_route(self, matcher, responder):
56+
self._post_routes.append((matcher, responder))
57+
58+
def get(self, url, params=None, headers=None, **kwargs):
59+
call = {
60+
"url": url,
61+
"params": params,
62+
"headers": headers,
63+
"kwargs": kwargs,
64+
}
65+
self.get_calls.append(call)
66+
for matcher, responder in self._get_routes:
67+
if matcher(call):
68+
return responder(call)
69+
return MinimalResponse(status_code=404, text="")
70+
71+
def post(self, url, params=None, data=None, headers=None, **kwargs):
72+
call = {
73+
"url": url,
74+
"params": params,
75+
"data": data,
76+
"headers": headers,
77+
"kwargs": kwargs,
78+
}
79+
self.post_calls.append(call)
80+
for matcher, responder in self._post_routes:
81+
if matcher(call):
82+
return responder(call)
83+
return MinimalResponse(status_code=404, text="")

tests/test_authority.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,11 @@ def test_wellknown_host_and_tenant(self):
5656
continue
5757
self._test_given_host_and_tenant(host, "common")
5858

59-
def test_new_sovereign_hosts_should_be_known_authorities(self):
60-
self.assertIn(AZURE_GOV_FR, WELL_KNOWN_AUTHORITY_HOSTS)
61-
self.assertIn(AZURE_GOV_DE, WELL_KNOWN_AUTHORITY_HOSTS)
62-
self.assertIn(AZURE_GOV_SG, WELL_KNOWN_AUTHORITY_HOSTS)
63-
6459
@patch("msal.authority._instance_discovery")
6560
@patch("msal.authority.tenant_discovery")
6661
def test_new_sovereign_hosts_should_build_authority_endpoints(
6762
self, tenant_discovery_mock, instance_discovery_mock):
68-
for host in (AZURE_GOV_FR, AZURE_GOV_DE, AZURE_GOV_SG):
63+
for host in WELL_KNOWN_AUTHORITY_HOSTS:
6964
tenant_discovery_mock.return_value = {
7065
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
7166
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
@@ -90,21 +85,21 @@ def test_new_sovereign_hosts_should_build_authority_endpoints(
9085
@patch("msal.authority.tenant_discovery")
9186
def test_known_authority_should_use_same_host_and_skip_instance_discovery(
9287
self, tenant_discovery_mock, instance_discovery_mock):
93-
host = AZURE_US_GOVERNMENT
94-
tenant_discovery_mock.return_value = {
95-
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
96-
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
97-
"issuer": "https://{}/common/v2.0".format(host),
98-
}
99-
c = MinimalHttpClient()
100-
Authority("https://{}/common".format(host), c)
101-
c.close()
88+
for host in WELL_KNOWN_AUTHORITY_HOSTS:
89+
tenant_discovery_mock.return_value = {
90+
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
91+
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
92+
"issuer": "https://{}/common/v2.0".format(host),
93+
}
94+
c = MinimalHttpClient()
95+
Authority("https://{}/common".format(host), c)
96+
c.close()
10297

103-
instance_discovery_mock.assert_not_called()
104-
tenant_discovery_endpoint = tenant_discovery_mock.call_args[0][0]
105-
self.assertTrue(
106-
tenant_discovery_endpoint.startswith(
107-
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)))
98+
instance_discovery_mock.assert_not_called()
99+
tenant_discovery_endpoint = tenant_discovery_mock.call_args[0][0]
100+
self.assertTrue(
101+
tenant_discovery_endpoint.startswith(
102+
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)))
108103

109104
@patch("msal.authority._instance_discovery")
110105
@patch("msal.authority.tenant_discovery")
@@ -361,7 +356,24 @@ def test_by_default_a_known_to_microsoft_authority_should_skip_validation_but_st
361356
app = msal.ClientApplication("id", authority="https://login.microsoftonline.com/common")
362357
known_to_microsoft_validation.assert_not_called()
363358
app.get_accounts() # This could make an instance metadata call for authority aliases
364-
instance_metadata.assert_called_once_with()
359+
instance_metadata.assert_called_once_with("login.microsoftonline.com")
360+
361+
def test_by_default_a_sovereign_known_authority_should_use_cloud_local_instance_metadata(
362+
self, instance_metadata, known_to_microsoft_validation, _):
363+
app = msal.ClientApplication("id", authority="https://login.microsoftonline.us/common")
364+
known_to_microsoft_validation.assert_not_called()
365+
app.get_accounts() # This could make an instance metadata call for authority aliases
366+
instance_metadata.assert_called_once_with("login.microsoftonline.us")
367+
368+
def test_fr_known_authority_should_still_work_when_instance_metadata_has_no_alias_entry(
369+
self, instance_metadata, known_to_microsoft_validation, _):
370+
app = msal.ClientApplication("id", authority="https://{}/common".format(AZURE_GOV_FR))
371+
known_to_microsoft_validation.assert_not_called()
372+
373+
accounts = app.get_accounts()
374+
375+
self.assertEqual([], accounts)
376+
instance_metadata.assert_called_once_with(AZURE_GOV_FR)
365377

366378
def test_validate_authority_boolean_should_skip_validation_and_instance_metadata(
367379
self, instance_metadata, known_to_microsoft_validation, _):
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import json
2+
3+
from tests import unittest
4+
from tests.http_client import RecordingHttpClient, MinimalResponse
5+
from msal.application import ConfidentialClientApplication
6+
7+
8+
class TestSovereignAuthorityForClientCredentialWithRecordingHttpClient(unittest.TestCase):
9+
def test_acquire_token_for_client_on_gov_fr_should_keep_calls_on_same_host(self):
10+
host = "login.sovcloud-identity.fr"
11+
expected_instance_discovery_url = "https://{}/common/discovery/instance".format(host)
12+
expected_instance_discovery_params = {
13+
"api-version": "1.1",
14+
"authorization_endpoint": (
15+
"https://{}/common/oauth2/authorize".format(host)
16+
),
17+
}
18+
19+
http_client = RecordingHttpClient()
20+
21+
def is_oidc_discovery(call):
22+
return call["url"].startswith(
23+
"https://{}/common/v2.0/.well-known/openid-configuration".format(host))
24+
25+
def oidc_discovery_response(_call):
26+
return MinimalResponse(status_code=200, text=json.dumps({
27+
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
28+
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
29+
"issuer": "https://{}/common/v2.0".format(host),
30+
}))
31+
32+
def is_instance_discovery(call):
33+
return (
34+
call["url"] == expected_instance_discovery_url
35+
and call["params"] == expected_instance_discovery_params
36+
)
37+
38+
def instance_discovery_response(_call):
39+
return MinimalResponse(status_code=200, text=json.dumps({
40+
"tenant_discovery_endpoint": (
41+
"https://login.microsoftonline.us/"
42+
"cab8a31a-1906-4287-a0d8-4eef66b95f6e/"
43+
"v2.0/.well-known/openid-configuration"
44+
),
45+
"api-version": "1.1",
46+
"metadata": [
47+
{
48+
"preferred_network": "login.microsoftonline.com",
49+
"preferred_cache": "login.windows.net",
50+
"aliases": [
51+
"login.microsoftonline.com",
52+
"login.windows.net",
53+
"login.microsoft.com",
54+
"sts.windows.net",
55+
],
56+
},
57+
{
58+
"preferred_network": "login.partner.microsoftonline.cn",
59+
"preferred_cache": "login.partner.microsoftonline.cn",
60+
"aliases": [
61+
"login.partner.microsoftonline.cn",
62+
"login.chinacloudapi.cn",
63+
],
64+
},
65+
{
66+
"preferred_network": "login.microsoftonline.de",
67+
"preferred_cache": "login.microsoftonline.de",
68+
"aliases": ["login.microsoftonline.de"],
69+
},
70+
{
71+
"preferred_network": "login.microsoftonline.us",
72+
"preferred_cache": "login.microsoftonline.us",
73+
"aliases": [
74+
"login.microsoftonline.us",
75+
"login.usgovcloudapi.net",
76+
],
77+
},
78+
{
79+
"preferred_network": "login-us.microsoftonline.com",
80+
"preferred_cache": "login-us.microsoftonline.com",
81+
"aliases": ["login-us.microsoftonline.com"],
82+
},
83+
],
84+
}))
85+
86+
token_counter = {"value": 0}
87+
88+
def is_token_call(call):
89+
return call["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host))
90+
91+
def token_response(_call):
92+
token_counter["value"] += 1
93+
return MinimalResponse(status_code=200, text=json.dumps({
94+
"access_token": "AT_{}".format(token_counter["value"]),
95+
"expires_in": 3600,
96+
}))
97+
98+
http_client.add_get_route(is_oidc_discovery, oidc_discovery_response)
99+
http_client.add_get_route(is_instance_discovery, instance_discovery_response)
100+
http_client.add_post_route(is_token_call, token_response)
101+
102+
app = ConfidentialClientApplication(
103+
"client_id",
104+
client_credential="secret",
105+
authority="https://{}/common".format(host),
106+
http_client=http_client,
107+
)
108+
109+
result1 = app.acquire_token_for_client(["scope1"])
110+
self.assertEqual("AT_1", result1.get("access_token"))
111+
112+
get_calls_after_first = list(http_client.get_calls)
113+
post_calls_after_first = list(http_client.post_calls)
114+
115+
result2 = app.acquire_token_for_client(["scope2"])
116+
self.assertEqual("AT_2", result2.get("access_token"))
117+
118+
post_count_after_scope2 = len(http_client.post_calls)
119+
get_count_after_scope2 = len(http_client.get_calls)
120+
121+
cached_result1 = app.acquire_token_for_client(["scope1"])
122+
self.assertEqual("AT_1", cached_result1.get("access_token"))
123+
124+
cached_result2 = app.acquire_token_for_client(["scope2"])
125+
self.assertEqual("AT_2", cached_result2.get("access_token"))
126+
127+
cached_result3 = app.acquire_token_for_client(["scope1"])
128+
self.assertEqual("AT_1", cached_result3.get("access_token"))
129+
130+
self.assertEqual(
131+
post_count_after_scope2,
132+
len(http_client.post_calls),
133+
"Subsequent same-scope calls should be served from cache without token POST")
134+
self.assertEqual(
135+
get_count_after_scope2,
136+
len(http_client.get_calls),
137+
"Subsequent same-authority calls should not trigger additional discovery GET")
138+
139+
self.assertEqual(1, len(get_calls_after_first), "First acquire should trigger one discovery GET")
140+
self.assertTrue(
141+
get_calls_after_first[0]["url"].startswith(
142+
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)))
143+
144+
self.assertEqual(1, len(post_calls_after_first), "First acquire should trigger one token POST")
145+
self.assertTrue(
146+
post_calls_after_first[0]["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host)))
147+
148+
self.assertEqual(1, len(http_client.get_calls), "Second acquire on same authority should not re-discover")
149+
self.assertEqual(2, len(http_client.post_calls), "Second acquire with a different scope should request another token")
150+
self.assertTrue(
151+
http_client.post_calls[1]["url"].startswith("https://{}/common/oauth2/v2.0/token".format(host)))
152+
153+
all_urls = [c["url"] for c in http_client.get_calls + http_client.post_calls]
154+
self.assertTrue(all("login.microsoftonline.com" not in url for url in all_urls))
155+
self.assertTrue(all("https://{}/".format(host) in url for url in all_urls))

0 commit comments

Comments
 (0)