@@ -85,8 +85,8 @@ async def run_test():
8585 runner .http_client = AsyncMock ()
8686 runner .http_client .get = AsyncMock (return_value = mock_response )
8787
88- # Call poll
89- await runner .poll_and_execute_task ( )
88+ # Call poll using the internal method
89+ await runner ._poll_tasks_from_server ( count = 1 )
9090
9191 # Verify API timing was recorded
9292 runner .metrics_collector .record_api_request_time .assert_called ()
@@ -106,9 +106,12 @@ def test_api_timing_failed_poll_with_status_code(self):
106106 runner = TaskRunnerAsyncIO (
107107 worker = self .worker ,
108108 configuration = self .config ,
109- metrics_collector = self .metrics_collector
109+ metrics_settings = self .metrics_settings
110110 )
111111
112+ # Mock the metrics_collector\'s record method
113+ runner .metrics_collector .record_api_request_time = Mock ()
114+
112115 # Mock HTTP error with response
113116 mock_response = Mock ()
114117 mock_response .status_code = 500
@@ -120,13 +123,13 @@ async def run_test():
120123
121124 # Call poll (should handle exception)
122125 try :
123- await runner .poll_and_execute_task ( )
126+ await runner ._poll_tasks_from_server ( count = 1 )
124127 except :
125128 pass
126129
127130 # Verify API timing was recorded with error status
128- self .metrics_collector .record_api_request_time .assert_called ()
129- call_args = self .metrics_collector .record_api_request_time .call_args
131+ runner .metrics_collector .record_api_request_time .assert_called ()
132+ call_args = runner .metrics_collector .record_api_request_time .call_args
130133
131134 self .assertEqual (call_args .kwargs ['method' ], 'GET' )
132135 self .assertEqual (call_args .kwargs ['status' ], '500' )
@@ -139,9 +142,12 @@ def test_api_timing_failed_poll_without_status_code(self):
139142 runner = TaskRunnerAsyncIO (
140143 worker = self .worker ,
141144 configuration = self .config ,
142- metrics_collector = self .metrics_collector
145+ metrics_settings = self .metrics_settings
143146 )
144147
148+ # Mock the metrics_collector\'s record method
149+ runner .metrics_collector .record_api_request_time = Mock ()
150+
145151 # Mock generic network error
146152 error = httpx .ConnectError ("Connection refused" )
147153
@@ -151,13 +157,13 @@ async def run_test():
151157
152158 # Call poll
153159 try :
154- await runner .poll_and_execute_task ( )
160+ await runner ._poll_tasks_from_server ( count = 1 )
155161 except :
156162 pass
157163
158164 # Verify API timing was recorded with "error" status
159- self .metrics_collector .record_api_request_time .assert_called ()
160- call_args = self .metrics_collector .record_api_request_time .call_args
165+ runner .metrics_collector .record_api_request_time .assert_called ()
166+ call_args = runner .metrics_collector .record_api_request_time .call_args
161167
162168 self .assertEqual (call_args .kwargs ['method' ], 'GET' )
163169 self .assertEqual (call_args .kwargs ['status' ], 'error' )
@@ -169,13 +175,16 @@ def test_api_timing_successful_update(self):
169175 runner = TaskRunnerAsyncIO (
170176 worker = self .worker ,
171177 configuration = self .config ,
172- metrics_collector = self .metrics_collector
178+ metrics_settings = self .metrics_settings
173179 )
174180
175- # Create task and result
176- task = Task (task_id = 'task1' , task_def_name = 'test_task' )
181+ # Mock the metrics_collector's record method
182+ runner .metrics_collector .record_api_request_time = Mock ()
183+
184+ # Create task result
177185 task_result = TaskResult (
178186 task_id = 'task1' ,
187+ workflow_instance_id = 'wf1' ,
179188 status = TaskResultStatus .COMPLETED ,
180189 output_data = {'result' : 'success' }
181190 )
@@ -189,12 +198,12 @@ async def run_test():
189198 runner .http_client = AsyncMock ()
190199 runner .http_client .post = AsyncMock (return_value = mock_response )
191200
192- # Call update
193- await runner ._update_task (task , task_result )
201+ # Call update (only needs task_result)
202+ await runner ._update_task (task_result )
194203
195204 # Verify API timing was recorded
196- self .metrics_collector .record_api_request_time .assert_called ()
197- call_args = self .metrics_collector .record_api_request_time .call_args
205+ runner .metrics_collector .record_api_request_time .assert_called ()
206+ call_args = runner .metrics_collector .record_api_request_time .call_args
198207
199208 self .assertEqual (call_args .kwargs ['method' ], 'POST' )
200209 self .assertIn ('/tasks/update' , call_args .kwargs ['uri' ])
@@ -208,12 +217,16 @@ def test_api_timing_failed_update(self):
208217 runner = TaskRunnerAsyncIO (
209218 worker = self .worker ,
210219 configuration = self .config ,
211- metrics_collector = self .metrics_collector
220+ metrics_settings = self .metrics_settings
212221 )
213222
214- task = Task (task_id = 'task1' , task_def_name = 'test_task' )
223+ # Mock the metrics_collector's record method
224+ runner .metrics_collector .record_api_request_time = Mock ()
225+
226+ # Create task result with required fields
215227 task_result = TaskResult (
216228 task_id = 'task1' ,
229+ workflow_instance_id = 'wf1' ,
217230 status = TaskResultStatus .COMPLETED
218231 )
219232
@@ -226,15 +239,15 @@ async def run_test():
226239 runner .http_client = AsyncMock ()
227240 runner .http_client .post = AsyncMock (side_effect = error )
228241
229- # Call update
242+ # Call update (only needs task_result)
230243 try :
231- await runner ._update_task (task , task_result )
244+ await runner ._update_task (task_result )
232245 except :
233246 pass
234247
235248 # Verify API timing was recorded
236- self .metrics_collector .record_api_request_time .assert_called ()
237- call_args = self .metrics_collector .record_api_request_time .call_args
249+ runner .metrics_collector .record_api_request_time .assert_called ()
250+ call_args = runner .metrics_collector .record_api_request_time .call_args
238251
239252 self .assertEqual (call_args .kwargs ['method' ], 'POST' )
240253 self .assertEqual (call_args .kwargs ['status' ], '503' )
@@ -246,9 +259,12 @@ def test_api_timing_multiple_requests(self):
246259 runner = TaskRunnerAsyncIO (
247260 worker = self .worker ,
248261 configuration = self .config ,
249- metrics_collector = self .metrics_collector
262+ metrics_settings = self .metrics_settings
250263 )
251264
265+ # Mock the metrics_collector's record method
266+ runner .metrics_collector .record_api_request_time = Mock ()
267+
252268 mock_response = Mock ()
253269 mock_response .status_code = 200
254270 mock_response .json .return_value = []
@@ -258,15 +274,15 @@ async def run_test():
258274 runner .http_client .get = AsyncMock (return_value = mock_response )
259275
260276 # Poll 3 times
261- await runner .poll_and_execute_task ( )
262- await runner .poll_and_execute_task ( )
263- await runner .poll_and_execute_task ( )
277+ await runner ._poll_tasks_from_server ( count = 1 )
278+ await runner ._poll_tasks_from_server ( count = 1 )
279+ await runner ._poll_tasks_from_server ( count = 1 )
264280
265281 # Should have 3 API timing records
266- self .assertEqual (self .metrics_collector .record_api_request_time .call_count , 3 )
282+ self .assertEqual (runner .metrics_collector .record_api_request_time .call_count , 3 )
267283
268284 # All should be successful
269- for call in self .metrics_collector .record_api_request_time .call_args_list :
285+ for call in runner .metrics_collector .record_api_request_time .call_args_list :
270286 self .assertEqual (call .kwargs ['status' ], '200' )
271287
272288 asyncio .run (run_test ())
@@ -275,8 +291,7 @@ def test_api_timing_without_metrics_collector(self):
275291 """Test that API requests work without metrics collector"""
276292 runner = TaskRunnerAsyncIO (
277293 worker = self .worker ,
278- configuration = self .config ,
279- metrics_collector = None
294+ configuration = self .config
280295 )
281296
282297 mock_response = Mock ()
@@ -288,7 +303,7 @@ async def run_test():
288303 runner .http_client .get = AsyncMock (return_value = mock_response )
289304
290305 # Should not raise exception
291- await runner .poll_and_execute_task ( )
306+ await runner ._poll_tasks_from_server ( count = 1 )
292307
293308 # No metrics recorded (metrics_collector is None)
294309 # Just verify no exception was raised
@@ -300,9 +315,12 @@ def test_api_timing_precision(self):
300315 runner = TaskRunnerAsyncIO (
301316 worker = self .worker ,
302317 configuration = self .config ,
303- metrics_collector = self .metrics_collector
318+ metrics_settings = self .metrics_settings
304319 )
305320
321+ # Mock the metrics_collector\'s record method
322+ runner .metrics_collector .record_api_request_time = Mock ()
323+
306324 # Mock fast response
307325 mock_response = Mock ()
308326 mock_response .status_code = 200
@@ -318,10 +336,10 @@ async def mock_get(*args, **kwargs):
318336
319337 runner .http_client .get = mock_get
320338
321- await runner .poll_and_execute_task ( )
339+ await runner ._poll_tasks_from_server ( count = 1 )
322340
323341 # Verify timing captured sub-second precision
324- call_args = self .metrics_collector .record_api_request_time .call_args
342+ call_args = runner .metrics_collector .record_api_request_time .call_args
325343 time_spent = call_args .kwargs ['time_spent' ]
326344
327345 # Should be at least 1ms, but less than 100ms
@@ -335,9 +353,12 @@ def test_api_timing_auth_error_401(self):
335353 runner = TaskRunnerAsyncIO (
336354 worker = self .worker ,
337355 configuration = self .config ,
338- metrics_collector = self .metrics_collector
356+ metrics_settings = self .metrics_settings
339357 )
340358
359+ # Mock the metrics_collector's record method
360+ runner .metrics_collector .record_api_request_time = Mock ()
361+
341362 mock_response = Mock ()
342363 mock_response .status_code = 401
343364 error = httpx .HTTPStatusError ("Unauthorized" , request = Mock (), response = mock_response )
@@ -347,12 +368,12 @@ async def run_test():
347368 runner .http_client .get = AsyncMock (side_effect = error )
348369
349370 try :
350- await runner .poll_and_execute_task ( )
371+ await runner ._poll_tasks_from_server ( count = 1 )
351372 except :
352373 pass
353374
354375 # Verify 401 status captured
355- call_args = self .metrics_collector .record_api_request_time .call_args
376+ call_args = runner .metrics_collector .record_api_request_time .call_args
356377 self .assertEqual (call_args .kwargs ['status' ], '401' )
357378
358379 asyncio .run (run_test ())
@@ -362,22 +383,25 @@ def test_api_timing_timeout_error(self):
362383 runner = TaskRunnerAsyncIO (
363384 worker = self .worker ,
364385 configuration = self .config ,
365- metrics_collector = self .metrics_collector
386+ metrics_settings = self .metrics_settings
366387 )
367388
389+ # Mock the metrics_collector's record method
390+ runner .metrics_collector .record_api_request_time = Mock ()
391+
368392 error = httpx .TimeoutException ("Request timeout" )
369393
370394 async def run_test ():
371395 runner .http_client = AsyncMock ()
372396 runner .http_client .get = AsyncMock (side_effect = error )
373397
374398 try :
375- await runner .poll_and_execute_task ( )
399+ await runner ._poll_tasks_from_server ( count = 1 )
376400 except :
377401 pass
378402
379403 # Verify "error" status for timeout
380- call_args = self .metrics_collector .record_api_request_time .call_args
404+ call_args = runner .metrics_collector .record_api_request_time .call_args
381405 self .assertEqual (call_args .kwargs ['status' ], 'error' )
382406
383407 asyncio .run (run_test ())
@@ -387,9 +411,12 @@ def test_api_timing_concurrent_requests(self):
387411 runner = TaskRunnerAsyncIO (
388412 worker = self .worker ,
389413 configuration = self .config ,
390- metrics_collector = self .metrics_collector
414+ metrics_settings = self .metrics_settings
391415 )
392416
417+ # Mock the metrics_collector's record method
418+ runner .metrics_collector .record_api_request_time = Mock ()
419+
393420 mock_response = Mock ()
394421 mock_response .status_code = 200
395422 mock_response .json .return_value = []
@@ -400,11 +427,11 @@ async def run_test():
400427
401428 # Run 5 concurrent polls
402429 await asyncio .gather (* [
403- runner .poll_and_execute_task ( ) for _ in range (5 )
430+ runner ._poll_tasks_from_server ( count = 1 ) for _ in range (5 )
404431 ])
405432
406433 # Should have 5 timing records
407- self .assertEqual (self .metrics_collector .record_api_request_time .call_count , 5 )
434+ self .assertEqual (runner .metrics_collector .record_api_request_time .call_count , 5 )
408435
409436 asyncio .run (run_test ())
410437
0 commit comments