Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions dojo/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from urllib.parse import quote

import pghistory.middleware
import requests
from auditlog.context import set_actor
from auditlog.middleware import AuditlogMiddleware as _AuditlogMiddleware
from django.conf import settings
from django.contrib import messages
from django.db import models
from django.http import HttpResponseRedirect
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.functional import SimpleLazyObject
from social_core.exceptions import AuthCanceled, AuthFailed
from social_django.middleware import SocialAuthExceptionMiddleware
from watson.middleware import SearchContextMiddleware
from watson.search import search_context_manager

Expand Down Expand Up @@ -75,6 +80,20 @@ def __call__(self, request):
return self.get_response(request)


class CustomSocialAuthExceptionMiddleware(SocialAuthExceptionMiddleware):
def process_exception(self, request, exception):
if isinstance(exception, requests.exceptions.RequestException):
messages.error(request, "Login via social authentication is temporarily unavailable. Please use the standard login below.")
Comment thread
manuel-sommer marked this conversation as resolved.
Outdated
return redirect("/login")
Comment thread
manuel-sommer marked this conversation as resolved.
Outdated
if isinstance(exception, AuthCanceled):
messages.warning(request, "Social login was canceled. Please try again or use the standard login.")
return redirect("/login")
if isinstance(exception, AuthFailed):
messages.error(request, "Social login failed. Please try again or use the standard login.")
return redirect("/login")
return super().process_exception(request, exception)


class DojoSytemSettingsMiddleware:
_thread_local = local()

Expand Down
2 changes: 1 addition & 1 deletion dojo/settings/settings.dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def generate_url(scheme, double_slashes, user, password, host, port, path, param
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"dojo.middleware.LoginRequiredMiddleware",
"dojo.middleware.AdditionalHeaderMiddleware",
"social_django.middleware.SocialAuthExceptionMiddleware",
"dojo.middleware.CustomSocialAuthExceptionMiddleware",
"crum.CurrentRequestUserMiddleware",
"dojo.middleware.AuditlogMiddleware",
"dojo.middleware.AsyncSearchContextMiddleware",
Expand Down
122 changes: 122 additions & 0 deletions unittests/test_social_auth_failure_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from unittest.mock import patch

from django.contrib import messages
from django.contrib.auth.models import AnonymousUser
from django.contrib.messages.storage.fallback import FallbackStorage
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpResponse
from django.test import RequestFactory, override_settings
from requests.exceptions import ConnectionError as RequestsConnectionError
from social_core.exceptions import AuthCanceled, AuthFailed

from dojo.middleware import CustomSocialAuthExceptionMiddleware

from .dojo_test_case import DojoTestCase


class TestSocialAuthMiddlewareUnit(DojoTestCase):

"""
Unit tests:
Directly test CustomSocialAuthExceptionMiddleware behavior
by simulating exceptions (ConnectionError, AuthCanceled, AuthFailed),
without relying on actual backend configuration or whether the
/complete/<backend>/ URLs are registered and accessible.
"""

def setUp(self):
self.factory = RequestFactory()
self.middleware = CustomSocialAuthExceptionMiddleware(lambda *_: HttpResponse("OK"))

def _prepare_request(self, path):
request = self.factory.get(path)
request.user = AnonymousUser()
SessionMiddleware(lambda *_: None).process_request(request)
request.session.save()
request._messages = FallbackStorage(request)
return request

def test_social_auth_exception_redirects_to_login(self):
login_paths = [
"/login/oidc/",
"/login/auth0/",
"/login/google-oauth2/",
"/login/okta-oauth2/",
"/login/azuread-tenant-oauth2/",
"/login/gitlab/",
"/login/keycloak-oauth2/",
"/login/github/",
]
exceptions = [
(RequestsConnectionError("Host unreachable"), "Login via social authentication is temporarily unavailable. Please use the standard login below."),
(AuthCanceled("User canceled login"), "Social login was canceled. Please try again or use the standard login."),
(AuthFailed("Token exchange failed"), "Social login failed. Please try again or use the standard login."),
]
for path in login_paths:
for exception, expected_message in exceptions:
with self.subTest(path=path, exception=type(exception).__name__):
request = self._prepare_request(path)
response = self.middleware.process_exception(request, exception)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/login")
storage = list(messages.get_messages(request))
self.assertTrue(any(expected_message in str(msg) for msg in storage))

def test_non_social_auth_path_still_redirects_on_auth_exception(self):
"""Ensure middleware handles AuthFailed even on unrelated paths."""
request = self._prepare_request("/some/other/path/")
exception = AuthFailed("Should be handled globally")
response = self.middleware.process_exception(request, exception)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/login")
storage = list(messages.get_messages(request))
self.assertTrue(any("Social login failed. Please try again or use the standard login." in str(msg) for msg in storage))


@override_settings(
AUTHENTICATION_BACKENDS=(
"social_core.backends.github.GithubOAuth2",
"social_core.backends.gitlab.GitLabOAuth2",
"social_core.backends.keycloak.KeycloakOAuth2",
"social_core.backends.azuread_tenant.AzureADTenantOAuth2",
"social_core.backends.auth0.Auth0OAuth2",
"social_core.backends.okta.OktaOAuth2",
"social_core.backends.open_id_connect.OpenIdConnectAuth",
"django.contrib.auth.backends.ModelBackend",
),
)
class TestSocialAuthIntegrationFailures(DojoTestCase):

"""
Integration tests:
Simulates social login failures by calling /complete/<backend>/ URLs
and mocking auth_complete() to raise AuthFailed and AuthCanceled.
Verifies that the middleware is correctly integrated and handles backend failures.
"""

BACKEND_CLASS_PATHS = {
"github": "social_core.backends.github.GithubOAuth2",
"gitlab": "social_core.backends.gitlab.GitLabOAuth2",
"keycloak": "social_core.backends.keycloak.KeycloakOAuth2",
"azuread-tenant-oauth2": "social_core.backends.azuread_tenant.AzureADTenantOAuth2",
"auth0": "social_core.backends.auth0.Auth0OAuth2",
"okta-oauth2": "social_core.backends.okta.OktaOAuth2",
"oidc": "social_core.backends.open_id_connect.OpenIdConnectAuth",
}

def _test_backend_exception(self, backend_slug, exception, expected_message):
backend_class_path = self.BACKEND_CLASS_PATHS[backend_slug]
with patch(f"{backend_class_path}.auth_complete", side_effect=exception):
response = self.client.get(f"/complete/{backend_slug}/", follow=True)
self.assertEqual(response.status_code, 200)
self.assertContains(response, expected_message)

def test_all_backends_auth_failed(self):
for backend in self.BACKEND_CLASS_PATHS:
with self.subTest(backend=backend):
self._test_backend_exception(backend, AuthFailed(backend=None), "Social login failed. Please try again or use the standard login.")

def test_all_backends_auth_canceled(self):
for backend in self.BACKEND_CLASS_PATHS:
with self.subTest(backend=backend):
self._test_backend_exception(backend, AuthCanceled(backend=None), "Social login was canceled. Please try again or use the standard login.")