Skip to content

Commit 32cdce1

Browse files
committed
Update quart patches to use new system
1 parent 9e0c8b8 commit 32cdce1

2 files changed

Lines changed: 79 additions & 82 deletions

File tree

aikido_zen/sinks/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,25 @@ def decorator(func, instance, args, kwargs):
5555
return decorator
5656

5757

58+
def before_async(wrapper):
59+
"""
60+
Surrounds an async patch with try-except and calls the original asynchronous function at the end
61+
"""
62+
63+
async def decorator(func, instance, args, kwargs):
64+
try:
65+
await wrapper(func, instance, args, kwargs) # Call the patch
66+
except AikidoException as e:
67+
raise e # Re-raise AikidoException
68+
except Exception as e:
69+
logger.debug(
70+
"%s:%s wrapping-before error: %s", func.__module__, func.__name__, e
71+
)
72+
return await func(*args, **kwargs) # Call the original function
73+
74+
return decorator
75+
76+
5877
def before_modify_return(wrapper):
5978
"""
6079
Surrounds a patch with try-except and calls the original function at the end unless a return value is present.

aikido_zen/sources/quart.py

Lines changed: 60 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,71 @@
1-
"""
2-
Quart source module, intercepts quart import and adds Aikido middleware
3-
"""
4-
5-
import copy
6-
import aikido_zen.importhook as importhook
7-
from aikido_zen.helpers.logging import logger
81
from aikido_zen.context import Context, get_current_context
9-
from aikido_zen.background_process.packages import is_package_compatible, ANY_VERSION
102
from .functions.request_handler import request_handler
3+
from ..helpers.get_argument import get_argument
4+
from ..sinks import on_import, patch_function, before, before_async
115

126

13-
async def aikido___call___wrapper(former_call, quart_app, scope, receive, send):
14-
"""Aikido's __call__ wrapper"""
15-
# We don't want to install werkzeug :
16-
# pylint: disable=import-outside-toplevel
17-
try:
18-
if scope["type"] != "http":
19-
return await former_call(quart_app, scope, receive, send)
20-
context1 = Context(req=scope, source="quart")
21-
context1.set_as_current_context()
7+
@before
8+
def _call(func, instance, args, kwargs):
9+
scope = get_argument(args, kwargs, 0, "scope")
10+
if not scope or scope.get("type") != "http":
11+
return
2212

23-
request_handler(stage="init")
24-
except Exception as e:
25-
logger.debug("Exception on aikido __call__ function : %s", e)
26-
return await former_call(quart_app, scope, receive, send)
13+
new_context = Context(req=scope, source="quart")
14+
new_context.set_as_current_context()
15+
request_handler(stage="init")
2716

2817

29-
async def handle_request_wrapper(former_handle_request, quart_app, req):
30-
"""
31-
https://github.com/pallets/quart/blob/2fc6d4fa6e3df017e8eef1411ec80b5a6dce25a5/src/quart/app.py#L1400
32-
Wraps the handle_request function
33-
"""
34-
# At this stage no middleware is called yet, running pre_response is
35-
# not what we need to do now, but we can store the body inside context :
36-
try:
37-
context = get_current_context()
38-
if context:
39-
form = await req.form
40-
if req.is_json:
41-
context.set_body(await req.get_json())
42-
elif form:
43-
context.set_body(form)
44-
else:
45-
data = await req.data
46-
context.set_body(data.decode("utf-8"))
47-
context.cookies = req.cookies.to_dict()
48-
context.set_as_current_context()
49-
except Exception as e:
50-
logger.debug("Exception in handle_request : %s", e)
51-
52-
# Fetch response and run post_response handler :
18+
@before_async
19+
async def _handle_request_before(func, instance, args, kwargs):
20+
context = get_current_context()
21+
if not context:
22+
return
23+
24+
request = get_argument(args, kwargs, 0, "request")
25+
if not request:
26+
return
27+
28+
form = await request.form
29+
if request.is_json:
30+
context.set_body(await request.get_json())
31+
elif form:
32+
context.set_body(form)
33+
else:
34+
data = await request.data
35+
context.set_body(data.decode("utf-8"))
36+
context.cookies = request.cookies.to_dict()
37+
context.set_as_current_context()
38+
39+
40+
async def _handle_request_after(func, instance, args, kwargs):
5341
# pylint:disable=import-outside-toplevel # We don't want to install this by default
5442
from werkzeug.exceptions import HTTPException
5543

5644
try:
57-
response = await former_handle_request(quart_app, req)
58-
status_code = response.status_code
59-
request_handler(stage="post_response", status_code=status_code)
45+
response = await func(*args, **kwargs)
46+
if hasattr(response, "status_code"):
47+
request_handler(stage="post_response", status_code=response.status_code)
6048
return response
6149
except HTTPException as e:
6250
request_handler(stage="post_response", status_code=e.code)
6351
raise e
6452

6553

54+
async def _asgi_app(func, instance, args, kwargs):
55+
scope = get_argument(args, kwargs, 0, "scope")
56+
if not scope or scope.get("type") != "http":
57+
return await func(*args, **kwargs)
58+
send = get_argument(args, kwargs, 2, "send")
59+
if not send:
60+
return await func(*args, **kwargs)
61+
62+
pre_response = request_handler(stage="pre_response")
63+
if pre_response:
64+
return await send_status_code_and_text(send, pre_response)
65+
return await func(*args, **kwargs)
66+
67+
6668
async def send_status_code_and_text(send, pre_response):
67-
"""Sends a status code and text"""
6869
await send(
6970
{
7071
"type": "http.response.start",
@@ -81,38 +82,15 @@ async def send_status_code_and_text(send, pre_response):
8182
)
8283

8384

84-
@importhook.on_import("quart.app")
85-
def on_quart_import(quart):
85+
@on_import("quart.app", "quart")
86+
def patch(m):
8687
"""
87-
Hook 'n wrap on `quart.app`
88-
Our goal is to wrap the __call__, handle_request, asgi_app functios of the "Quart" class
88+
patching module quart.app
89+
- patches Quart.__call__ (creates Context)
90+
- patches Quart.handle_request (Stores body/cookies, checks status code)
91+
- patches Quart.asgi_app (Pre-response: puts in messages when request is blocked)
8992
"""
90-
if not is_package_compatible("quart", required_version=ANY_VERSION):
91-
return quart
92-
modified_quart = importhook.copy_module(quart)
93-
94-
former_handle_request = copy.deepcopy(quart.Quart.handle_request)
95-
former_asgi_app = copy.deepcopy(quart.Quart.asgi_app)
96-
former_call = copy.deepcopy(quart.Quart.__call__)
97-
98-
async def aikido___call__(quart_app, scope, receive=None, send=None):
99-
return await aikido___call___wrapper(
100-
former_call, quart_app, scope, receive, send
101-
)
102-
103-
async def aikido_handle_request(quart_app, request):
104-
return await handle_request_wrapper(former_handle_request, quart_app, request)
105-
106-
async def aikido_asgi_app(quart_app, scope, receive=None, send=None):
107-
if scope["type"] == "http":
108-
# Run pre_response code :
109-
pre_response = request_handler(stage="pre_response")
110-
if pre_response:
111-
return await send_status_code_and_text(send, pre_response)
112-
return await former_asgi_app(quart_app, scope, receive, send)
113-
114-
# pylint:disable=no-member # Pylint has issues with the wrapping
115-
setattr(modified_quart.Quart, "__call__", aikido___call__)
116-
setattr(modified_quart.Quart, "handle_request", aikido_handle_request)
117-
setattr(modified_quart.Quart, "asgi_app", aikido_asgi_app)
118-
return modified_quart
93+
patch_function(m, "Quart.__call__", _call)
94+
patch_function(m, "Quart.handle_request", _handle_request_before)
95+
patch_function(m, "Quart.handle_request", _handle_request_after)
96+
patch_function(m, "Quart.asgi_app", _asgi_app)

0 commit comments

Comments
 (0)