|
4 | 4 | from django.http import HttpRequest, JsonResponse |
5 | 5 | from django.views.decorators.csrf import csrf_exempt |
6 | 6 | from django.views.decorators.http import require_GET |
| 7 | +from rest_framework import status as drf_status |
| 8 | +from rest_framework.permissions import AllowAny |
| 9 | +from rest_framework.request import Request |
| 10 | +from rest_framework.response import Response |
| 11 | +from rest_framework.throttling import ScopedRateThrottle |
| 12 | +from rest_framework.views import APIView |
| 13 | + |
| 14 | +from oauth2_metadata.serializers import DCRRequestSerializer |
| 15 | +from oauth2_metadata.services import create_oauth2_application |
7 | 16 |
|
8 | 17 |
|
9 | 18 | @csrf_exempt |
@@ -35,3 +44,63 @@ def authorization_server_metadata(request: HttpRequest) -> JsonResponse: |
35 | 44 | } |
36 | 45 |
|
37 | 46 | return JsonResponse(metadata) |
| 47 | + |
| 48 | + |
| 49 | +class DynamicClientRegistrationView(APIView): |
| 50 | + """RFC 7591 Dynamic Client Registration endpoint.""" |
| 51 | + |
| 52 | + authentication_classes: list[type] = [] |
| 53 | + permission_classes = [AllowAny] |
| 54 | + throttle_classes = [ScopedRateThrottle] |
| 55 | + throttle_scope = "dcr_register" |
| 56 | + |
| 57 | + # Map DRF serializer field names to RFC 7591 error codes. |
| 58 | + _rfc7591_error_codes: dict[str, str] = { |
| 59 | + "redirect_uris": "invalid_redirect_uri", |
| 60 | + "client_name": "invalid_client_metadata", |
| 61 | + "grant_types": "invalid_client_metadata", |
| 62 | + "response_types": "invalid_client_metadata", |
| 63 | + "token_endpoint_auth_method": "invalid_client_metadata", |
| 64 | + } |
| 65 | + |
| 66 | + def post(self, request: Request) -> Response: |
| 67 | + serializer = DCRRequestSerializer(data=request.data) |
| 68 | + if not serializer.is_valid(): |
| 69 | + return self._rfc7591_error_response(serializer.errors) |
| 70 | + |
| 71 | + data = serializer.validated_data |
| 72 | + |
| 73 | + application = create_oauth2_application( |
| 74 | + client_name=data["client_name"], |
| 75 | + redirect_uris=data["redirect_uris"], |
| 76 | + ) |
| 77 | + |
| 78 | + return Response( |
| 79 | + { |
| 80 | + "client_id": application.client_id, |
| 81 | + "client_name": application.name, |
| 82 | + "redirect_uris": data["redirect_uris"], |
| 83 | + "grant_types": data["grant_types"], |
| 84 | + "response_types": data["response_types"], |
| 85 | + "token_endpoint_auth_method": data["token_endpoint_auth_method"], |
| 86 | + "client_id_issued_at": int(application.created.timestamp()), |
| 87 | + }, |
| 88 | + status=drf_status.HTTP_201_CREATED, |
| 89 | + ) |
| 90 | + |
| 91 | + def _rfc7591_error_response(self, errors: dict[str, list[str]]) -> Response: |
| 92 | + """Format validation errors per RFC 7591 section 3.2.2.""" |
| 93 | + first_field = next(iter(errors)) |
| 94 | + error_code = self._rfc7591_error_codes.get( |
| 95 | + first_field, "invalid_client_metadata" |
| 96 | + ) |
| 97 | + messages = errors[first_field] |
| 98 | + description = messages[0] if isinstance(messages[0], str) else str(messages[0]) |
| 99 | + |
| 100 | + return Response( |
| 101 | + { |
| 102 | + "error": error_code, |
| 103 | + "error_description": description, |
| 104 | + }, |
| 105 | + status=drf_status.HTTP_400_BAD_REQUEST, |
| 106 | + ) |
0 commit comments