Skip to content

Commit be12443

Browse files
authored
Add unit tests for each function (#11)
1 parent c7b0336 commit be12443

File tree

7 files changed

+1194
-6
lines changed

7 files changed

+1194
-6
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
4+
from crowdstrike.foundry.function import Request
5+
6+
7+
def mock_handler(*args, **kwargs):
8+
def identity(func):
9+
return func
10+
11+
return identity
12+
13+
14+
class FnTestCase(unittest.TestCase):
15+
def setUp(self):
16+
patcher = patch("crowdstrike.foundry.function.Function.handler", new=mock_handler)
17+
self.addCleanup(patcher.stop)
18+
self.handler_patch = patcher.start()
19+
20+
import importlib
21+
import main
22+
importlib.reload(main)
23+
24+
@patch("main.Hosts")
25+
def test_on_post_success(self, mock_hosts_class):
26+
from main import on_post
27+
28+
# Mock the Hosts instance and its response
29+
mock_hosts_instance = MagicMock()
30+
mock_hosts_class.return_value = mock_hosts_instance
31+
mock_hosts_instance.get_device_details.return_value = {
32+
"status_code": 200,
33+
"body": {
34+
"resources": [{
35+
"device_id": "test-host-123",
36+
"hostname": "test-host",
37+
"platform_name": "Windows"
38+
}]
39+
}
40+
}
41+
42+
request = Request()
43+
request.body = {
44+
"host_id": "test-host-123"
45+
}
46+
47+
response = on_post(request)
48+
49+
self.assertEqual(response.code, 200)
50+
self.assertIn("host_details", response.body)
51+
self.assertEqual(response.body["host_details"]["device_id"], "test-host-123")
52+
mock_hosts_instance.get_device_details.assert_called_once_with(ids="test-host-123")
53+
54+
def test_on_post_missing_host_id(self):
55+
from main import on_post
56+
request = Request()
57+
58+
response = on_post(request)
59+
60+
self.assertEqual(response.code, 400)
61+
self.assertEqual(len(response.errors), 1)
62+
self.assertEqual(response.errors[0].message, "missing host_id from request body")
63+
64+
@patch("main.Hosts")
65+
def test_on_post_api_error(self, mock_hosts_class):
66+
from main import on_post
67+
68+
# Mock the Hosts instance to return an error
69+
mock_hosts_instance = MagicMock()
70+
mock_hosts_class.return_value = mock_hosts_instance
71+
mock_hosts_instance.get_device_details.return_value = {
72+
"status_code": 404,
73+
"body": {"errors": [{"message": "Host not found"}]}
74+
}
75+
76+
request = Request()
77+
request.body = {
78+
"host_id": "nonexistent-host"
79+
}
80+
81+
response = on_post(request)
82+
83+
self.assertEqual(response.code, 404)
84+
self.assertEqual(len(response.errors), 1)
85+
self.assertIn("Error retrieving host:", response.errors[0].message)
86+
mock_hosts_instance.get_device_details.assert_called_once_with(ids="nonexistent-host")
87+
88+
89+
if __name__ == "__main__":
90+
unittest.main()

functions/host-info/test_main.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
4+
from crowdstrike.foundry.function import Request
5+
6+
7+
def mock_handler(*args, **kwargs):
8+
def identity(func):
9+
return func
10+
11+
return identity
12+
13+
14+
class FnTestCase(unittest.TestCase):
15+
def setUp(self):
16+
patcher = patch("crowdstrike.foundry.function.Function.handler", new=mock_handler)
17+
self.addCleanup(patcher.stop)
18+
self.handler_patch = patcher.start()
19+
20+
import importlib
21+
import main
22+
importlib.reload(main)
23+
24+
@patch("main.validate_host_id")
25+
@patch("main.format_error_response")
26+
def test_on_post_success(self, mock_format_error, mock_validate_host_id):
27+
from main import on_post
28+
29+
# Mock validation to return True for valid host ID
30+
mock_validate_host_id.return_value = True
31+
32+
# Create mock logger
33+
mock_logger = MagicMock()
34+
35+
request = Request()
36+
request.body = {
37+
"host_id": "valid-host-123"
38+
}
39+
40+
response = on_post(request, config=None, logger=mock_logger)
41+
42+
self.assertEqual(response.code, 200)
43+
self.assertEqual(response.body["host"], "valid-host-123")
44+
45+
# Verify validation was called twice (once for logging, once for condition)
46+
self.assertEqual(mock_validate_host_id.call_count, 2)
47+
mock_validate_host_id.assert_called_with("valid-host-123")
48+
49+
# Verify logger was called
50+
self.assertEqual(mock_logger.info.call_count, 2)
51+
mock_logger.info.assert_any_call("Host ID: valid-host-123")
52+
mock_logger.info.assert_any_call("Is valid? True")
53+
54+
# Verify format_error_response was not called
55+
mock_format_error.assert_not_called()
56+
57+
@patch("main.validate_host_id")
58+
@patch("main.format_error_response")
59+
def test_on_post_invalid_host_id(self, mock_format_error, mock_validate_host_id):
60+
from main import on_post
61+
62+
# Mock validation to return False for invalid host ID
63+
mock_validate_host_id.return_value = False
64+
65+
# Mock error response formatting
66+
mock_format_error.return_value = [{"code": 400, "message": "Invalid host ID format"}]
67+
68+
# Create mock logger
69+
mock_logger = MagicMock()
70+
71+
request = Request()
72+
request.body = {
73+
"host_id": "invalid-host"
74+
}
75+
76+
response = on_post(request, config=None, logger=mock_logger)
77+
78+
# Should return error response (default code is likely 400)
79+
self.assertEqual(response.errors, [{"code": 400, "message": "Invalid host ID format"}])
80+
81+
# Verify validation was called twice (once for logging, once for condition)
82+
self.assertEqual(mock_validate_host_id.call_count, 2)
83+
mock_validate_host_id.assert_called_with("invalid-host")
84+
85+
# Verify error formatting was called
86+
mock_format_error.assert_called_once_with("Invalid host ID format")
87+
88+
# Verify logger was called
89+
self.assertEqual(mock_logger.info.call_count, 2)
90+
mock_logger.info.assert_any_call("Host ID: invalid-host")
91+
mock_logger.info.assert_any_call("Is valid? False")
92+
93+
@patch("main.validate_host_id")
94+
@patch("main.format_error_response")
95+
def test_on_post_missing_host_id(self, mock_format_error, mock_validate_host_id):
96+
from main import on_post
97+
98+
# Mock validation to return False for None host ID
99+
mock_validate_host_id.return_value = False
100+
101+
# Mock error response formatting
102+
mock_format_error.return_value = [{"code": 400, "message": "Invalid host ID format"}]
103+
104+
# Create mock logger
105+
mock_logger = MagicMock()
106+
107+
request = Request()
108+
request.body = {} # No host_id provided
109+
110+
response = on_post(request, config=None, logger=mock_logger)
111+
112+
# Should return error response
113+
self.assertEqual(response.errors, [{"code": 400, "message": "Invalid host ID format"}])
114+
115+
# Verify validation was called twice (once for logging, once for condition)
116+
self.assertEqual(mock_validate_host_id.call_count, 2)
117+
mock_validate_host_id.assert_called_with(None)
118+
119+
# Verify error formatting was called
120+
mock_format_error.assert_called_once_with("Invalid host ID format")
121+
122+
# Verify logger was called with None
123+
self.assertEqual(mock_logger.info.call_count, 2)
124+
mock_logger.info.assert_any_call("Host ID: None")
125+
mock_logger.info.assert_any_call("Is valid? False")
126+
127+
@patch("main.validate_host_id")
128+
@patch("main.format_error_response")
129+
def test_on_post_empty_host_id(self, mock_format_error, mock_validate_host_id):
130+
from main import on_post
131+
132+
# Mock validation to return False for empty string
133+
mock_validate_host_id.return_value = False
134+
135+
# Mock error response formatting
136+
mock_format_error.return_value = [{"code": 400, "message": "Invalid host ID format"}]
137+
138+
# Create mock logger
139+
mock_logger = MagicMock()
140+
141+
request = Request()
142+
request.body = {
143+
"host_id": ""
144+
}
145+
146+
response = on_post(request, config=None, logger=mock_logger)
147+
148+
# Should return error response
149+
self.assertEqual(response.errors, [{"code": 400, "message": "Invalid host ID format"}])
150+
151+
# Verify validation was called twice (once for logging, once for condition)
152+
self.assertEqual(mock_validate_host_id.call_count, 2)
153+
mock_validate_host_id.assert_called_with("")
154+
155+
# Verify error formatting was called
156+
mock_format_error.assert_called_once_with("Invalid host ID format")
157+
158+
# Verify logger was called with empty string
159+
self.assertEqual(mock_logger.info.call_count, 2)
160+
mock_logger.info.assert_any_call("Host ID: ")
161+
mock_logger.info.assert_any_call("Is valid? False")
162+
163+
@patch("main.validate_host_id")
164+
@patch("main.format_error_response")
165+
def test_on_post_with_config(self, mock_format_error, mock_validate_host_id):
166+
from main import on_post
167+
168+
# Mock validation to return True
169+
mock_validate_host_id.return_value = True
170+
171+
# Create mock logger
172+
mock_logger = MagicMock()
173+
174+
# Test with config parameter
175+
config = {"some_setting": "value"}
176+
177+
request = Request()
178+
request.body = {
179+
"host_id": "test-host-456"
180+
}
181+
182+
response = on_post(request, config=config, logger=mock_logger)
183+
184+
self.assertEqual(response.code, 200)
185+
self.assertEqual(response.body["host"], "test-host-456")
186+
187+
# Verify validation was called twice (once for logging, once for condition)
188+
self.assertEqual(mock_validate_host_id.call_count, 2)
189+
mock_validate_host_id.assert_called_with("test-host-456")
190+
191+
# Verify logger was called
192+
self.assertEqual(mock_logger.info.call_count, 2)
193+
194+
195+
if __name__ == "__main__":
196+
unittest.main()

functions/log-event/main.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
func = Function.instance()
99

1010

11-
@func.handler(method='POST', path='/log-event')
11+
@func.handler(method="POST", path="/log-event")
1212
def on_post(request: Request) -> Response:
1313
# Validate request
14-
if 'event_data' not in request.body:
14+
if "event_data" not in request.body:
1515
return Response(
1616
code=400,
17-
errors=[APIError(code=400, message='missing event_data')]
17+
errors=[APIError(code=400, message="missing event_data")]
1818
)
1919

20-
event_data = request.body['event_data']
20+
event_data = request.body["event_data"]
2121

2222
try:
2323
# Store data in a collection
@@ -47,7 +47,7 @@ def on_post(request: Request) -> Response:
4747
)
4848

4949
if response["status_code"] != 200:
50-
error_message = response.get('error', {}).get('message', 'Unknown error')
50+
error_message = response.get("error", {}).get("message", "Unknown error")
5151
return Response(
5252
code=response["status_code"],
5353
errors=[APIError(
@@ -78,5 +78,5 @@ def on_post(request: Request) -> Response:
7878
)
7979

8080

81-
if __name__ == '__main__':
81+
if __name__ == "__main__":
8282
func.run()

0 commit comments

Comments
 (0)