Skip to content

Commit ac935e7

Browse files
authored
fix: redirect to frontend with error params on oauth callback failure (#1254)
1 parent 8eeedd1 commit ac935e7

4 files changed

Lines changed: 166 additions & 1 deletion

File tree

components/renku_data_services/connected_services/apispec_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,8 @@ class CallbackParams(BaseAPISpec):
3636
model_config = ConfigDict(extra="ignore")
3737

3838
state: str = Field(default="")
39+
error: str | None = Field(default=None)
40+
error_description: str | None = Field(default=None)
41+
error_uri: str | None = Field(default=None)
42+
code: str | None = Field(default=None)
43+
iss: str | None = Field(default=None)

components/renku_data_services/connected_services/blueprints.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from typing import Any
5-
from urllib.parse import unquote, urlparse, urlunparse
5+
from urllib.parse import parse_qsl, unquote, urlencode, urlparse, urlunparse
66

77
from sanic import HTTPResponse, Request, empty, json, redirect
88
from sanic.response import JSONResponse
@@ -133,6 +133,16 @@ async def _callback(request: Request) -> HTTPResponse:
133133

134134
callback_url = self._get_callback_url(request)
135135

136+
if params.error:
137+
next_url = await self.connected_services_repo.get_oauth2_connection_next_url_by_state(params.state)
138+
if next_url:
139+
return redirect(to=self._append_query_params(next_url, params))
140+
logger.info(
141+
"OAuth callback returned an error but no pending connection next_url was found "
142+
f"for state={params.state!r}"
143+
)
144+
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
145+
136146
client = await self.oauth_http_client_factory.fetch_token(
137147
state=params.state, raw_url=request.url, callback_url=callback_url
138148
)
@@ -156,6 +166,17 @@ def _get_callback_url(self, request: Request) -> str:
156166
logger.warning("Forcing the callback URL to use https. Trusted proxies configuration may be incorrect.")
157167
return https_callback_url
158168

169+
@staticmethod
170+
def _append_query_params(url: str, params: CallbackParams) -> str:
171+
allowed_keys = {"error", "error_description", "state", "error_uri", "code", "iss"}
172+
allowed_params = [(k, v) for k, v in params.model_dump(include=allowed_keys, exclude_none=True).items() if v]
173+
if not allowed_params:
174+
return url
175+
parsed = urlparse(url)
176+
existing_params = parse_qsl(parsed.query, keep_blank_values=True)
177+
merged_query = urlencode([*existing_params, *allowed_params])
178+
return urlunparse(parsed._replace(query=merged_query))
179+
159180

160181
@dataclass(kw_only=True)
161182
class OAuth2ConnectionsBP(CustomBlueprint):

components/renku_data_services/connected_services/db.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,19 @@ async def get_oauth2_connection(self, connection_id: ULID, user: base_models.API
241241

242242
return connection
243243

244+
async def get_oauth2_connection_next_url_by_state(self, state: str) -> str | None:
245+
"""Get the saved next_url for a pending OAuth2 connection using state."""
246+
if not state:
247+
return None
248+
async with self.session_maker() as session:
249+
result = await session.scalars(
250+
select(schemas.OAuth2ConnectionORM)
251+
.where(schemas.OAuth2ConnectionORM.state == state)
252+
.where(schemas.OAuth2ConnectionORM.status == models.ConnectionStatus.pending)
253+
)
254+
connection = result.one_or_none()
255+
return connection.next_url if connection else None
256+
244257
async def get_provider_for_image(self, user: APIUser, image: Image) -> models.ImageProvider | None:
245258
"""Find a provider supporting the given image."""
246259
registry_urls = [f"http://{image.hostname}", f"https://{image.hostname}"]

test/bases/renku_data_services/data_api/test_connected_services.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sanic import Sanic
99
from sanic_testing.testing import SanicASGITestClient
1010

11+
from renku_data_services.connected_services.apispec_base import CallbackParams
12+
from renku_data_services.connected_services.blueprints import OAuth2ClientsBP
1113
from renku_data_services.data_api.app import register_all_handlers
1214
from renku_data_services.data_api.dependencies import DependencyManager
1315
from test.bases.renku_data_services.data_api.utils import create_dummy_oauth_client
@@ -365,6 +367,130 @@ async def test_callback_oauth2_authorization_flow(
365367
assert token_set.get("access_token") == "ACCESS_TOKEN"
366368

367369

370+
@pytest.mark.asyncio
371+
async def test_callback_oauth2_authorization_flow_error_redirects_to_next_url(
372+
oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider
373+
):
374+
provider = await create_oauth2_provider("provider_1")
375+
provider_id = provider["id"]
376+
377+
next_url = "https://example.org/my-ui/callback"
378+
qs = f"next_url={quote(next_url)}"
379+
380+
_, res = await oauth2_test_client.get(
381+
f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers
382+
)
383+
assert res.status_code == 302, res.text
384+
location = urlparse(res.headers["location"])
385+
state = parse_qs(location.query).get("state", [None])[0]
386+
assert state
387+
388+
callback_qs = f"state={quote(state)}&error=access_denied&error_description={quote('User canceled')}"
389+
_, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}")
390+
391+
assert res.status_code == 302, res.text
392+
redirect_location = urlparse(res.headers["location"])
393+
assert f"{redirect_location.scheme}://{redirect_location.netloc}{redirect_location.path}" == next_url
394+
redirected_query = parse_qs(redirect_location.query)
395+
assert redirected_query.get("error", [None])[0] == "access_denied"
396+
assert redirected_query.get("error_description", [None])[0] == "User canceled"
397+
assert redirected_query.get("state", [None])[0] == state
398+
399+
400+
@pytest.mark.asyncio
401+
async def test_callback_oauth2_authorization_flow_error_preserves_next_url_query(
402+
oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider
403+
):
404+
provider = await create_oauth2_provider("provider_1")
405+
provider_id = provider["id"]
406+
407+
next_url = "https://example.org/my-ui/callback?existing=1"
408+
qs = f"next_url={quote(next_url)}"
409+
410+
_, res = await oauth2_test_client.get(
411+
f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers
412+
)
413+
assert res.status_code == 302, res.text
414+
location = urlparse(res.headers["location"])
415+
state = parse_qs(location.query).get("state", [None])[0]
416+
assert state
417+
418+
callback_qs = (
419+
f"state={quote(state)}"
420+
"&error=access_denied"
421+
f"&error_description={quote('User canceled')}"
422+
f"&error_uri={quote('https://example.org/oauth/errors/access_denied')}"
423+
)
424+
_, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}")
425+
426+
assert res.status_code == 302, res.text
427+
redirect_location = urlparse(res.headers["location"])
428+
redirected_query = parse_qs(redirect_location.query)
429+
assert redirected_query.get("existing", [None])[0] == "1"
430+
assert redirected_query.get("error", [None])[0] == "access_denied"
431+
assert redirected_query.get("error_description", [None])[0] == "User canceled"
432+
assert redirected_query.get("error_uri", [None])[0] == "https://example.org/oauth/errors/access_denied"
433+
assert redirected_query.get("state", [None])[0] == state
434+
435+
436+
@pytest.mark.asyncio
437+
async def test_callback_oauth2_authorization_flow_error_redirects_with_all_optional_params(
438+
oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_provider
439+
):
440+
provider = await create_oauth2_provider("provider_1")
441+
provider_id = provider["id"]
442+
443+
next_url = "https://example.org/my-ui/callback"
444+
qs = f"next_url={quote(next_url)}"
445+
446+
_, res = await oauth2_test_client.get(
447+
f"/api/data/oauth2/providers/{provider_id}/authorize?{qs}", headers=user_headers
448+
)
449+
assert res.status_code == 302, res.text
450+
state = parse_qs(urlparse(res.headers["location"]).query).get("state", [None])[0]
451+
assert state
452+
453+
callback_qs = (
454+
f"state={quote(state)}"
455+
"&error=access_denied"
456+
f"&error_description={quote('User canceled')}"
457+
f"&error_uri={quote('https://example.org/oauth/errors/access_denied')}"
458+
"&code=auth-code-123"
459+
"&iss=https%3A%2F%2Fissuer.example.org"
460+
)
461+
_, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}")
462+
463+
assert res.status_code == 302, res.text
464+
redirected_query = parse_qs(urlparse(res.headers["location"]).query)
465+
assert redirected_query.get("state", [None])[0] == state
466+
assert redirected_query.get("error", [None])[0] == "access_denied"
467+
assert redirected_query.get("error_description", [None])[0] == "User canceled"
468+
assert redirected_query.get("error_uri", [None])[0] == "https://example.org/oauth/errors/access_denied"
469+
assert redirected_query.get("code", [None])[0] == "auth-code-123"
470+
assert redirected_query.get("iss", [None])[0] == "https://issuer.example.org"
471+
472+
473+
def test_append_query_params_returns_original_url_when_no_allowed_values() -> None:
474+
next_url = "https://example.org/my-ui/callback?existing=1"
475+
params = CallbackParams(state="")
476+
477+
result = OAuth2ClientsBP._append_query_params(next_url, params)
478+
479+
assert result == next_url
480+
481+
482+
@pytest.mark.asyncio
483+
async def test_callback_oauth2_authorization_flow_error_without_pending_connection_forbidden(
484+
oauth2_test_client: SanicASGITestClient,
485+
sanic_client: SanicASGITestClient,
486+
):
487+
_ = sanic_client
488+
callback_qs = "state=missing-state&error=access_denied&error_description=No+pending+connection"
489+
_, res = await oauth2_test_client.get(f"/api/data/oauth2/callback?{callback_qs}")
490+
491+
assert res.status_code == 403, res.text
492+
493+
368494
@pytest.mark.asyncio
369495
async def test_get_account(oauth2_test_client: SanicASGITestClient, user_headers, create_oauth2_connection):
370496
connection = await create_oauth2_connection("provider_1")

0 commit comments

Comments
 (0)