Skip to content

Commit f4e5a86

Browse files
committed
added RESTView
1 parent 985dd73 commit f4e5a86

2 files changed

Lines changed: 378 additions & 0 deletions

File tree

rest_framework/rest_views.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from django import VERSION as DJANGO_VERSION
2+
from django.db import models
3+
from django.urls import path
4+
from django.utils.decorators import classonlymethod
5+
from django.views.decorators.csrf import csrf_exempt
6+
7+
from rest_framework.views import APIView
8+
9+
10+
class RESTViewMethod:
11+
12+
def __init__(
13+
self,
14+
http_method: str,
15+
path: str,
16+
url_name: str,
17+
view_method
18+
):
19+
self.http_method = http_method
20+
self.path = path
21+
self.url_name = url_name
22+
self.view_method = view_method
23+
24+
25+
def get(path: str, url_name: str):
26+
def decorator(view_method):
27+
return RESTViewMethod(
28+
http_method='get',
29+
path=path,
30+
url_name=url_name,
31+
view_method=view_method,
32+
)
33+
34+
return decorator
35+
36+
37+
def post(path: str, url_name: str):
38+
def decorator(view_method):
39+
return RESTViewMethod(
40+
http_method='post',
41+
path=path,
42+
url_name=url_name,
43+
view_method=view_method,
44+
)
45+
46+
return decorator
47+
48+
49+
def put(path: str, url_name: str):
50+
def decorator(view_method):
51+
return RESTViewMethod(
52+
http_method='put',
53+
path=path,
54+
url_name=url_name,
55+
view_method=view_method,
56+
)
57+
58+
return decorator
59+
60+
61+
def patch(path: str, url_name: str):
62+
def decorator(view_method):
63+
return RESTViewMethod(
64+
http_method='patch',
65+
path=path,
66+
url_name=url_name,
67+
view_method=view_method,
68+
)
69+
70+
return decorator
71+
72+
73+
def delete(path: str, url_name: str):
74+
def decorator(view_method):
75+
return RESTViewMethod(
76+
http_method='delete',
77+
path=path,
78+
url_name=url_name,
79+
view_method=view_method,
80+
)
81+
82+
return decorator
83+
84+
85+
class RESTViewMetaclass(type):
86+
87+
def __new__(cls, name, bases, attrs):
88+
_all_actions: dict[str, list[tuple[str, str, str]]] = {}
89+
http_method_path_pairs = set()
90+
url_names_by_path = {}
91+
92+
for key, value in attrs.items():
93+
if isinstance(value, RESTViewMethod):
94+
if (value.http_method, value.path) in http_method_path_pairs:
95+
raise ValueError(f"{cls.__name__} has multiple methods with the same HTTP method and path")
96+
97+
http_method_path_pairs.add((value.http_method, value.path))
98+
99+
url_names_by_path.setdefault(value.path, set()).add(value.url_name)
100+
if len(url_names_by_path[value.path]) > 1:
101+
raise ValueError(
102+
f"{cls.__name__} has multiple methods with the same path {value.path}, but different URL names"
103+
)
104+
105+
http_method_path_pairs.add((value.http_method, value.path))
106+
_all_actions.setdefault(value.path, [])
107+
_all_actions[value.path].append((value.http_method, value.view_method.__name__, value.url_name))
108+
attrs[key] = value.view_method
109+
110+
attrs['_all_actions'] = _all_actions
111+
return type.__new__(cls, name, bases, attrs)
112+
113+
114+
class RESTView(APIView, metaclass=RESTViewMetaclass):
115+
"""
116+
A View that allows handling any HTTP methods and URL paths. Use special decorators to specify URl path
117+
and URl name for handlers. These decorators moved to a class attribute at runtime.
118+
119+
Example:
120+
class UserAPI(RESTView):
121+
122+
@get(path='/v1/users/', url_name='users')
123+
def list(self, request):
124+
...
125+
126+
@get(path='/v1/users/<int:user_id>/', url_name='user_detail')
127+
def retrieve(self, request, user_id: int):
128+
...
129+
130+
@post(path='/v1/users/', url_name='users')
131+
def create(self, request):
132+
...
133+
134+
@patch(path='/v1/users/<int:user_id>/change_password/', url_name='user_change_password')
135+
def change_password(self, request, user_id: int):
136+
...
137+
138+
To use this View, you have to comply with these rules:
139+
1. Use special decorators for all handlers
140+
2. All identical URL paths must have identical URL names
141+
3. Special decorators have to be the last in order
142+
4. All custom decorators have to be wrapped with functools.wraps or manually copy the docstrings
143+
"""
144+
145+
@classonlymethod
146+
def unwrap_url_patterns(cls, **initkwargs):
147+
"""
148+
Create classes for all URL paths for Django urlpatterns interface
149+
150+
Example:
151+
urlpatterns = [
152+
*UserAPI.unwrap_url_patterns(),
153+
]
154+
"""
155+
urlpatterns = []
156+
for url_path, attrs in cls._all_actions.items():
157+
view = cls.as_view(url_path=url_path, **initkwargs)
158+
urlpatterns.append(path(url_path, view, name=attrs[0][2]))
159+
160+
return urlpatterns
161+
162+
@classmethod
163+
def as_view(cls, url_path: str, **initkwargs):
164+
"""
165+
Store the generated class on the view function for URL path. Don't use this method.
166+
"""
167+
if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
168+
def force_evaluation():
169+
raise RuntimeError(
170+
'Do not evaluate the `.queryset` attribute directly, '
171+
'as the result will be cached and reused between requests. '
172+
'Use `.all()` or call `.get_queryset()` instead.'
173+
)
174+
cls.queryset._fetch_all = force_evaluation
175+
176+
fork_cls = type(cls.__name__, cls.__bases__, dict(cls.__dict__))
177+
actions = {}
178+
for http_method, view_method_name, url_name in cls._all_actions[url_path]:
179+
actions[http_method] = view_method_name
180+
181+
if 'get' in actions and 'head' not in actions:
182+
actions['head'] = actions['get']
183+
184+
if 'options' not in actions:
185+
# use ApiView.options
186+
actions['options'] = 'options'
187+
188+
fork_cls.actions = actions
189+
view = super(APIView, fork_cls).as_view(**initkwargs)
190+
view.cls = fork_cls
191+
view.initkwargs = initkwargs
192+
193+
# Exempt all DRF views from Django's LoginRequiredMiddleware. Users should set
194+
# DEFAULT_PERMISSION_CLASSES to 'rest_framework.permissions.IsAuthenticated' instead
195+
if DJANGO_VERSION >= (5, 1):
196+
view.login_required = False
197+
198+
# Note: session based authentication is explicitly CSRF validated,
199+
# all other authentication is CSRF exempt.
200+
return csrf_exempt(view)
201+
202+
def dispatch(self, request, *args, **kwargs):
203+
"""
204+
`.dispatch()` is pretty much the same as ApiView dispatch
205+
"""
206+
self.args = args
207+
self.kwargs = kwargs
208+
request = self.initialize_request(request, *args, **kwargs)
209+
self.request = request
210+
self.headers = self.default_response_headers # deprecate?
211+
212+
try:
213+
self.initial(request, *args, **kwargs)
214+
215+
view_method_name = self.actions.get(request.method.lower())
216+
handler = (
217+
getattr(self, view_method_name, self.http_method_not_allowed)
218+
if view_method_name
219+
else self.http_method_not_allowed
220+
)
221+
response = handler(request, *args, **kwargs)
222+
223+
except Exception as exc:
224+
response = self.handle_exception(exc)
225+
226+
self.response = self.finalize_response(request, response, *args, **kwargs)
227+
return self.response
228+
229+
230+
__all__ = ['get', 'post', 'put', 'patch', 'delete', 'RESTView']

tests/test_rest_views.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from functools import wraps
2+
3+
from django.test import TestCase, override_settings
4+
5+
from rest_framework import status
6+
from rest_framework.response import Response
7+
from rest_framework.rest_views import RESTView, get, post, patch
8+
from rest_framework.settings import api_settings
9+
10+
11+
class BasicRESTView(RESTView):
12+
13+
@get(path='instances/', url_name='instances_list')
14+
def list(self, request):
15+
return Response(status=200, data={'http_method': 'GET', 'view_method': 'list'})
16+
17+
@get(path='instances/<int:instance_id>/', url_name='detail_instance')
18+
def retrieve(self, request, instance_id: int):
19+
return Response(status=200, data={'http_method': 'GET', 'view_method': 'retrieve'})
20+
21+
@post(path='instances/', url_name='instances_list')
22+
def create(self, request):
23+
return Response(status=201, data={'http_method': 'POST', 'view_method': 'create'})
24+
25+
@patch(path='instances/<int:instance_id>/', url_name='detail_instance')
26+
def update(self, request, instance_id: int):
27+
return Response(status=200, data={'http_method': 'PATCH', 'view_method': 'update'})
28+
29+
@patch(path='instances/<int:instance_id>/change/', url_name='detail_instance_change')
30+
def change_status(self, request, instance_id: int):
31+
return Response(status=200, data={'http_method': 'PATCH', 'view_method': 'change_status'})
32+
33+
34+
class ErrorRESTView(RESTView):
35+
36+
@get(path='errors/', url_name='errors_list')
37+
def error_method(self, request):
38+
raise Exception
39+
40+
41+
def custom_decorator(view_method):
42+
@wraps(view_method)
43+
def wrapper(*args, **kwargs):
44+
response = view_method(*args, **kwargs)
45+
response.data['has_decorator'] = True
46+
return response
47+
48+
return wrapper
49+
50+
51+
class RESTViewWithCustomDecorators(RESTView):
52+
53+
@get(path='decorators/', url_name='errors_list')
54+
@custom_decorator
55+
def custom_decorator(self, request):
56+
return Response(status=200, data={'http_method': 'GET', 'view_method': 'custom_decorator'})
57+
58+
59+
urlpatterns = [
60+
*BasicRESTView.unwrap_url_patterns(),
61+
*ErrorRESTView.unwrap_url_patterns(),
62+
*RESTViewWithCustomDecorators.unwrap_url_patterns(),
63+
]
64+
65+
66+
class TestInitializeRESTView(TestCase):
67+
68+
@staticmethod
69+
def test_initialize_rest_view():
70+
assert BasicRESTView._all_actions == {
71+
'instances/': [
72+
('get', 'list', 'instances_list'),
73+
('post', 'create', 'instances_list'),
74+
],
75+
'instances/<int:instance_id>/': [
76+
('get', 'retrieve', 'detail_instance'),
77+
('patch', 'update', 'detail_instance'),
78+
],
79+
'instances/<int:instance_id>/change/': [('patch', 'change_status', 'detail_instance_change')],
80+
}
81+
assert not hasattr(BasicRESTView, 'actions')
82+
83+
84+
class TestRESTViewUnwrap(TestCase):
85+
86+
@staticmethod
87+
def test_unwrap_url_patterns():
88+
urlpatterns = BasicRESTView.unwrap_url_patterns()
89+
assert len(urlpatterns) == 3
90+
for pattern in urlpatterns:
91+
assert pattern.callback.cls is not BasicRESTView
92+
93+
94+
@override_settings(ROOT_URLCONF='tests.test_rest_views')
95+
class RESTViewIntegrationTests(TestCase):
96+
97+
def test_successful_get_request(self):
98+
response = self.client.get(path='/instances/')
99+
assert response.status_code == status.HTTP_200_OK
100+
assert response.data == {'http_method': 'GET', 'view_method': 'list'}
101+
102+
def test_successful_get_request_with_path_param(self):
103+
response = self.client.get(path='/instances/1/')
104+
assert response.status_code == status.HTTP_200_OK
105+
assert response.data == {'http_method': 'GET', 'view_method': 'retrieve'}
106+
107+
def test_successful_post_request(self):
108+
response = self.client.post(path='/instances/')
109+
assert response.status_code == status.HTTP_201_CREATED
110+
assert response.data == {'http_method': 'POST', 'view_method': 'create'}
111+
112+
def test_successful_head_request(self):
113+
response = self.client.head(path='/instances/')
114+
assert response.status_code == status.HTTP_200_OK
115+
assert response.data == {'http_method': 'GET', 'view_method': 'list'}
116+
117+
def test_successful_options_request(self):
118+
response = self.client.options(path='/instances/')
119+
assert response.status_code == status.HTTP_200_OK
120+
121+
def test_method_not_allowed(self):
122+
response = self.client.put(path='/instances/')
123+
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
124+
125+
def test_method_with_custom_decorator(self):
126+
response = self.client.get(path='/decorators/')
127+
assert response.status_code == status.HTTP_200_OK
128+
assert response.data == {'http_method': 'GET', 'view_method': 'custom_decorator', 'has_decorator': True}
129+
130+
131+
@override_settings(ROOT_URLCONF='tests.test_rest_views')
132+
class TestCustomExceptionHandler(TestCase):
133+
134+
def setUp(self):
135+
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
136+
137+
def exception_handler(exc, request):
138+
return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
139+
140+
api_settings.EXCEPTION_HANDLER = exception_handler
141+
142+
def tearDown(self):
143+
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
144+
145+
def test_class_based_view_exception_handler(self):
146+
response = self.client.get(path='/errors/')
147+
assert response.status_code == status.HTTP_400_BAD_REQUEST
148+
assert response.data == 'Error!'

0 commit comments

Comments
 (0)