diff --git a/dojo/middleware.py b/dojo/middleware.py index 5d63b1a35a0..ab89bf5a849 100644 --- a/dojo/middleware.py +++ b/dojo/middleware.py @@ -6,13 +6,18 @@ from urllib.parse import quote import pghistory.middleware +import requests from auditlog.context import set_actor from auditlog.middleware import AuditlogMiddleware as _AuditlogMiddleware from django.conf import settings +from django.contrib import messages from django.db import models from django.http import HttpResponseRedirect +from django.shortcuts import redirect from django.urls import reverse from django.utils.functional import SimpleLazyObject +from social_core.exceptions import AuthCanceled, AuthFailed, AuthForbidden +from social_django.middleware import SocialAuthExceptionMiddleware from watson.middleware import SearchContextMiddleware from watson.search import search_context_manager @@ -75,6 +80,23 @@ def __call__(self, request): return self.get_response(request) +class CustomSocialAuthExceptionMiddleware(SocialAuthExceptionMiddleware): + def process_exception(self, request, exception): + if isinstance(exception, requests.exceptions.RequestException): + messages.error(request, "Please use the standard login below.") + return redirect("/login?force_login_form") + if isinstance(exception, AuthCanceled): + messages.warning(request, "Social login was canceled. Please try again or use the standard login.") + return redirect("/login?force_login_form") + if isinstance(exception, AuthFailed): + messages.error(request, "Social login failed. Please try again or use the standard login.") + return redirect("/login?force_login_form") + if isinstance(exception, AuthForbidden): + messages.error(request, "You are not authorized to log in via this method. Please contact support or use the standard login.") + return redirect("/login?force_login_form") + return super().process_exception(request, exception) + + class DojoSytemSettingsMiddleware: _thread_local = local() diff --git a/dojo/settings/settings.dist.py b/dojo/settings/settings.dist.py index 6243e44a690..b2be58bc64d 100644 --- a/dojo/settings/settings.dist.py +++ b/dojo/settings/settings.dist.py @@ -936,7 +936,7 @@ def generate_url(scheme, double_slashes, user, password, host, port, path, param "django.middleware.clickjacking.XFrameOptionsMiddleware", "dojo.middleware.LoginRequiredMiddleware", "dojo.middleware.AdditionalHeaderMiddleware", - "social_django.middleware.SocialAuthExceptionMiddleware", + "dojo.middleware.CustomSocialAuthExceptionMiddleware", "crum.CurrentRequestUserMiddleware", "dojo.middleware.AuditlogMiddleware", "dojo.middleware.AsyncSearchContextMiddleware", diff --git a/unittests/test_social_auth_failure_handling.py b/unittests/test_social_auth_failure_handling.py new file mode 100644 index 00000000000..83f69471a02 --- /dev/null +++ b/unittests/test_social_auth_failure_handling.py @@ -0,0 +1,142 @@ +from unittest.mock import patch + +from django.contrib import messages +from django.contrib.auth.models import AnonymousUser +from django.contrib.messages.storage.fallback import FallbackStorage +from django.contrib.sessions.middleware import SessionMiddleware +from django.http import HttpResponse +from django.test import RequestFactory, override_settings +from requests.exceptions import ConnectionError as RequestsConnectionError +from social_core.exceptions import AuthCanceled, AuthFailed, AuthForbidden + +from dojo.middleware import CustomSocialAuthExceptionMiddleware + +from .dojo_test_case import DojoTestCase + + +class TestSocialAuthMiddlewareUnit(DojoTestCase): + + """ + Unit tests: + Directly test CustomSocialAuthExceptionMiddleware behavior + by simulating exceptions (ConnectionError, AuthCanceled, AuthFailed, AuthForbidden), + without relying on actual backend configuration or whether the + /complete// URLs are registered and accessible. + """ + + def setUp(self): + self.factory = RequestFactory() + self.middleware = CustomSocialAuthExceptionMiddleware(lambda *_: HttpResponse("OK")) + + def _prepare_request(self, path): + request = self.factory.get(path) + request.user = AnonymousUser() + SessionMiddleware(lambda *_: None).process_request(request) + request.session.save() + request._messages = FallbackStorage(request) + return request + + def test_social_auth_exception_redirects_to_login(self): + login_paths = [ + "/login/oidc/", + "/login/auth0/", + "/login/google-oauth2/", + "/login/okta-oauth2/", + "/login/azuread-tenant-oauth2/", + "/login/gitlab/", + "/login/keycloak-oauth2/", + "/login/github/", + ] + exceptions = [ + (RequestsConnectionError("Host unreachable"), "Please use the standard login below."), + (AuthCanceled("User canceled login"), "Social login was canceled. Please try again or use the standard login."), + (AuthFailed("Token exchange failed"), "Social login failed. Please try again or use the standard login."), + (AuthForbidden("User not allowed"), "You are not authorized to log in via this method. Please contact support or use the standard login."), + ] + for path in login_paths: + for exception, expected_message in exceptions: + with self.subTest(path=path, exception=type(exception).__name__): + request = self._prepare_request(path) + response = self.middleware.process_exception(request, exception) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, "/login?force_login_form") + storage = list(messages.get_messages(request)) + self.assertTrue(any(expected_message in str(msg) for msg in storage)) + + def test_non_social_auth_path_still_redirects_on_auth_exception(self): + """Ensure middleware handles AuthFailed even on unrelated paths.""" + request = self._prepare_request("/some/other/path/") + exception = AuthFailed("Should be handled globally") + response = self.middleware.process_exception(request, exception) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, "/login?force_login_form") + storage = list(messages.get_messages(request)) + self.assertTrue(any("Social login failed. Please try again or use the standard login." in str(msg) for msg in storage)) + + def test_non_social_auth_path_redirects_on_auth_forbidden(self): + """Ensure middleware handles AuthForbidden even on unrelated paths.""" + request = self._prepare_request("/some/other/path/") + exception = AuthForbidden("User not allowed") + response = self.middleware.process_exception(request, exception) + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url, "/login?force_login_form") + storage = list(messages.get_messages(request)) + self.assertTrue(any("You are not authorized to log in via this method." in str(msg) for msg in storage)) + + +@override_settings( + AUTHENTICATION_BACKENDS=( + "social_core.backends.github.GithubOAuth2", + "social_core.backends.gitlab.GitLabOAuth2", + "social_core.backends.keycloak.KeycloakOAuth2", + "social_core.backends.azuread_tenant.AzureADTenantOAuth2", + "social_core.backends.auth0.Auth0OAuth2", + "social_core.backends.okta.OktaOAuth2", + "social_core.backends.open_id_connect.OpenIdConnectAuth", + "django.contrib.auth.backends.ModelBackend", + ), +) +class TestSocialAuthIntegrationFailures(DojoTestCase): + + """ + Integration tests: + Simulate social login failures by calling /complete// URLs + and mocking auth_complete() to raise AuthFailed, AuthCanceled, and AuthForbidden. + Verifies that the middleware is correctly integrated and handles backend failures. + """ + + BACKEND_CLASS_PATHS = { + "github": "social_core.backends.github.GithubOAuth2", + "gitlab": "social_core.backends.gitlab.GitLabOAuth2", + "keycloak": "social_core.backends.keycloak.KeycloakOAuth2", + "azuread-tenant-oauth2": "social_core.backends.azuread_tenant.AzureADTenantOAuth2", + "auth0": "social_core.backends.auth0.Auth0OAuth2", + "okta-oauth2": "social_core.backends.okta.OktaOAuth2", + "oidc": "social_core.backends.open_id_connect.OpenIdConnectAuth", + } + + def _test_backend_exception(self, backend_slug, exception, expected_message): + backend_class_path = self.BACKEND_CLASS_PATHS[backend_slug] + with patch(f"{backend_class_path}.auth_complete", side_effect=exception): + response = self.client.get(f"/complete/{backend_slug}/", follow=True) + self.assertEqual(response.status_code, 200) + self.assertContains(response, expected_message) + + def test_all_backends_auth_failed(self): + for backend in self.BACKEND_CLASS_PATHS: + with self.subTest(backend=backend): + self._test_backend_exception(backend, AuthFailed(backend=None), "Social login failed. Please try again or use the standard login.") + + def test_all_backends_auth_canceled(self): + for backend in self.BACKEND_CLASS_PATHS: + with self.subTest(backend=backend): + self._test_backend_exception(backend, AuthCanceled(backend=None), "Social login was canceled. Please try again or use the standard login.") + + def test_all_backends_auth_forbidden(self): + for backend in self.BACKEND_CLASS_PATHS: + with self.subTest(backend=backend): + self._test_backend_exception( + backend, + AuthForbidden(backend=None), + "You are not authorized to log in via this method. Please contact support or use the standard login.", + )