|
| 1 | +import re |
| 2 | +import string |
| 3 | + |
1 | 4 | from tests import unittest |
2 | 5 |
|
3 | 6 | import msal |
4 | 7 | from msal import oauth2cli |
| 8 | +from msal.oauth2cli.oauth2 import _generate_pkce_code_verifier |
| 9 | + |
| 10 | + |
| 11 | +class TestCsprngUsage(unittest.TestCase): |
| 12 | + """Tests that security-critical parameters use cryptographically secure randomness.""" |
| 13 | + |
| 14 | + # RFC 7636 §4.1: code_verifier = 43*128unreserved |
| 15 | + _PKCE_ALPHABET = set(string.ascii_letters + string.digits + "-._~") |
| 16 | + |
| 17 | + def test_pkce_code_verifier_contains_only_valid_characters(self): |
| 18 | + for _ in range(50): |
| 19 | + result = _generate_pkce_code_verifier() |
| 20 | + self.assertTrue( |
| 21 | + set(result["code_verifier"]).issubset(self._PKCE_ALPHABET), |
| 22 | + "code_verifier contains invalid characters") |
| 23 | + |
| 24 | + def test_pkce_code_verifier_has_correct_default_length(self): |
| 25 | + result = _generate_pkce_code_verifier() |
| 26 | + self.assertEqual(len(result["code_verifier"]), 43) |
| 27 | + |
| 28 | + def test_pkce_code_verifier_respects_custom_length(self): |
| 29 | + for length in (43, 64, 128): |
| 30 | + result = _generate_pkce_code_verifier(length) |
| 31 | + self.assertEqual(len(result["code_verifier"]), length) |
| 32 | + |
| 33 | + def test_pkce_code_verifier_can_have_repeated_characters(self): |
| 34 | + """secrets.choice() samples with replacement, unlike the old random.sample().""" |
| 35 | + seen_repeat = False |
| 36 | + for _ in range(100): |
| 37 | + result = _generate_pkce_code_verifier(128) |
| 38 | + if len(set(result["code_verifier"])) < len(result["code_verifier"]): |
| 39 | + seen_repeat = True |
| 40 | + break |
| 41 | + self.assertTrue(seen_repeat, |
| 42 | + "At length 128 with a 66-char alphabet, repeated chars are expected") |
| 43 | + |
| 44 | + def test_pkce_code_verifier_is_not_deterministic(self): |
| 45 | + results = {_generate_pkce_code_verifier()["code_verifier"] for _ in range(10)} |
| 46 | + self.assertGreater(len(results), 1, "code_verifier should not be deterministic") |
| 47 | + |
| 48 | + def test_oauth2_state_is_url_safe_and_unpredictable(self): |
| 49 | + """State generated by initiate_auth_code_flow should be URL-safe.""" |
| 50 | + from msal.oauth2cli.oauth2 import Client |
| 51 | + client = Client( |
| 52 | + {"authorization_endpoint": "https://example.com/auth", |
| 53 | + "token_endpoint": "https://example.com/token"}, |
| 54 | + client_id="test_client") |
| 55 | + states = set() |
| 56 | + for _ in range(10): |
| 57 | + flow = client.initiate_auth_code_flow( |
| 58 | + redirect_uri="http://localhost", scope=["openid"]) |
| 59 | + state = flow["state"] |
| 60 | + self.assertRegex(state, r'^[A-Za-z0-9_-]+$', |
| 61 | + "state should be URL-safe") |
| 62 | + states.add(state) |
| 63 | + self.assertGreater(len(states), 1, "state should not be deterministic") |
| 64 | + |
| 65 | + def test_oidc_nonce_is_url_safe_and_unpredictable(self): |
| 66 | + """Nonce generated by OIDC initiate_auth_code_flow should be URL-safe.""" |
| 67 | + from msal.oauth2cli.oidc import Client |
| 68 | + client = Client( |
| 69 | + {"authorization_endpoint": "https://example.com/auth", |
| 70 | + "token_endpoint": "https://example.com/token"}, |
| 71 | + client_id="test_client") |
| 72 | + nonces = set() |
| 73 | + for _ in range(10): |
| 74 | + flow = client.initiate_auth_code_flow( |
| 75 | + redirect_uri="http://localhost", scope=["openid"]) |
| 76 | + nonce = flow["nonce"] |
| 77 | + self.assertRegex(nonce, r'^[A-Za-z0-9_-]+$', |
| 78 | + "nonce should be URL-safe") |
| 79 | + nonces.add(nonce) |
| 80 | + self.assertGreater(len(nonces), 1, "nonce should not be deterministic") |
5 | 81 |
|
6 | 82 |
|
7 | 83 | class TestIdToken(unittest.TestCase): |
|
0 commit comments