Skip to content

Commit ce2ad90

Browse files
jsocolttys0dev
authored andcommitted
Add support for async functions to decorator
Tries to add support for async functions to the decorator, but trips over not having a failing test to fix.
1 parent afc0b4c commit ce2ad90

3 files changed

Lines changed: 71 additions & 8 deletions

File tree

.github/actions/test/action.yml

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,23 @@ runs:
1414
with:
1515
python-version: ${{ inputs.python-version }}
1616

17-
- name: Install dependencies
17+
- name: Update pip
1818
shell: sh
19-
run: |
20-
python -m pip install --upgrade pip
21-
if [[ ${{ inputs.django-version }} != 'main' ]]; then pip install --pre -q "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99"; fi
22-
if [[ ${{ inputs.django-version }} == 'main' ]]; then pip install https://github.com/django/django/archive/main.tar.gz; fi
23-
pip install flake8 django-redis pymemcache
19+
run: python -m pip install --upgrade pip
20+
21+
- name: Install Django
22+
shell: sh
23+
run: python -m pip install "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99"
24+
if: ${{ inputs.django-version != 'main' }}
25+
26+
- name: Install Django main
27+
shell: sh
28+
run: python -m pip install https://github.com/django/django/archive/main.tar.gz
29+
if: ${{ inputs.django-version == 'main' }}
30+
31+
- name: Install Django dependencies
32+
shell: sh
33+
run: pip install flake8 django-redis pymemcache
2434

2535
- name: Test
2636
shell: sh

django_ratelimit/decorators.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from functools import wraps
2+
import django
3+
if django.VERSION >= (4, 1):
4+
from asgiref.sync import iscoroutinefunction
5+
else:
6+
def iscoroutinefunction(func):
7+
return False
28

39
from django.conf import settings
410
from django.utils.module_loading import import_string
@@ -13,6 +19,23 @@
1319

1420
def ratelimit(group=None, key=None, rate=None, method=ALL, block=True):
1521
def decorator(fn):
22+
# if iscoroutinefunction(fn):
23+
# @wraps(fn)
24+
# async def _async_wrapped(request, *args, **kw):
25+
# old_limited = getattr(request, 'limited', False)
26+
# ratelimited = is_ratelimited(
27+
# request=request, group=group, fn=fn, key=key, rate=rate,
28+
# method=method, increment=True)
29+
# request.limited = ratelimited or old_limited
30+
# if ratelimited and block:
31+
# cls = getattr(
32+
# settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited)
33+
# if isinstance(cls, str):
34+
# cls = import_string(cls)
35+
# raise cls()
36+
# return await fn(request, *args, **kw)
37+
# return _async_wrapped
38+
1639
@wraps(fn)
1740
def _wrapped(request, *args, **kw):
1841
old_limited = getattr(request, 'limited', False)
@@ -23,7 +46,9 @@ def _wrapped(request, *args, **kw):
2346
if ratelimited and block:
2447
cls = getattr(
2548
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited)
26-
raise (import_string(cls) if isinstance(cls, str) else cls)()
49+
if isinstance(cls, str):
50+
cls = import_string(cls)
51+
raise cls()
2752
return fn(request, *args, **kw)
2853
return _wrapped
2954
return decorator

django_ratelimit/tests.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import asyncio
2+
3+
import django
14
from functools import partial
25

36
from django.core.cache import cache, InvalidCacheBackendError
@@ -12,7 +15,10 @@
1215
from django_ratelimit.core import (get_usage, is_ratelimited,
1316
_split_rate, _get_ip)
1417

15-
18+
if django.VERSION >= (4, 1):
19+
from asgiref.sync import iscoroutinefunction
20+
from django.test import AsyncRequestFactory
21+
arf = AsyncRequestFactory()
1622
rf = RequestFactory()
1723

1824

@@ -412,6 +418,28 @@ def view(request):
412418
assert not view(req)
413419

414420

421+
if django.VERSION >= (4, 1):
422+
class AsyncTests(TestCase):
423+
def setUp(self):
424+
cache.clear()
425+
426+
async def test_decorate_async_function(self):
427+
@ratelimit(key='ip', rate='1/m', block=False)
428+
async def view(request):
429+
await asyncio.sleep(0)
430+
return request.limited
431+
432+
req1 = arf.get('/')
433+
req1.META['REMOTE_ADDR'] = '1.2.3.4'
434+
435+
req2 = arf.get('/')
436+
req2.META['REMOTE_ADDR'] = '1.2.3.4'
437+
438+
assert iscoroutinefunction(view)
439+
assert await view(req1) is False
440+
assert await view(req2) is True
441+
442+
415443
class FunctionsTests(TestCase):
416444
def setUp(self):
417445
cache.clear()

0 commit comments

Comments
 (0)