Skip to content

Commit 80159d1

Browse files
authored
[Identity] Use form_post for InteractiveBrowserCredential (#46598)
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 2aba6e7 commit 80159d1

3 files changed

Lines changed: 48 additions & 7 deletions

File tree

sdk/identity/azure-identity/azure/identity/_credentials/browser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _request_token(self, *scopes: str, **kwargs) -> Dict:
115115
prompt="select_account",
116116
claims_challenge=claims,
117117
login_hint=self._login_hint,
118+
response_mode="form_post",
118119
)
119120
if "auth_uri" not in flow:
120121
raise CredentialUnavailableError("Failed to begin authentication flow")

sdk/identity/azure-identity/azure/identity/_internal/auth_code_redirect_handler.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@ def do_GET(self):
2020

2121
query = self.path.split("?", 1)[-1]
2222
parsed = parse_qs(query, keep_blank_values=True)
23-
self.server.query_params = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in parsed.items()}
23+
self.server.auth_response = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in parsed.items()}
24+
self._send_success_response()
2425

26+
def do_POST(self):
27+
content_length = int(self.headers.get("Content-Length", 0))
28+
body = self.rfile.read(content_length).decode("utf-8")
29+
parsed = parse_qs(body, keep_blank_values=True)
30+
self.server.auth_response = {k: v[0] if isinstance(v, list) and len(v) == 1 else v for k, v in parsed.items()}
31+
self._send_success_response()
32+
33+
def _send_success_response(self):
2534
self.send_response(200)
2635
self.send_header("Content-Type", "text/html")
2736
self.end_headers()
28-
2937
self.wfile.write(b"Authentication complete. You can close this window.")
3038

3139
def log_message(self, format, *args): # pylint: disable=redefined-builtin
@@ -35,14 +43,13 @@ def log_message(self, format, *args): # pylint: disable=redefined-builtin
3543
class AuthCodeRedirectServer(HTTPServer):
3644
"""HTTP server that listens for the redirect request following an authorization code authentication"""
3745

38-
query_params: Mapping[str, Any] = {}
39-
4046
def __init__(self, hostname: str, port: int, timeout: int) -> None:
4147
HTTPServer.__init__(self, (hostname, port), AuthCodeRedirectHandler)
4248
self.timeout = timeout
49+
self.auth_response: Mapping[str, Any] = {}
4350

4451
def wait_for_redirect(self) -> Mapping[str, Any]:
45-
while not self.query_params:
52+
while not self.auth_response:
4653
try:
4754
self.handle_request()
4855
except (IOError, ValueError):
@@ -53,7 +60,7 @@ def wait_for_redirect(self) -> Mapping[str, Any]:
5360
self.server_close()
5461

5562
# if we timed out, this returns an empty dict
56-
return self.query_params
63+
return self.auth_response
5764

5865
def handle_timeout(self):
5966
"""Break the request-handling loop by tearing down the server"""

sdk/identity/azure-identity/tests/test_browser_credential.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,40 @@ def test_redirect_server(get_token_method):
167167
response = urllib.request.urlopen(url) # nosec
168168

169169
assert response.code == 200
170-
assert server.query_params[expected_param] == expected_value
170+
assert server.auth_response[expected_param] == expected_value
171+
172+
173+
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
174+
def test_redirect_server_post(get_token_method):
175+
"""The redirect server should handle POST requests with form-encoded bodies (form_post response mode)"""
176+
177+
server = None
178+
hostname = "127.0.0.1"
179+
for _ in range(4):
180+
try:
181+
port = random.randint(1024, 65535)
182+
server = AuthCodeRedirectServer(hostname, port, timeout=10)
183+
break
184+
except socket.error:
185+
continue
186+
187+
assert server, "failed to start redirect server"
188+
189+
expected_param = "code"
190+
expected_value = "test-auth-code"
191+
192+
thread = threading.Thread(target=server.wait_for_redirect)
193+
thread.daemon = True
194+
thread.start()
195+
196+
# send a POST request with form-encoded body, simulating form_post response mode
197+
url = "http://{}:{}".format(hostname, port)
198+
data = urllib.parse.urlencode({expected_param: expected_value}).encode("utf-8")
199+
request = urllib.request.Request(url, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"})
200+
response = urllib.request.urlopen(request) # nosec
201+
202+
assert response.code == 200
203+
assert server.auth_response[expected_param] == expected_value
171204

172205

173206
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)

0 commit comments

Comments
 (0)