11import pytest
2- from unittest .mock import MagicMock , patch
2+ from unittest .mock import MagicMock
33from .on_detected_attack import on_detected_attack
44from ...context import Context
5+ import aikido_zen .test_utils as test_utils
56
67
78@pytest .fixture
@@ -15,79 +16,76 @@ def mock_connection_manager():
1516 return connection_manager
1617
1718
18- @pytest .fixture
19- def mock_context ():
20- basic_wsgi_req = {
21- "REQUEST_METHOD" : "GET" ,
22- "HTTP_HEADER_1" : "header 1 value" ,
23- "HTTP_HEADER_2" : "Header 2 value" ,
24- "RANDOM_VALUE" : "Random value" ,
25- "HTTP_COOKIE" : "sessionId=abc123xyz456;" ,
26- "wsgi.url_scheme" : "http" ,
27- "HTTP_HOST" : "localhost:8080" ,
28- "PATH_INFO" : "/hello" ,
29- "QUERY_STRING" : "user=JohnDoe&age=30&age=35" ,
30- "CONTENT_TYPE" : "application/json" ,
31- "REMOTE_ADDR" : "198.51.100.23" ,
32- "HTTP_USER_AGENT" : "Mozilla/5.0" ,
33- }
34-
35- return Context (req = basic_wsgi_req , body = 123 , source = "django" )
36-
37-
38- def test_on_detected_attack_no_token (mock_context ):
19+ def test_on_detected_attack_no_token ():
3920 connection_manager = MagicMock ()
4021 connection_manager .token = None
41- on_detected_attack (connection_manager , {}, mock_context , blocked = False , stack = None )
22+
23+ on_detected_attack (
24+ connection_manager ,
25+ attack = {},
26+ context = test_utils .generate_context (),
27+ blocked = False ,
28+ stack = None ,
29+ )
30+
4231 connection_manager .api .report .assert_not_called ()
4332
4433
45- def test_on_detected_attack_with_long_payload (mock_connection_manager , mock_context ):
34+ def test_on_detected_attack_with_long_payload (mock_connection_manager ):
4635 long_payload = "x" * 5000 # Create a payload longer than 4096 characters
4736 attack = {
4837 "payload" : long_payload ,
4938 "metadata" : {"test" : "1" },
5039 }
5140
5241 on_detected_attack (
53- mock_connection_manager , attack , mock_context , blocked = False , stack = None
42+ mock_connection_manager ,
43+ attack = attack ,
44+ context = test_utils .generate_context (),
45+ blocked = False ,
46+ stack = None ,
5447 )
48+
5549 assert len (attack ["payload" ]) == 4096 # Ensure payload is truncated
5650 mock_connection_manager .api .report .assert_called_once ()
5751
5852
59- def test_on_detected_attack_with_long_metadata (mock_connection_manager , mock_context ):
53+ def test_on_detected_attack_with_long_metadata (mock_connection_manager ):
6054 long_metadata = "x" * 5000 # Create metadata longer than 4096 characters
6155 attack = {
6256 "payload" : {},
6357 "metadata" : {"test" : long_metadata },
6458 }
6559
6660 on_detected_attack (
67- mock_connection_manager , attack , mock_context , blocked = False , stack = None
61+ mock_connection_manager ,
62+ attack = attack ,
63+ context = test_utils .generate_context (),
64+ blocked = False ,
65+ stack = None ,
6866 )
6967
70- assert (
71- attack ["metadata" ]["test" ] == long_metadata [:4096 ]
72- ) # Ensure metadata is truncated
68+ assert attack ["metadata" ]["test" ] == long_metadata [:4096 ]
7369 mock_connection_manager .api .report .assert_called_once ()
7470
7571
76- def test_on_detected_attack_success (mock_connection_manager , mock_context ):
72+ def test_on_detected_attack_success (mock_connection_manager ):
7773 attack = {
7874 "payload" : {"key" : "value" },
7975 "metadata" : {},
8076 }
8177
8278 on_detected_attack (
83- mock_connection_manager , attack , mock_context , blocked = False , stack = None
79+ mock_connection_manager ,
80+ attack = attack ,
81+ context = test_utils .generate_context (),
82+ blocked = False ,
83+ stack = None ,
8484 )
8585 assert mock_connection_manager .api .report .call_count == 1
8686
8787
88- def test_on_detected_attack_exception_handling (
89- mock_connection_manager , mock_context , caplog
90- ):
88+ def test_on_detected_attack_exception_handling (mock_connection_manager , caplog ):
9189 attack = {
9290 "payload" : {"key" : "value" },
9391 "metadata" : {"key" : "value" },
@@ -97,15 +95,17 @@ def test_on_detected_attack_exception_handling(
9795 mock_connection_manager .api .report .side_effect = Exception ("API error" )
9896
9997 on_detected_attack (
100- mock_connection_manager , attack , mock_context , blocked = False , stack = None
98+ mock_connection_manager ,
99+ attack = attack ,
100+ context = test_utils .generate_context (),
101+ blocked = False ,
102+ stack = None ,
101103 )
102104
103105 assert "Failed to report an attack" in caplog .text
104106
105107
106- def test_on_detected_attack_with_blocked_and_stack (
107- mock_connection_manager , mock_context
108- ):
108+ def test_on_detected_attack_with_blocked_and_stack (mock_connection_manager ):
109109 attack = {
110110 "payload" : {"key" : "value" },
111111 "metadata" : {},
@@ -114,7 +114,11 @@ def test_on_detected_attack_with_blocked_and_stack(
114114 stack = "sample stack trace"
115115
116116 on_detected_attack (
117- mock_connection_manager , attack , mock_context , blocked = blocked , stack = stack
117+ mock_connection_manager ,
118+ attack = attack ,
119+ context = test_utils .generate_context (),
120+ blocked = blocked ,
121+ stack = stack ,
118122 )
119123
120124 # Check that the attack dictionary has the blocked and stack fields set
@@ -123,16 +127,24 @@ def test_on_detected_attack_with_blocked_and_stack(
123127 assert mock_connection_manager .api .report .call_count == 1
124128
125129
126- def test_on_detected_attack_request_data_and_attack_data (
127- mock_connection_manager , mock_context
128- ):
130+ def test_on_detected_attack_request_data_and_attack_data (mock_connection_manager ):
129131 attack = {
130132 "payload" : {"key" : "value" },
131133 "metadata" : {"test" : "true" },
132134 }
133135
134136 on_detected_attack (
135- mock_connection_manager , attack , mock_context , blocked = False , stack = None
137+ mock_connection_manager ,
138+ attack = attack ,
139+ context = test_utils .generate_context (
140+ method = "GET" ,
141+ url = "http://localhost:8080/hello" ,
142+ ip = "198.51.100.23" ,
143+ route = "/hello" ,
144+ headers = {"user-agent" : "Mozilla/5.0" },
145+ ),
146+ blocked = False ,
147+ stack = None ,
136148 )
137149
138150 # Extract the call arguments for the report method
@@ -146,7 +158,7 @@ def test_on_detected_attack_request_data_and_attack_data(
146158 assert request_data ["ipAddress" ] == "198.51.100.23"
147159 assert not "body" in request_data
148160 assert not "headers" in request_data
149- assert request_data ["source" ] == "django "
161+ assert request_data ["source" ] == "flask "
150162 assert request_data ["route" ] == "/hello"
151163 assert request_data ["userAgent" ] == "Mozilla/5.0"
152164
@@ -158,16 +170,18 @@ def test_on_detected_attack_request_data_and_attack_data(
158170 assert attack_data ["user" ] is None
159171
160172
161- def test_on_detected_attack_with_user (mock_connection_manager , mock_context ):
173+ def test_on_detected_attack_with_user (mock_connection_manager ):
162174 attack = {
163175 "payload" : {"key" : "value" },
164176 "metadata" : {},
165177 }
166- # Simulate a user in the context
167- mock_context .user = "test_user"
168178
169179 on_detected_attack (
170- mock_connection_manager , attack , mock_context , blocked = False , stack = None
180+ mock_connection_manager ,
181+ attack = attack ,
182+ context = test_utils .generate_context (user = "test_user" ),
183+ blocked = False ,
184+ stack = None ,
171185 )
172186
173187 # Extract the call arguments for the report method
0 commit comments