diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9d1737..a81cee5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ This page tries to contain all use facing changes made on DocHub. # Unreleased + * Auto-recover from common CAS login errors (stale bookmark, refreshed callback) instead of showing an error page + # 2026.5.4 * Show the deployed version in the footer, linked to the changelog diff --git a/users/authBackend.py b/users/authBackend.py index 8b12cdf7..e06a6f95 100644 --- a/users/authBackend.py +++ b/users/authBackend.py @@ -99,7 +99,14 @@ def _parse_response(self, xml): "./cas:authenticationFailure", namespaces=self.XML_NAMESPACES ) if failure is not None: - raise CasRejectError(failure.attrib.get("code"), failure.text) + # Distinct subclasses for the common codes + code = failure.attrib.get("code") + if code == "INVALID_TICKET": + raise CasInvalidTicket(code, failure.text) + elif code == "INVALID_SERVICE": + raise CasInvalidService(code, failure.text) + else: + raise CasRejectError(code, failure.text) else: raise CasParseError("UNKNOWN_STRUCTURE", xml) @@ -157,8 +164,22 @@ class CasRequestError(CasError): class CasParseError(CasError): - pass + def __init__(self, code, debug): + super().__init__(code, debug) + self.code = code + self.debug = debug class CasRejectError(CasError): + def __init__(self, code, debug): + super().__init__(code, debug) + self.code = code + self.debug = debug + + +class CasInvalidTicket(CasRejectError): + pass + + +class CasInvalidService(CasRejectError): pass diff --git a/users/tests/auth_backend_parser_test.py b/users/tests/auth_backend_parser_test.py index 7ea73d93..be0d406a 100644 --- a/users/tests/auth_backend_parser_test.py +++ b/users/tests/auth_backend_parser_test.py @@ -1,6 +1,12 @@ import pytest -from users.authBackend import CasParseError, CasRejectError, UlbCasBackend +from users.authBackend import ( + CasInvalidService, + CasInvalidTicket, + CasParseError, + CasRejectError, + UlbCasBackend, +) # Parse valid cases @@ -60,25 +66,28 @@ def test_unknown_structure(path): @pytest.mark.parametrize( - ("path", "expected_error", "expected_text"), + ("path", "expected_exc", "expected_error", "expected_text"), [ ( "users/tests/xml-fixtures/invalid-service.xml", + CasInvalidService, "INVALID_SERVICE", "does not match supplied service", ), ( "users/tests/xml-fixtures/invalid-ticket.xml", + CasInvalidTicket, "INVALID_TICKET", "not recognized", ), ], ) -def test_invalid_service(path, expected_error, expected_text): +def test_invalid_service(path, expected_exc, expected_error, expected_text): with open(path) as fd: xml = fd.read() - with pytest.raises(CasRejectError) as e: + with pytest.raises(expected_exc) as e: UlbCasBackend()._parse_response(xml) + assert isinstance(e.value, CasRejectError) assert e.value.args[0] == expected_error assert expected_text in e.value.args[1] diff --git a/users/tests/auth_ulb_view_test.py b/users/tests/auth_ulb_view_test.py new file mode 100644 index 00000000..5731ded2 --- /dev/null +++ b/users/tests/auth_ulb_view_test.py @@ -0,0 +1,70 @@ +from django.urls import reverse + +import pytest +import responses + +from users.models import CasFailure + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def fake_base_url(settings): + settings.BASE_URL = "http://example.com/" + + +def _mock_cas_response(fixture_path): + with open(fixture_path) as fd: + xml = fd.read() + responses.add( + responses.GET, + "https://auth.ulb.be/proxyValidate", + body=xml, + status=200, + ) + + +@pytest.mark.parametrize( + ("fixture", "code"), + [ + ("users/tests/xml-fixtures/invalid-service.xml", "INVALID_SERVICE"), + ("users/tests/xml-fixtures/invalid-ticket.xml", "INVALID_TICKET"), + ], +) +@responses.activate +def test_recoverable_reject_redirects_to_login(client, fake_base_url, fixture, code): + """CAS rejecting with a recoverable code triggers a quiet retry through /login.""" + _mock_cas_response(fixture) + + response = client.get(reverse("auth-ulb"), {"ticket": "ST-x"}) + + assert response.status_code == 302 + assert response.url == reverse("login") + assert response.cookies["cas_autoretry"].value == "1" + + failure = CasFailure.objects.get() + assert failure.code == f"AUTORETRY__{code}" + assert failure.ticket == "ST-x" + + +@pytest.mark.parametrize( + ("fixture", "code"), + [ + ("users/tests/xml-fixtures/invalid-service.xml", "INVALID_SERVICE"), + ("users/tests/xml-fixtures/invalid-ticket.xml", "INVALID_TICKET"), + ], +) +@responses.activate +def test_recoverable_reject_does_not_loop_when_cookie_set( + client, fake_base_url, fixture, code +): + """Once we've already tried to recover, surface the error instead of looping.""" + _mock_cas_response(fixture) + client.cookies["cas_autoretry"] = "1" + + response = client.get(reverse("auth-ulb"), {"ticket": "ST-x"}) + + assert response.status_code == 200 + assert f"CAS_{code}".encode() in response.content + assert CasFailure.objects.filter(code=code).exists() + assert not CasFailure.objects.filter(code__startswith="AUTORETRY__").exists() diff --git a/users/views.py b/users/views.py index 36c06a33..f5e2a8ae 100644 --- a/users/views.py +++ b/users/views.py @@ -10,6 +10,8 @@ from requests.exceptions import ConnectionError, SSLError from users.authBackend import ( + CasInvalidService, + CasInvalidTicket, CasParseError, CasRejectError, CasRequestError, @@ -56,13 +58,28 @@ def auth_ulb(request): try: user = authenticate(ticket=ticket) + except (CasInvalidService, CasInvalidTicket) as e: + already_retried = request.COOKIES.get("cas_autoretry") == "1" + if already_retried: + # Show an error page to the user + logger.exception("CAS rejected after recovery attempt") + _log_cas_failure(request, ticket, e.code, e.debug) + return TemplateResponse( + request, "users/auth/error.html", {"code": e.code, "debug": e.debug} + ) + else: + # If it's the first try, just redirect the user to the login page + # so they get an automatic retry. + # Also set a 60s cookie to avoid a second (or more) retry and an infinite redirect loop + _log_cas_failure(request, ticket, f"AUTORETRY__{e.code}", e.debug) + resp = HttpResponseRedirect(reverse("login")) + resp.set_cookie("cas_autoretry", "1", max_age=60) + return resp except CasRejectError as e: logger.exception("CAS rejected") - code = e.args[0] - debug = e.args[1] - _log_cas_failure(request, ticket, code, debug) + _log_cas_failure(request, ticket, e.code, e.debug) return TemplateResponse( - request, "users/auth/error.html", {"code": code, "debug": debug} + request, "users/auth/error.html", {"code": e.code, "debug": e.debug} ) except CasRequestError as e: logger.exception("CAS request error") @@ -75,11 +92,9 @@ def auth_ulb(request): ) except CasParseError as e: logger.exception("CAS parse error") - code = e.args[0] - debug = e.args[1] - _log_cas_failure(request, ticket, code, debug) + _log_cas_failure(request, ticket, e.code, e.debug) return TemplateResponse( - request, "users/auth/error.html", {"code": code, "debug": debug} + request, "users/auth/error.html", {"code": e.code, "debug": e.debug} ) except (ConnectionError, SSLError) as e: logger.exception("CAS SSL error")