|
8 | 8 | from sanic import Sanic |
9 | 9 | from sanic_testing.testing import SanicASGITestClient |
10 | 10 |
|
| 11 | +from renku_data_services.connected_services.apispec_base import CallbackParams |
| 12 | +from renku_data_services.connected_services.blueprints import OAuth2ClientsBP |
11 | 13 | from renku_data_services.data_api.app import register_all_handlers |
12 | 14 | from renku_data_services.data_api.dependencies import DependencyManager |
13 | 15 | from test.bases.renku_data_services.data_api.utils import create_dummy_oauth_client |
@@ -365,6 +367,130 @@ async def test_callback_oauth2_authorization_flow( |
365 | 367 | assert token_set.get("access_token") == "ACCESS_TOKEN" |
366 | 368 |
|
367 | 369 |
|
| 370 | +@pytest.mark.asyncio |
| 371 | +async def test_callback_oauth2_authorization_flow_error_redirects_to_next_url( |
| 372 | + oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider |
| 373 | +): |
| 374 | + provider = await create_oauth2_provider("provider_1") |
| 375 | + provider_id = provider["id"] |
| 376 | + |
| 377 | + next_url = "https://example.org/my-ui/callback" |
| 378 | + qs = f"next_url={quote(next_url)}" |
| 379 | + |
| 380 | + _, res = await oauth2_test_client.get( |
| 381 | + f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers |
| 382 | + ) |
| 383 | + assert res.status_code == 302, res.text |
| 384 | + location = urlparse(res.headers["location"]) |
| 385 | + state = parse_qs(location.query).get("state", [None])[0] |
| 386 | + assert state |
| 387 | + |
| 388 | + callback_qs = f"state={quote(state)}&error=access_denied&error_description={quote('User canceled')}" |
| 389 | + _, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}") |
| 390 | + |
| 391 | + assert res.status_code == 302, res.text |
| 392 | + redirect_location = urlparse(res.headers["location"]) |
| 393 | + assert f"{redirect_location.scheme}://{redirect_location.netloc}{redirect_location.path}" == next_url |
| 394 | + redirected_query = parse_qs(redirect_location.query) |
| 395 | + assert redirected_query.get("error", [None])[0] == "access_denied" |
| 396 | + assert redirected_query.get("error_description", [None])[0] == "User canceled" |
| 397 | + assert redirected_query.get("state", [None])[0] == state |
| 398 | + |
| 399 | + |
| 400 | +@pytest.mark.asyncio |
| 401 | +async def test_callback_oauth2_authorization_flow_error_preserves_next_url_query( |
| 402 | + oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider |
| 403 | +): |
| 404 | + provider = await create_oauth2_provider("provider_1") |
| 405 | + provider_id = provider["id"] |
| 406 | + |
| 407 | + next_url = "https://example.org/my-ui/callback?existing=1" |
| 408 | + qs = f"next_url={quote(next_url)}" |
| 409 | + |
| 410 | + _, res = await oauth2_test_client.get( |
| 411 | + f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers |
| 412 | + ) |
| 413 | + assert res.status_code == 302, res.text |
| 414 | + location = urlparse(res.headers["location"]) |
| 415 | + state = parse_qs(location.query).get("state", [None])[0] |
| 416 | + assert state |
| 417 | + |
| 418 | + callback_qs = ( |
| 419 | + f"state={quote(state)}" |
| 420 | + "&error=access_denied" |
| 421 | + f"&error_description={quote('User canceled')}" |
| 422 | + f"&error_uri={quote('https://example.org/oauth/errors/access_denied')}" |
| 423 | + ) |
| 424 | + _, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}") |
| 425 | + |
| 426 | + assert res.status_code == 302, res.text |
| 427 | + redirect_location = urlparse(res.headers["location"]) |
| 428 | + redirected_query = parse_qs(redirect_location.query) |
| 429 | + assert redirected_query.get("existing", [None])[0] == "1" |
| 430 | + assert redirected_query.get("error", [None])[0] == "access_denied" |
| 431 | + assert redirected_query.get("error_description", [None])[0] == "User canceled" |
| 432 | + assert redirected_query.get("error_uri", [None])[0] == "https://example.org/oauth/errors/access_denied" |
| 433 | + assert redirected_query.get("state", [None])[0] == state |
| 434 | + |
| 435 | + |
| 436 | +@pytest.mark.asyncio |
| 437 | +async def test_callback_oauth2_authorization_flow_error_redirects_with_all_optional_params( |
| 438 | + oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider |
| 439 | +): |
| 440 | + provider = await create_oauth2_provider("provider_1") |
| 441 | + provider_id = provider["id"] |
| 442 | + |
| 443 | + next_url = "https://example.org/my-ui/callback" |
| 444 | + qs = f"next_url={quote(next_url)}" |
| 445 | + |
| 446 | + _, res = await oauth2_test_client.get( |
| 447 | + f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers |
| 448 | + ) |
| 449 | + assert res.status_code == 302, res.text |
| 450 | + state = parse_qs(urlparse(res.headers["location"]).query).get("state", [None])[0] |
| 451 | + assert state |
| 452 | + |
| 453 | + callback_qs = ( |
| 454 | + f"state={quote(state)}" |
| 455 | + "&error=access_denied" |
| 456 | + f"&error_description={quote('User canceled')}" |
| 457 | + f"&error_uri={quote('https://example.org/oauth/errors/access_denied')}" |
| 458 | + "&code=auth-code-123" |
| 459 | + "&iss=https%3A%2F%2Fissuer.example.org" |
| 460 | + ) |
| 461 | + _, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}") |
| 462 | + |
| 463 | + assert res.status_code == 302, res.text |
| 464 | + redirected_query = parse_qs(urlparse(res.headers["location"]).query) |
| 465 | + assert redirected_query.get("state", [None])[0] == state |
| 466 | + assert redirected_query.get("error", [None])[0] == "access_denied" |
| 467 | + assert redirected_query.get("error_description", [None])[0] == "User canceled" |
| 468 | + assert redirected_query.get("error_uri", [None])[0] == "https://example.org/oauth/errors/access_denied" |
| 469 | + assert redirected_query.get("code", [None])[0] == "auth-code-123" |
| 470 | + assert redirected_query.get("iss", [None])[0] == "https://issuer.example.org" |
| 471 | + |
| 472 | + |
| 473 | +def test_append_query_params_returns_original_url_when_no_allowed_values() -> None: |
| 474 | + next_url = "https://example.org/my-ui/callback?existing=1" |
| 475 | + params = CallbackParams(state="") |
| 476 | + |
| 477 | + result = OAuth2ClientsBP._append_query_params(next_url, params) |
| 478 | + |
| 479 | + assert result == next_url |
| 480 | + |
| 481 | + |
| 482 | +@pytest.mark.asyncio |
| 483 | +async def test_callback_oauth2_authorization_flow_error_without_pending_connection_forbidden( |
| 484 | + oauth2_test_client: SanicASGITestClient, |
| 485 | + sanic_client: SanicASGITestClient, |
| 486 | +): |
| 487 | + _ = sanic_client |
| 488 | + callback_qs = "state=missing-state&error=access_denied&error_description=No+pending+connection" |
| 489 | + _, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}") |
| 490 | + |
| 491 | + assert res.status_code == 403, res.text |
| 492 | + |
| 493 | + |
368 | 494 | @pytest.mark.asyncio |
369 | 495 | async def test_get_account(oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_connection): |
370 | 496 | connection = await create_oauth2_connection("provider_1") |
|
0 commit comments