Skip to content

Commit bbe8275

Browse files
committed
Cleanup on_detected_attack_test test cases
1 parent 67e0638 commit bbe8275

2 files changed

Lines changed: 78 additions & 50 deletions

File tree

aikido_zen/background_process/cloud_connection_manager/on_detected_attack_test.py

Lines changed: 63 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
2-
from unittest.mock import MagicMock, patch
2+
from unittest.mock import MagicMock
33
from .on_detected_attack import on_detected_attack
44
from ...context import Context
5+
import aikido_zen.test_utils as test_utils
56

67

78
@pytest.fixture
@@ -15,79 +16,76 @@ def mock_connection_manager():
1516
return connection_manager
1617

1718

18-
@pytest.fixture
19-
def mock_context():
20-
basic_wsgi_req = {
21-
"REQUEST_METHOD": "GET",
22-
"HTTP_HEADER_1": "header 1 value",
23-
"HTTP_HEADER_2": "Header 2 value",
24-
"RANDOM_VALUE": "Random value",
25-
"HTTP_COOKIE": "sessionId=abc123xyz456;",
26-
"wsgi.url_scheme": "http",
27-
"HTTP_HOST": "localhost:8080",
28-
"PATH_INFO": "/hello",
29-
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
30-
"CONTENT_TYPE": "application/json",
31-
"REMOTE_ADDR": "198.51.100.23",
32-
"HTTP_USER_AGENT": "Mozilla/5.0",
33-
}
34-
35-
return Context(req=basic_wsgi_req, body=123, source="django")
36-
37-
38-
def test_on_detected_attack_no_token(mock_context):
19+
def test_on_detected_attack_no_token():
3920
connection_manager = MagicMock()
4021
connection_manager.token = None
41-
on_detected_attack(connection_manager, {}, mock_context, blocked=False, stack=None)
22+
23+
on_detected_attack(
24+
connection_manager,
25+
attack={},
26+
context=test_utils.generate_context(),
27+
blocked=False,
28+
stack=None,
29+
)
30+
4231
connection_manager.api.report.assert_not_called()
4332

4433

45-
def test_on_detected_attack_with_long_payload(mock_connection_manager, mock_context):
34+
def test_on_detected_attack_with_long_payload(mock_connection_manager):
4635
long_payload = "x" * 5000 # Create a payload longer than 4096 characters
4736
attack = {
4837
"payload": long_payload,
4938
"metadata": {"test": "1"},
5039
}
5140

5241
on_detected_attack(
53-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
42+
mock_connection_manager,
43+
attack=attack,
44+
context=test_utils.generate_context(),
45+
blocked=False,
46+
stack=None,
5447
)
48+
5549
assert len(attack["payload"]) == 4096 # Ensure payload is truncated
5650
mock_connection_manager.api.report.assert_called_once()
5751

5852

59-
def test_on_detected_attack_with_long_metadata(mock_connection_manager, mock_context):
53+
def test_on_detected_attack_with_long_metadata(mock_connection_manager):
6054
long_metadata = "x" * 5000 # Create metadata longer than 4096 characters
6155
attack = {
6256
"payload": {},
6357
"metadata": {"test": long_metadata},
6458
}
6559

6660
on_detected_attack(
67-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
61+
mock_connection_manager,
62+
attack=attack,
63+
context=test_utils.generate_context(),
64+
blocked=False,
65+
stack=None,
6866
)
6967

70-
assert (
71-
attack["metadata"]["test"] == long_metadata[:4096]
72-
) # Ensure metadata is truncated
68+
assert attack["metadata"]["test"] == long_metadata[:4096]
7369
mock_connection_manager.api.report.assert_called_once()
7470

7571

76-
def test_on_detected_attack_success(mock_connection_manager, mock_context):
72+
def test_on_detected_attack_success(mock_connection_manager):
7773
attack = {
7874
"payload": {"key": "value"},
7975
"metadata": {},
8076
}
8177

8278
on_detected_attack(
83-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
79+
mock_connection_manager,
80+
attack=attack,
81+
context=test_utils.generate_context(),
82+
blocked=False,
83+
stack=None,
8484
)
8585
assert mock_connection_manager.api.report.call_count == 1
8686

8787

88-
def test_on_detected_attack_exception_handling(
89-
mock_connection_manager, mock_context, caplog
90-
):
88+
def test_on_detected_attack_exception_handling(mock_connection_manager, caplog):
9189
attack = {
9290
"payload": {"key": "value"},
9391
"metadata": {"key": "value"},
@@ -97,15 +95,17 @@ def test_on_detected_attack_exception_handling(
9795
mock_connection_manager.api.report.side_effect = Exception("API error")
9896

9997
on_detected_attack(
100-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
98+
mock_connection_manager,
99+
attack=attack,
100+
context=test_utils.generate_context(),
101+
blocked=False,
102+
stack=None,
101103
)
102104

103105
assert "Failed to report an attack" in caplog.text
104106

105107

106-
def test_on_detected_attack_with_blocked_and_stack(
107-
mock_connection_manager, mock_context
108-
):
108+
def test_on_detected_attack_with_blocked_and_stack(mock_connection_manager):
109109
attack = {
110110
"payload": {"key": "value"},
111111
"metadata": {},
@@ -114,7 +114,11 @@ def test_on_detected_attack_with_blocked_and_stack(
114114
stack = "sample stack trace"
115115

116116
on_detected_attack(
117-
mock_connection_manager, attack, mock_context, blocked=blocked, stack=stack
117+
mock_connection_manager,
118+
attack=attack,
119+
context=test_utils.generate_context(),
120+
blocked=blocked,
121+
stack=stack,
118122
)
119123

120124
# Check that the attack dictionary has the blocked and stack fields set
@@ -123,16 +127,24 @@ def test_on_detected_attack_with_blocked_and_stack(
123127
assert mock_connection_manager.api.report.call_count == 1
124128

125129

126-
def test_on_detected_attack_request_data_and_attack_data(
127-
mock_connection_manager, mock_context
128-
):
130+
def test_on_detected_attack_request_data_and_attack_data(mock_connection_manager):
129131
attack = {
130132
"payload": {"key": "value"},
131133
"metadata": {"test": "true"},
132134
}
133135

134136
on_detected_attack(
135-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
137+
mock_connection_manager,
138+
attack=attack,
139+
context=test_utils.generate_context(
140+
method="GET",
141+
url="http://localhost:8080/hello",
142+
ip="198.51.100.23",
143+
route="/hello",
144+
headers={"user-agent": "Mozilla/5.0"},
145+
),
146+
blocked=False,
147+
stack=None,
136148
)
137149

138150
# Extract the call arguments for the report method
@@ -146,7 +158,7 @@ def test_on_detected_attack_request_data_and_attack_data(
146158
assert request_data["ipAddress"] == "198.51.100.23"
147159
assert not "body" in request_data
148160
assert not "headers" in request_data
149-
assert request_data["source"] == "django"
161+
assert request_data["source"] == "flask"
150162
assert request_data["route"] == "/hello"
151163
assert request_data["userAgent"] == "Mozilla/5.0"
152164

@@ -158,16 +170,18 @@ def test_on_detected_attack_request_data_and_attack_data(
158170
assert attack_data["user"] is None
159171

160172

161-
def test_on_detected_attack_with_user(mock_connection_manager, mock_context):
173+
def test_on_detected_attack_with_user(mock_connection_manager):
162174
attack = {
163175
"payload": {"key": "value"},
164176
"metadata": {},
165177
}
166-
# Simulate a user in the context
167-
mock_context.user = "test_user"
168178

169179
on_detected_attack(
170-
mock_connection_manager, attack, mock_context, blocked=False, stack=None
180+
mock_connection_manager,
181+
attack=attack,
182+
context=test_utils.generate_context(user="test_user"),
183+
blocked=False,
184+
stack=None,
171185
)
172186

173187
# Extract the call arguments for the report method

aikido_zen/test_utils/context_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@ def generate_and_set_context(*args, **kwargs) -> Context:
99

1010

1111
def generate_context(
12-
value=None, query_value=None, user=None, route=None, ip=None
12+
value=None,
13+
query_value=None,
14+
user=None,
15+
route=None,
16+
ip=None,
17+
method=None,
18+
url=None,
19+
headers=None,
1320
) -> Context:
1421
context = MockTestContext()
1522

@@ -23,6 +30,13 @@ def generate_context(
2330
context.route = route
2431
if ip is not None:
2532
context.remote_address = ip
33+
if method is not None:
34+
context.method = method
35+
if url is not None:
36+
context.url = url
37+
if headers is not None:
38+
for header_k, header_v in headers.items():
39+
context.headers.store_header(header_k, header_v)
2640

2741
return context
2842

0 commit comments

Comments
 (0)