Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aikido_zen/sources/functions/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def post_response(status_code):
if cache.is_bypassed_ip(context.remote_address):
return

attack_wave = attack_wave_detector_store.is_attack_wave(context)
attack_wave = attack_wave_detector_store.is_attack_wave(context, status_code)
if attack_wave:
cache.stats.on_detected_attack_wave(blocked=False)

Expand Down
4 changes: 2 additions & 2 deletions aikido_zen/storage/attack_wave_detector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def __init__(self):
self._detector = AttackWaveDetector()
self._lock = threading.RLock() # Reentrant lock for thread safety

def is_attack_wave(self, context: Context) -> bool:
def is_attack_wave(self, context: Context, status_code: int) -> bool:
with self._lock:
return self._detector.is_attack_wave(context)
return self._detector.is_attack_wave(context, status_code)

def get_samples_for_ip(self, ip: str):
with self._lock:
Expand Down
54 changes: 27 additions & 27 deletions aikido_zen/storage/attack_wave_detector_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def test_is_attack_wave_basic_functionality():
return_value=True,
):
# Should return False for first few calls
assert not store.is_attack_wave(context)
assert not store.is_attack_wave(context)
assert not store.is_attack_wave(context, 404)
assert not store.is_attack_wave(context, 404)

# Call 12 more times to get to 14 total (still below threshold)
for _ in range(12):
result = store.is_attack_wave(context)
result = store.is_attack_wave(context, 404)
assert not result

# The 15th call should trigger attack wave detection and return True
assert store.is_attack_wave(context)
assert store.is_attack_wave(context, 404)


def test_is_attack_wave_different_ips():
Expand All @@ -62,25 +62,25 @@ def test_is_attack_wave_different_ips():
):
# Call multiple times for different IPs
for _ in range(10):
store.is_attack_wave(context1)
store.is_attack_wave(context2)
store.is_attack_wave(context1, 404)
store.is_attack_wave(context2, 404)

# Neither should trigger attack wave yet
assert not store.is_attack_wave(context1)
assert not store.is_attack_wave(context2)
assert not store.is_attack_wave(context1, 404)
assert not store.is_attack_wave(context2, 404)


def test_is_attack_wave_none_context():
"""Test handling of None context"""
store = AttackWaveDetectorStore()
assert not store.is_attack_wave(None)
assert not store.is_attack_wave(None, 404)


def test_is_attack_wave_no_ip_in_context():
"""Test handling of context with no IP address"""
store = AttackWaveDetectorStore()
context = test_utils.generate_context(ip=None)
assert not store.is_attack_wave(context)
assert not store.is_attack_wave(context, 404)


def test_thread_safety_multiple_threads():
Expand All @@ -94,7 +94,7 @@ def worker(ip_suffix, result_list):
"""Worker function that calls is_attack_wave multiple times"""
context = test_utils.generate_context(ip=f"192.168.1.{ip_suffix}")
for _ in range(5):
result = store.is_attack_wave(context)
result = store.is_attack_wave(context, 404)
result_list.append((context.remote_address, result))
time.sleep(0.001) # Small delay to simulate real usage

Expand Down Expand Up @@ -127,7 +127,7 @@ def worker(result_list):
"""Worker function that calls is_attack_wave for the same IP"""
context = test_utils.generate_context(ip="10.0.0.1")
for _ in range(10):
result = store.is_attack_wave(context)
result = store.is_attack_wave(context, 404)
with lock:
result_list.append(result)
time.sleep(0.001)
Expand Down Expand Up @@ -161,13 +161,13 @@ def test_attack_wave_cooldown():
):
# Call 14 times to get close to threshold
for _ in range(14):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# The 15th call should trigger attack wave detection and return True
assert store.is_attack_wave(context)
assert store.is_attack_wave(context, 404)

# Subsequent calls should return False due to cooldown
assert not store.is_attack_wave(context)
assert not store.is_attack_wave(context, 404)


def test_attack_wave_time_frame():
Expand All @@ -182,10 +182,10 @@ def test_attack_wave_time_frame():
):
# Make some calls
for _ in range(5):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# Should not trigger attack wave yet
assert not store.is_attack_wave(context)
assert not store.is_attack_wave(context, 404)

# Wait for the time frame to expire (60 seconds)
# We can't actually wait 60 seconds in a test, but we can verify the behavior
Expand Down Expand Up @@ -232,7 +232,7 @@ def worker(worker_id):
try:
for i in range(10):
context = test_utils.generate_context(ip=f"192.168.{worker_id}.{i}")
result = store.is_attack_wave(context)
result = store.is_attack_wave(context, 404)
results.append((worker_id, context.remote_address, result))
except Exception as e:
results.append((worker_id, "error", str(e)))
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_samples_tracking_in_store():
):
# Make a few requests
for i in range(3):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# Check that samples are being tracked (should have 1 unique sample)
samples = store.get_samples_for_ip(context.remote_address)
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_samples_structure_and_content():
):
# Make enough requests to trigger attack wave
for i in range(15):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# Get samples
samples = store.get_samples_for_ip(context.remote_address)
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_samples_json_serialization():
):
# Make enough requests to trigger attack wave
for i in range(15):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# Get samples
samples = store.get_samples_for_ip(context.remote_address)
Expand Down Expand Up @@ -356,8 +356,8 @@ def test_samples_with_different_contexts():
):
# Make requests for both contexts
for i in range(15):
store.is_attack_wave(context1)
store.is_attack_wave(context2)
store.is_attack_wave(context1, 404)
store.is_attack_wave(context2, 404)

# Get samples for each IP
samples1 = store.get_samples_for_ip(context1.remote_address)
Expand Down Expand Up @@ -424,7 +424,7 @@ def create_context_with_url(ip, url, method="GET"):

# Make enough requests to trigger attack wave
for j in range(15):
store.is_attack_wave(context)
store.is_attack_wave(context, 404)

# Check a few IPs to verify sample structure
for i in range(5):
Expand All @@ -446,7 +446,7 @@ def test_mock_detector_integration(mock_detector_class):
store = AttackWaveDetectorStore()
context = test_utils.generate_context()

# Should use the mocked detector
result = store.is_attack_wave(context)
# Should use the mocked detector (default status_code=404)
result = store.is_attack_wave(context, 404)
assert result is True
mock_detector.is_attack_wave.assert_called_once_with(context)
mock_detector.is_attack_wave.assert_called_once_with(context, 404)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
time_to_live_in_ms=self.attack_wave_time_frame,
)

def is_attack_wave(self, context: Context) -> bool:
def is_attack_wave(self, context: Context, status_code: int) -> bool:
"""
Function gets called with context to check if there is an attack wave request.
"""
Expand All @@ -45,7 +45,7 @@ def is_attack_wave(self, context: Context) -> bool:
if self.sent_events_map.get(ip) is not None:
return False

if not is_web_scanner(context):
if not is_web_scanner(context, status_code):
return False

# Increment suspicious requests count -> there is a new or first suspicious request
Expand Down
Loading
Loading