|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | +from urllib.parse import parse_qs, urlparse |
| 3 | + |
| 4 | +import pytest |
| 5 | +import requests |
| 6 | + |
| 7 | +from jumpstarter_cli_common.oidc import Config, _OAuth2Client |
| 8 | + |
| 9 | + |
| 10 | +class TestOAuth2ClientInit: |
| 11 | + def test_init_sets_attributes(self): |
| 12 | + client = _OAuth2Client(client_id="my-client", scope=["openid", "profile"]) |
| 13 | + assert client.client_id == "my-client" |
| 14 | + assert client.scope == ["openid", "profile"] |
| 15 | + assert client.redirect_uri is None |
| 16 | + assert isinstance(client._session, requests.Session) |
| 17 | + |
| 18 | + def test_init_with_redirect_uri(self): |
| 19 | + client = _OAuth2Client( |
| 20 | + client_id="my-client", |
| 21 | + scope=["openid"], |
| 22 | + redirect_uri="http://localhost:8080/callback", |
| 23 | + ) |
| 24 | + assert client.redirect_uri == "http://localhost:8080/callback" |
| 25 | + |
| 26 | + |
| 27 | +class TestOAuth2ClientVerifyProperty: |
| 28 | + def test_verify_getter_returns_session_verify(self): |
| 29 | + client = _OAuth2Client(client_id="c", scope=[]) |
| 30 | + # requests.Session defaults verify to True |
| 31 | + assert client.verify is True |
| 32 | + |
| 33 | + def test_verify_setter_updates_session_verify(self): |
| 34 | + client = _OAuth2Client(client_id="c", scope=[]) |
| 35 | + client.verify = False |
| 36 | + assert client.verify is False |
| 37 | + assert client._session.verify is False |
| 38 | + |
| 39 | + def test_verify_setter_with_cert_path(self): |
| 40 | + client = _OAuth2Client(client_id="c", scope=[]) |
| 41 | + client.verify = "/path/to/ca-bundle.crt" |
| 42 | + assert client.verify == "/path/to/ca-bundle.crt" |
| 43 | + assert client._session.verify == "/path/to/ca-bundle.crt" |
| 44 | + |
| 45 | + |
| 46 | +class TestOAuth2ClientCreateAuthorizationUrl: |
| 47 | + def test_basic_url_construction(self): |
| 48 | + client = _OAuth2Client(client_id="test-client", scope=["openid", "profile"]) |
| 49 | + url, state = client.create_authorization_url("https://auth.example.com/authorize") |
| 50 | + |
| 51 | + parsed = urlparse(url) |
| 52 | + params = parse_qs(parsed.query) |
| 53 | + |
| 54 | + assert parsed.scheme == "https" |
| 55 | + assert parsed.netloc == "auth.example.com" |
| 56 | + assert parsed.path == "/authorize" |
| 57 | + assert params["response_type"] == ["code"] |
| 58 | + assert params["client_id"] == ["test-client"] |
| 59 | + assert params["scope"] == ["openid profile"] |
| 60 | + assert params["state"] == [state] |
| 61 | + assert len(state) > 0 |
| 62 | + |
| 63 | + def test_includes_redirect_uri_when_set(self): |
| 64 | + client = _OAuth2Client( |
| 65 | + client_id="test-client", |
| 66 | + scope=["openid"], |
| 67 | + redirect_uri="http://localhost:9999/callback", |
| 68 | + ) |
| 69 | + url, _state = client.create_authorization_url("https://auth.example.com/authorize") |
| 70 | + params = parse_qs(urlparse(url).query) |
| 71 | + assert params["redirect_uri"] == ["http://localhost:9999/callback"] |
| 72 | + |
| 73 | + def test_no_redirect_uri_when_not_set(self): |
| 74 | + client = _OAuth2Client(client_id="test-client", scope=["openid"]) |
| 75 | + url, _state = client.create_authorization_url("https://auth.example.com/authorize") |
| 76 | + params = parse_qs(urlparse(url).query) |
| 77 | + assert "redirect_uri" not in params |
| 78 | + |
| 79 | + def test_extra_kwargs_included(self): |
| 80 | + client = _OAuth2Client(client_id="test-client", scope=["openid"]) |
| 81 | + url, _state = client.create_authorization_url( |
| 82 | + "https://auth.example.com/authorize", |
| 83 | + prompt="consent", |
| 84 | + nonce="abc123", |
| 85 | + ) |
| 86 | + params = parse_qs(urlparse(url).query) |
| 87 | + assert params["prompt"] == ["consent"] |
| 88 | + assert params["nonce"] == ["abc123"] |
| 89 | + |
| 90 | + def test_url_with_existing_query_params(self): |
| 91 | + client = _OAuth2Client(client_id="test-client", scope=["openid"]) |
| 92 | + url, _state = client.create_authorization_url("https://auth.example.com/authorize?foo=bar") |
| 93 | + # Should use '&' separator since URL already has '?' |
| 94 | + assert "authorize?foo=bar&" in url |
| 95 | + |
| 96 | + def test_state_is_unique(self): |
| 97 | + client = _OAuth2Client(client_id="test-client", scope=["openid"]) |
| 98 | + _, state1 = client.create_authorization_url("https://auth.example.com/authorize") |
| 99 | + _, state2 = client.create_authorization_url("https://auth.example.com/authorize") |
| 100 | + assert state1 != state2 |
| 101 | + |
| 102 | + |
| 103 | +class TestOAuth2ClientFetchToken: |
| 104 | + def _mock_response(self, json_data, status_code=200): |
| 105 | + mock_resp = MagicMock() |
| 106 | + mock_resp.json.return_value = json_data |
| 107 | + mock_resp.raise_for_status.return_value = None |
| 108 | + mock_resp.status_code = status_code |
| 109 | + return mock_resp |
| 110 | + |
| 111 | + def test_fetch_token_with_grant_type(self): |
| 112 | + token_data = {"access_token": "tok123", "token_type": "Bearer"} |
| 113 | + client = _OAuth2Client(client_id="my-client", scope=["openid", "profile"]) |
| 114 | + |
| 115 | + with patch.object(client._session, "post", return_value=self._mock_response(token_data)) as mock_post: |
| 116 | + result = client.fetch_token( |
| 117 | + "https://auth.example.com/token", |
| 118 | + grant_type="password", |
| 119 | + username="user", |
| 120 | + password="pass", |
| 121 | + ) |
| 122 | + |
| 123 | + assert result == token_data |
| 124 | + call_kwargs = mock_post.call_args |
| 125 | + post_data = call_kwargs.kwargs["data"] |
| 126 | + assert post_data["client_id"] == "my-client" |
| 127 | + assert post_data["grant_type"] == "password" |
| 128 | + assert post_data["username"] == "user" |
| 129 | + assert post_data["password"] == "pass" |
| 130 | + assert post_data["scope"] == "openid profile" |
| 131 | + assert call_kwargs.kwargs["headers"] == {"Accept": "application/json"} |
| 132 | + |
| 133 | + def test_fetch_token_with_authorization_response(self): |
| 134 | + token_data = {"access_token": "tok456", "token_type": "Bearer"} |
| 135 | + client = _OAuth2Client( |
| 136 | + client_id="my-client", |
| 137 | + scope=["openid"], |
| 138 | + redirect_uri="http://localhost:8080/callback", |
| 139 | + ) |
| 140 | + |
| 141 | + callback_url = "http://localhost:8080/callback?code=authcode123&state=xyz" |
| 142 | + |
| 143 | + with patch.object(client._session, "post", return_value=self._mock_response(token_data)) as mock_post: |
| 144 | + result = client.fetch_token( |
| 145 | + "https://auth.example.com/token", |
| 146 | + authorization_response=callback_url, |
| 147 | + ) |
| 148 | + |
| 149 | + assert result == token_data |
| 150 | + post_data = mock_post.call_args.kwargs["data"] |
| 151 | + assert post_data["code"] == "authcode123" |
| 152 | + assert post_data["redirect_uri"] == "http://localhost:8080/callback" |
| 153 | + assert post_data["grant_type"] == "authorization_code" |
| 154 | + |
| 155 | + def test_fetch_token_authorization_response_no_code(self): |
| 156 | + token_data = {"access_token": "tok789"} |
| 157 | + client = _OAuth2Client(client_id="my-client", scope=["openid"]) |
| 158 | + |
| 159 | + callback_url = "http://localhost:8080/callback?state=xyz" |
| 160 | + |
| 161 | + with patch.object(client._session, "post", return_value=self._mock_response(token_data)) as mock_post: |
| 162 | + result = client.fetch_token( |
| 163 | + "https://auth.example.com/token", |
| 164 | + authorization_response=callback_url, |
| 165 | + ) |
| 166 | + |
| 167 | + assert result == token_data |
| 168 | + post_data = mock_post.call_args.kwargs["data"] |
| 169 | + assert "code" not in post_data |
| 170 | + |
| 171 | + def test_fetch_token_authorization_response_without_redirect_uri(self): |
| 172 | + token_data = {"access_token": "tok000"} |
| 173 | + client = _OAuth2Client(client_id="my-client", scope=["openid"]) |
| 174 | + # redirect_uri is None |
| 175 | + |
| 176 | + callback_url = "http://localhost:8080/callback?code=abc" |
| 177 | + |
| 178 | + with patch.object(client._session, "post", return_value=self._mock_response(token_data)) as mock_post: |
| 179 | + client.fetch_token( |
| 180 | + "https://auth.example.com/token", |
| 181 | + authorization_response=callback_url, |
| 182 | + ) |
| 183 | + |
| 184 | + post_data = mock_post.call_args.kwargs["data"] |
| 185 | + assert "redirect_uri" not in post_data |
| 186 | + |
| 187 | + def test_fetch_token_scope_provided_in_kwargs(self): |
| 188 | + token_data = {"access_token": "tok999"} |
| 189 | + client = _OAuth2Client(client_id="my-client", scope=["openid", "profile"]) |
| 190 | + |
| 191 | + with patch.object(client._session, "post", return_value=self._mock_response(token_data)) as mock_post: |
| 192 | + client.fetch_token( |
| 193 | + "https://auth.example.com/token", |
| 194 | + grant_type="client_credentials", |
| 195 | + scope="custom_scope", |
| 196 | + ) |
| 197 | + |
| 198 | + post_data = mock_post.call_args.kwargs["data"] |
| 199 | + # When scope is provided in kwargs, it should not be overridden |
| 200 | + assert post_data["scope"] == "custom_scope" |
| 201 | + |
| 202 | + def test_fetch_token_raises_on_http_error(self): |
| 203 | + client = _OAuth2Client(client_id="my-client", scope=["openid"]) |
| 204 | + mock_resp = MagicMock() |
| 205 | + mock_resp.raise_for_status.side_effect = requests.HTTPError("401 Unauthorized") |
| 206 | + |
| 207 | + with patch.object(client._session, "post", return_value=mock_resp): |
| 208 | + with pytest.raises(requests.HTTPError): |
| 209 | + client.fetch_token("https://auth.example.com/token", grant_type="password") |
| 210 | + |
| 211 | + |
| 212 | +class TestConfigClient: |
| 213 | + def test_client_returns_oauth2_client(self): |
| 214 | + config = Config(issuer="https://issuer.example.com", client_id="test-client") |
| 215 | + session = config.client() |
| 216 | + assert isinstance(session, _OAuth2Client) |
| 217 | + assert session.client_id == "test-client" |
| 218 | + assert session.scope == ["openid", "profile"] |
| 219 | + |
| 220 | + def test_client_with_insecure_tls(self): |
| 221 | + config = Config(issuer="https://issuer.example.com", client_id="test-client", insecure_tls=True) |
| 222 | + session = config.client() |
| 223 | + assert session.verify is False |
| 224 | + |
| 225 | + def test_client_passes_kwargs(self): |
| 226 | + config = Config(issuer="https://issuer.example.com", client_id="test-client") |
| 227 | + session = config.client(redirect_uri="http://localhost:9999/callback") |
| 228 | + assert session.redirect_uri == "http://localhost:9999/callback" |
0 commit comments