Skip to content

Commit 0b9a51b

Browse files
committed
Fixed merge conflicts
2 parents 96bd38e + 6de9f41 commit 0b9a51b

3 files changed

Lines changed: 164 additions & 11 deletions

File tree

Tests/test_csrf_json.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import asyncio
2+
import httpx
3+
import subprocess
4+
import sys
5+
import time
6+
import os
7+
8+
# Construct absolute path to the test application directory
9+
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
10+
TEST_APP_DIR = os.path.join(TESTS_DIR, "test")
11+
12+
# Ensure the test application is in the python path
13+
sys.path.insert(0, TEST_APP_DIR)
14+
15+
BASE_URL = "http://127.0.0.1:8000"
16+
17+
async def run_csrf_test():
18+
"""
19+
Tests that CSRF protection works correctly for various request types.
20+
"""
21+
print("--- Starting CSRF Logic Test ---")
22+
async with httpx.AsyncClient(base_url=BASE_URL) as client:
23+
try:
24+
# 1. Make a GET request to a page to get a CSRF token from the cookie
25+
print("Step 1: Getting CSRF token from homepage...")
26+
get_response = await client.get("/")
27+
get_response.raise_for_status()
28+
assert "csrf_token" in client.cookies, "CSRF token not found in cookie"
29+
csrf_token = client.cookies["csrf_token"]
30+
print(f" [PASS] CSRF token received: {csrf_token[:10]}...")
31+
32+
# 2. Test POST without any CSRF token (should fail)
33+
print("\nStep 2: Testing POST to /api/test without CSRF token (expecting 403)...")
34+
fail_response = await client.post("/api/test", json={"message": "hello"})
35+
assert fail_response.status_code == 403, f"Expected status 403, but got {fail_response.status_code}"
36+
assert "CSRF token missing or invalid" in fail_response.text
37+
print(" [PASS] Request was correctly forbidden.")
38+
39+
# 3. Test POST with CSRF token in JSON body (should pass)
40+
print("\nStep 3: Testing POST to /api/test with CSRF token in JSON body (expecting 200)...")
41+
payload_with_token = {"message": "hello", "csrf_token": csrf_token}
42+
success_response_body = await client.post("/api/test", json=payload_with_token)
43+
assert success_response_body.status_code == 200, f"Expected status 200, but got {success_response_body.status_code}"
44+
assert success_response_body.json()["message"] == "hello"
45+
print(" [PASS] Request with token in body was successful.")
46+
47+
# 4. Test POST with CSRF token in header (should pass)
48+
print("\nStep 4: Testing POST to /api/test with CSRF token in header (expecting 200)...")
49+
headers = {"X-CSRF-Token": csrf_token}
50+
success_response_header = await client.post("/api/test", json={"message": "world"}, headers=headers)
51+
assert success_response_header.status_code == 200, f"Expected status 200, but got {success_response_header.status_code}"
52+
assert success_response_header.json()["message"] == "world"
53+
print(" [PASS] Request with token in header was successful.")
54+
55+
# 5. Test empty-body POST with CSRF token in header (should pass validation, then redirect)
56+
print("\nStep 5: Testing empty-body POST to /logout with CSRF token in header (expecting 302)...")
57+
# Note: The /logout endpoint redirects after success, so we expect a 302
58+
# We disable auto-redirects to verify the 302 status directly
59+
empty_body_response = await client.post("/logout", headers=headers, follow_redirects=False)
60+
61+
# If we got a 403, the CSRF check failed. If we got a 302, it passed!
62+
assert empty_body_response.status_code == 302, f"Expected status 302 (Redirect), but got {empty_body_response.status_code}. (403 means CSRF failed)"
63+
print(" [PASS] Empty-body request passed CSRF check and redirected.")
64+
65+
except Exception as e:
66+
print(f"\n--- TEST FAILED ---")
67+
print(f"An error occurred: {e}")
68+
import traceback
69+
traceback.print_exc()
70+
return False
71+
72+
print("\n--- ALL CSRF TESTS PASSED ---")
73+
return True
74+
75+
76+
def main():
77+
print("Starting test server...")
78+
server_process = subprocess.Popen(
79+
[sys.executable, "-m", "uvicorn", "app:app"],
80+
cwd=TEST_APP_DIR,
81+
stdout=subprocess.PIPE,
82+
stderr=subprocess.PIPE,
83+
text=True, # Decode stdout/stderr as text
84+
)
85+
86+
# Give the server more time to start up
87+
print("Waiting 5 seconds for server to start...")
88+
time.sleep(5)
89+
90+
# Check if the server process has terminated unexpectedly
91+
if server_process.poll() is not None:
92+
print("\n--- SERVER FAILED TO START ---")
93+
stdout, stderr = server_process.communicate()
94+
print("STDOUT:")
95+
print(stdout)
96+
print("\nSTDERR:")
97+
print(stderr)
98+
sys.exit(1)
99+
100+
print("Server seems to be running. Starting tests.")
101+
test_passed = False
102+
try:
103+
test_passed = asyncio.run(run_csrf_test())
104+
finally:
105+
print("\nStopping test server...")
106+
server_process.terminate()
107+
# Get remaining output
108+
try:
109+
stdout, stderr = server_process.communicate(timeout=5)
110+
print("\n--- Server Output ---")
111+
print("STDOUT:")
112+
print(stdout)
113+
print("\nSTDERR:")
114+
print(stderr)
115+
except subprocess.TimeoutExpired:
116+
print("Server did not terminate gracefully.")
117+
118+
if not test_passed:
119+
print("\nExiting with status 1 due to test failure.")
120+
sys.exit(1)
121+
122+
123+
if __name__ == "__main__":
124+
main()

jsweb/middleware.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from .static import serve_static
44
from .response import Forbidden
5+
import json
56

67
logger = logging.getLogger(__name__)
78

@@ -30,8 +31,10 @@ class CSRFMiddleware(Middleware):
3031
"""
3132
Middleware to protect against Cross-Site Request Forgery (CSRF) attacks.
3233
33-
This middleware checks for a valid CSRF token in POST, PUT, PATCH, and DELETE
34-
requests. It compares a token from the form data against a token stored in a cookie.
34+
This middleware enforces CSRF protection for all state-changing HTTP methods
35+
(POST, PUT, PATCH, DELETE). It requires a valid CSRF token to be present
36+
in the request, either in the 'X-CSRF-Token' header or in the request body
37+
(JSON or Form Data).
3538
"""
3639
async def __call__(self, scope, receive, send):
3740
"""
@@ -52,18 +55,44 @@ async def __call__(self, scope, receive, send):
5255
req = scope['jsweb.request']
5356

5457
if req.method in ("POST", "PUT", "PATCH", "DELETE"):
55-
form = await req.form()
56-
form_token = form.get("csrf_token")
5758
cookie_token = req.cookies.get("csrf_token")
58-
59-
if not form_token or not cookie_token or not secrets.compare_digest(form_token, cookie_token):
59+
submitted_token = None
60+
61+
# 1. Check header first (Best practice for AJAX/APIs)
62+
submitted_token = req.headers.get("x-csrf-token")
63+
64+
# 2. If no header token, check the body based on content type
65+
if not submitted_token:
66+
content_type = req.headers.get("content-type", "")
67+
68+
if "application/json" in content_type:
69+
try:
70+
# Request.json() safely returns {} for empty/invalid bodies
71+
data = await req.json()
72+
submitted_token = data.get("csrf_token")
73+
except Exception:
74+
# If JSON parsing fails, we treat it as no token found
75+
pass
76+
77+
elif "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
78+
try:
79+
# Request.form() safely returns {} for empty/non-form bodies
80+
form = await req.form()
81+
submitted_token = form.get("csrf_token")
82+
except Exception:
83+
# If form parsing fails, we treat it as no token found
84+
pass
85+
86+
# 3. Perform the validation
87+
# Both the cookie token and the submitted token MUST be present and match.
88+
if not cookie_token or not submitted_token or not secrets.compare_digest(submitted_token, cookie_token):
6089
# Log CSRF failure with context (but never log the actual tokens)
6190
client_ip = scope.get("client", ["unknown"])[0]
62-
logger.error(
91+
logger.warning(
6392
f"CSRF validation failed - Method: {req.method}, "
6493
f"Path: {req.path}, Client IP: {client_ip}, "
65-
f"Form token present: {bool(form_token)}, "
66-
f"Cookie token present: {bool(cookie_token)}"
94+
f"Cookie set: {'Yes' if cookie_token else 'No'}, "
95+
f"Token submitted: {'Yes' if submitted_token else 'No'}."
6796
)
6897
response = Forbidden("CSRF token missing or invalid.")
6998
await response(scope, receive, send)

jsweb/request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, scope, receive, app):
3737
self.path = self.scope.get("path", "/")
3838
self.query = self._parse_query(self.scope.get("query_string", b"").decode())
3939
self.headers = self._parse_headers(self.scope.get("headers", []))
40+
self.content_type = self.headers.get("content-type", "")
4041
self.cookies = self._parse_cookies(self.headers)
4142
self.user = None
4243

@@ -118,8 +119,7 @@ async def json(self):
118119
dict: The parsed JSON data.
119120
"""
120121
if self._json is None:
121-
content_type = self.headers.get("content-type", "")
122-
if "application/json" in content_type:
122+
if "application/json" in self.content_type:
123123
try:
124124
body_bytes = await self.body()
125125
self._json = json.loads(body_bytes) if body_bytes else {}

0 commit comments

Comments
 (0)