Skip to content

Commit 3d78c38

Browse files
committed
fix tests
1 parent b312836 commit 3d78c38

2 files changed

Lines changed: 128 additions & 104 deletions

File tree

tests/unit/automator/test_api_metrics.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ async def run_test():
8585
runner.http_client = AsyncMock()
8686
runner.http_client.get = AsyncMock(return_value=mock_response)
8787

88-
# Call poll
89-
await runner.poll_and_execute_task()
88+
# Call poll using the internal method
89+
await runner._poll_tasks_from_server(count=1)
9090

9191
# Verify API timing was recorded
9292
runner.metrics_collector.record_api_request_time.assert_called()
@@ -106,9 +106,12 @@ def test_api_timing_failed_poll_with_status_code(self):
106106
runner = TaskRunnerAsyncIO(
107107
worker=self.worker,
108108
configuration=self.config,
109-
metrics_collector=self.metrics_collector
109+
metrics_settings=self.metrics_settings
110110
)
111111

112+
# Mock the metrics_collector\'s record method
113+
runner.metrics_collector.record_api_request_time = Mock()
114+
112115
# Mock HTTP error with response
113116
mock_response = Mock()
114117
mock_response.status_code = 500
@@ -120,13 +123,13 @@ async def run_test():
120123

121124
# Call poll (should handle exception)
122125
try:
123-
await runner.poll_and_execute_task()
126+
await runner._poll_tasks_from_server(count=1)
124127
except:
125128
pass
126129

127130
# Verify API timing was recorded with error status
128-
self.metrics_collector.record_api_request_time.assert_called()
129-
call_args = self.metrics_collector.record_api_request_time.call_args
131+
runner.metrics_collector.record_api_request_time.assert_called()
132+
call_args = runner.metrics_collector.record_api_request_time.call_args
130133

131134
self.assertEqual(call_args.kwargs['method'], 'GET')
132135
self.assertEqual(call_args.kwargs['status'], '500')
@@ -139,9 +142,12 @@ def test_api_timing_failed_poll_without_status_code(self):
139142
runner = TaskRunnerAsyncIO(
140143
worker=self.worker,
141144
configuration=self.config,
142-
metrics_collector=self.metrics_collector
145+
metrics_settings=self.metrics_settings
143146
)
144147

148+
# Mock the metrics_collector\'s record method
149+
runner.metrics_collector.record_api_request_time = Mock()
150+
145151
# Mock generic network error
146152
error = httpx.ConnectError("Connection refused")
147153

@@ -151,13 +157,13 @@ async def run_test():
151157

152158
# Call poll
153159
try:
154-
await runner.poll_and_execute_task()
160+
await runner._poll_tasks_from_server(count=1)
155161
except:
156162
pass
157163

158164
# Verify API timing was recorded with "error" status
159-
self.metrics_collector.record_api_request_time.assert_called()
160-
call_args = self.metrics_collector.record_api_request_time.call_args
165+
runner.metrics_collector.record_api_request_time.assert_called()
166+
call_args = runner.metrics_collector.record_api_request_time.call_args
161167

162168
self.assertEqual(call_args.kwargs['method'], 'GET')
163169
self.assertEqual(call_args.kwargs['status'], 'error')
@@ -169,13 +175,16 @@ def test_api_timing_successful_update(self):
169175
runner = TaskRunnerAsyncIO(
170176
worker=self.worker,
171177
configuration=self.config,
172-
metrics_collector=self.metrics_collector
178+
metrics_settings=self.metrics_settings
173179
)
174180

175-
# Create task and result
176-
task = Task(task_id='task1', task_def_name='test_task')
181+
# Mock the metrics_collector's record method
182+
runner.metrics_collector.record_api_request_time = Mock()
183+
184+
# Create task result
177185
task_result = TaskResult(
178186
task_id='task1',
187+
workflow_instance_id='wf1',
179188
status=TaskResultStatus.COMPLETED,
180189
output_data={'result': 'success'}
181190
)
@@ -189,12 +198,12 @@ async def run_test():
189198
runner.http_client = AsyncMock()
190199
runner.http_client.post = AsyncMock(return_value=mock_response)
191200

192-
# Call update
193-
await runner._update_task(task, task_result)
201+
# Call update (only needs task_result)
202+
await runner._update_task(task_result)
194203

195204
# Verify API timing was recorded
196-
self.metrics_collector.record_api_request_time.assert_called()
197-
call_args = self.metrics_collector.record_api_request_time.call_args
205+
runner.metrics_collector.record_api_request_time.assert_called()
206+
call_args = runner.metrics_collector.record_api_request_time.call_args
198207

199208
self.assertEqual(call_args.kwargs['method'], 'POST')
200209
self.assertIn('/tasks/update', call_args.kwargs['uri'])
@@ -208,12 +217,16 @@ def test_api_timing_failed_update(self):
208217
runner = TaskRunnerAsyncIO(
209218
worker=self.worker,
210219
configuration=self.config,
211-
metrics_collector=self.metrics_collector
220+
metrics_settings=self.metrics_settings
212221
)
213222

214-
task = Task(task_id='task1', task_def_name='test_task')
223+
# Mock the metrics_collector's record method
224+
runner.metrics_collector.record_api_request_time = Mock()
225+
226+
# Create task result with required fields
215227
task_result = TaskResult(
216228
task_id='task1',
229+
workflow_instance_id='wf1',
217230
status=TaskResultStatus.COMPLETED
218231
)
219232

@@ -226,15 +239,15 @@ async def run_test():
226239
runner.http_client = AsyncMock()
227240
runner.http_client.post = AsyncMock(side_effect=error)
228241

229-
# Call update
242+
# Call update (only needs task_result)
230243
try:
231-
await runner._update_task(task, task_result)
244+
await runner._update_task(task_result)
232245
except:
233246
pass
234247

235248
# Verify API timing was recorded
236-
self.metrics_collector.record_api_request_time.assert_called()
237-
call_args = self.metrics_collector.record_api_request_time.call_args
249+
runner.metrics_collector.record_api_request_time.assert_called()
250+
call_args = runner.metrics_collector.record_api_request_time.call_args
238251

239252
self.assertEqual(call_args.kwargs['method'], 'POST')
240253
self.assertEqual(call_args.kwargs['status'], '503')
@@ -246,9 +259,12 @@ def test_api_timing_multiple_requests(self):
246259
runner = TaskRunnerAsyncIO(
247260
worker=self.worker,
248261
configuration=self.config,
249-
metrics_collector=self.metrics_collector
262+
metrics_settings=self.metrics_settings
250263
)
251264

265+
# Mock the metrics_collector's record method
266+
runner.metrics_collector.record_api_request_time = Mock()
267+
252268
mock_response = Mock()
253269
mock_response.status_code = 200
254270
mock_response.json.return_value = []
@@ -258,15 +274,15 @@ async def run_test():
258274
runner.http_client.get = AsyncMock(return_value=mock_response)
259275

260276
# Poll 3 times
261-
await runner.poll_and_execute_task()
262-
await runner.poll_and_execute_task()
263-
await runner.poll_and_execute_task()
277+
await runner._poll_tasks_from_server(count=1)
278+
await runner._poll_tasks_from_server(count=1)
279+
await runner._poll_tasks_from_server(count=1)
264280

265281
# Should have 3 API timing records
266-
self.assertEqual(self.metrics_collector.record_api_request_time.call_count, 3)
282+
self.assertEqual(runner.metrics_collector.record_api_request_time.call_count, 3)
267283

268284
# All should be successful
269-
for call in self.metrics_collector.record_api_request_time.call_args_list:
285+
for call in runner.metrics_collector.record_api_request_time.call_args_list:
270286
self.assertEqual(call.kwargs['status'], '200')
271287

272288
asyncio.run(run_test())
@@ -275,8 +291,7 @@ def test_api_timing_without_metrics_collector(self):
275291
"""Test that API requests work without metrics collector"""
276292
runner = TaskRunnerAsyncIO(
277293
worker=self.worker,
278-
configuration=self.config,
279-
metrics_collector=None
294+
configuration=self.config
280295
)
281296

282297
mock_response = Mock()
@@ -288,7 +303,7 @@ async def run_test():
288303
runner.http_client.get = AsyncMock(return_value=mock_response)
289304

290305
# Should not raise exception
291-
await runner.poll_and_execute_task()
306+
await runner._poll_tasks_from_server(count=1)
292307

293308
# No metrics recorded (metrics_collector is None)
294309
# Just verify no exception was raised
@@ -300,9 +315,12 @@ def test_api_timing_precision(self):
300315
runner = TaskRunnerAsyncIO(
301316
worker=self.worker,
302317
configuration=self.config,
303-
metrics_collector=self.metrics_collector
318+
metrics_settings=self.metrics_settings
304319
)
305320

321+
# Mock the metrics_collector\'s record method
322+
runner.metrics_collector.record_api_request_time = Mock()
323+
306324
# Mock fast response
307325
mock_response = Mock()
308326
mock_response.status_code = 200
@@ -318,10 +336,10 @@ async def mock_get(*args, **kwargs):
318336

319337
runner.http_client.get = mock_get
320338

321-
await runner.poll_and_execute_task()
339+
await runner._poll_tasks_from_server(count=1)
322340

323341
# Verify timing captured sub-second precision
324-
call_args = self.metrics_collector.record_api_request_time.call_args
342+
call_args = runner.metrics_collector.record_api_request_time.call_args
325343
time_spent = call_args.kwargs['time_spent']
326344

327345
# Should be at least 1ms, but less than 100ms
@@ -335,9 +353,12 @@ def test_api_timing_auth_error_401(self):
335353
runner = TaskRunnerAsyncIO(
336354
worker=self.worker,
337355
configuration=self.config,
338-
metrics_collector=self.metrics_collector
356+
metrics_settings=self.metrics_settings
339357
)
340358

359+
# Mock the metrics_collector's record method
360+
runner.metrics_collector.record_api_request_time = Mock()
361+
341362
mock_response = Mock()
342363
mock_response.status_code = 401
343364
error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
@@ -347,12 +368,12 @@ async def run_test():
347368
runner.http_client.get = AsyncMock(side_effect=error)
348369

349370
try:
350-
await runner.poll_and_execute_task()
371+
await runner._poll_tasks_from_server(count=1)
351372
except:
352373
pass
353374

354375
# Verify 401 status captured
355-
call_args = self.metrics_collector.record_api_request_time.call_args
376+
call_args = runner.metrics_collector.record_api_request_time.call_args
356377
self.assertEqual(call_args.kwargs['status'], '401')
357378

358379
asyncio.run(run_test())
@@ -362,22 +383,25 @@ def test_api_timing_timeout_error(self):
362383
runner = TaskRunnerAsyncIO(
363384
worker=self.worker,
364385
configuration=self.config,
365-
metrics_collector=self.metrics_collector
386+
metrics_settings=self.metrics_settings
366387
)
367388

389+
# Mock the metrics_collector's record method
390+
runner.metrics_collector.record_api_request_time = Mock()
391+
368392
error = httpx.TimeoutException("Request timeout")
369393

370394
async def run_test():
371395
runner.http_client = AsyncMock()
372396
runner.http_client.get = AsyncMock(side_effect=error)
373397

374398
try:
375-
await runner.poll_and_execute_task()
399+
await runner._poll_tasks_from_server(count=1)
376400
except:
377401
pass
378402

379403
# Verify "error" status for timeout
380-
call_args = self.metrics_collector.record_api_request_time.call_args
404+
call_args = runner.metrics_collector.record_api_request_time.call_args
381405
self.assertEqual(call_args.kwargs['status'], 'error')
382406

383407
asyncio.run(run_test())
@@ -387,9 +411,12 @@ def test_api_timing_concurrent_requests(self):
387411
runner = TaskRunnerAsyncIO(
388412
worker=self.worker,
389413
configuration=self.config,
390-
metrics_collector=self.metrics_collector
414+
metrics_settings=self.metrics_settings
391415
)
392416

417+
# Mock the metrics_collector's record method
418+
runner.metrics_collector.record_api_request_time = Mock()
419+
393420
mock_response = Mock()
394421
mock_response.status_code = 200
395422
mock_response.json.return_value = []
@@ -400,11 +427,11 @@ async def run_test():
400427

401428
# Run 5 concurrent polls
402429
await asyncio.gather(*[
403-
runner.poll_and_execute_task() for _ in range(5)
430+
runner._poll_tasks_from_server(count=1) for _ in range(5)
404431
])
405432

406433
# Should have 5 timing records
407-
self.assertEqual(self.metrics_collector.record_api_request_time.call_count, 5)
434+
self.assertEqual(runner.metrics_collector.record_api_request_time.call_count, 5)
408435

409436
asyncio.run(run_test())
410437

0 commit comments

Comments
 (0)