2121import inspect
2222import logging
2323import sys
24+ import threading
2425import time
2526import traceback
2627import uuid
@@ -347,6 +348,8 @@ def __init__(
347348 self ._shutdown_timeout = shutdown_timeout
348349 self ._in_flight : set [asyncio .Task [Any ]] = set ()
349350 self ._query_tasks_supported = False
351+ self ._query_thread_stop = threading .Event ()
352+ self ._query_thread : threading .Thread | None = None
350353 # In-flight slot accounting feeds the periodic heartbeat so operators
351354 # see free-slot counts without the worker having to re-derive them at
352355 # shutdown. Counters are bumped/decremented around dispatch.
@@ -897,15 +900,15 @@ async def call_interceptor(
897900
898901 return await handler (context )
899902
900- async def _run_query_task (self , task : dict [str , Any ]) -> str :
903+ async def _run_query_task (self , task : dict [str , Any ], * , client : Client | None = None ) -> str :
901904 context = QueryTaskInterceptorContext (
902905 worker_id = self .worker_id ,
903906 task_queue = self .task_queue ,
904907 task = task ,
905908 )
906909
907910 async def call_core (ctx : QueryTaskInterceptorContext ) -> str :
908- return await self ._run_query_task_core (ctx .task )
911+ return await self ._run_query_task_core (ctx .task , client = client )
909912
910913 handler = call_core
911914 for interceptor in reversed (self .interceptors ):
@@ -923,7 +926,8 @@ async def call_interceptor(
923926
924927 return await handler (context )
925928
926- async def _run_query_task_core (self , task : dict [str , Any ]) -> str :
929+ async def _run_query_task_core (self , task : dict [str , Any ], * , client : Client | None = None ) -> str :
930+ client = client or self .client
927931 query_task_id : str = task ["query_task_id" ]
928932 attempt : int = task .get ("query_task_attempt" , 1 )
929933 wf_type : str = task .get ("workflow_type" , "" )
@@ -939,6 +943,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
939943 reason = "query_payload_decode_failed" ,
940944 failure_type = type (e ).__name__ ,
941945 stack_trace = traceback .format_exc (),
946+ client = client ,
942947 )
943948 return "failed"
944949
@@ -953,6 +958,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
953958 f"no workflow registered for type { wf_type !r} " ,
954959 reason = "query_workflow_type_not_registered" ,
955960 failure_type = "WorkflowTypeNotRegistered" ,
961+ client = client ,
956962 )
957963 return "failed"
958964
@@ -992,6 +998,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
992998 reason = "query_payload_decode_failed" ,
993999 failure_type = type (e ).__name__ ,
9941000 stack_trace = traceback .format_exc (),
1001+ client = client ,
9951002 )
9961003 return "failed"
9971004 except QueryFailed as e :
@@ -1003,6 +1010,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10031010 reason = reason ,
10041011 failure_type = type (e ).__name__ ,
10051012 stack_trace = traceback .format_exc (),
1013+ client = client ,
10061014 )
10071015 return "failed"
10081016 except Exception as e :
@@ -1013,11 +1021,12 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10131021 reason = "query_rejected" ,
10141022 failure_type = type (e ).__name__ ,
10151023 stack_trace = traceback .format_exc (),
1024+ client = client ,
10161025 )
10171026 return "failed"
10181027
10191028 try :
1020- await self . client .complete_query_task (
1029+ await client .complete_query_task (
10211030 query_task_id = query_task_id ,
10221031 lease_owner = self .worker_id ,
10231032 query_task_attempt = attempt ,
@@ -1044,6 +1053,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10441053 reason = server_reason if server_reason else "query_result_completion_failed" ,
10451054 failure_type = type (e ).__name__ ,
10461055 stack_trace = traceback .format_exc (),
1056+ client = client ,
10471057 )
10481058 return "failed"
10491059 except AvroNotInstalledError as e :
@@ -1057,6 +1067,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10571067 reason = "query_result_encode_failed" ,
10581068 failure_type = type (e ).__name__ ,
10591069 stack_trace = traceback .format_exc (),
1070+ client = client ,
10601071 )
10611072 return "failed"
10621073 except (TypeError , ValueError ) as e :
@@ -1067,6 +1078,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10671078 reason = "query_result_encode_failed" ,
10681079 failure_type = type (e ).__name__ ,
10691080 stack_trace = traceback .format_exc (),
1081+ client = client ,
10701082 )
10711083 return "failed"
10721084 except Exception as e :
@@ -1077,6 +1089,7 @@ async def _run_query_task_core(self, task: dict[str, Any]) -> str:
10771089 reason = "query_result_completion_failed" ,
10781090 failure_type = type (e ).__name__ ,
10791091 stack_trace = traceback .format_exc (),
1092+ client = client ,
10801093 )
10811094 return "failed"
10821095
@@ -1107,9 +1120,11 @@ async def _fail_query_task(
11071120 reason : str ,
11081121 failure_type : str | None = None ,
11091122 stack_trace : str | None = None ,
1123+ client : Client | None = None ,
11101124 ) -> None :
1125+ client = client or self .client
11111126 try :
1112- await self . client .fail_query_task (
1127+ await client .fail_query_task (
11131128 query_task_id = query_task_id ,
11141129 lease_owner = self .worker_id ,
11151130 query_task_attempt = attempt ,
@@ -1217,11 +1232,12 @@ async def _dispatch_activity_task(self, task: dict[str, Any]) -> None:
12171232 self ._record_task_metrics ("activity" , outcome , time .perf_counter () - task_start )
12181233 self ._act_semaphore .release ()
12191234
1220- async def _poll_query_tasks (self ) -> None :
1221- while not self ._stop .is_set ():
1235+ async def _poll_query_tasks (self , * , client : Client | None = None , track_tasks : bool = True ) -> None :
1236+ client = client or self .client
1237+ while not self ._stop .is_set () and not self ._query_thread_stop .is_set ():
12221238 try :
12231239 poll_start = time .perf_counter ()
1224- task = await self . client .poll_query_task (
1240+ task = await client .poll_query_task (
12251241 worker_id = self .worker_id ,
12261242 task_queue = self .task_queue ,
12271243 timeout = self ._poll_timeout ,
@@ -1236,18 +1252,86 @@ async def _poll_query_tasks(self) -> None:
12361252 await asyncio .sleep (0 )
12371253 continue
12381254 self ._record_poll_metrics ("query" , "task" , time .perf_counter () - poll_start )
1239- self ._track (self ._dispatch_query_task (task ))
1255+ if track_tasks :
1256+ self ._track (self ._dispatch_query_task (task , client = client ))
1257+ else :
1258+ await self ._dispatch_query_task (task , client = client )
12401259
1241- async def _dispatch_query_task (self , task : dict [str , Any ]) -> None :
1260+ async def _dispatch_query_task (self , task : dict [str , Any ], * , client : Client | None = None ) -> None :
12421261 task_start = time .perf_counter ()
12431262 outcome = "error"
12441263 try :
1245- outcome = await self ._run_query_task (task )
1264+ outcome = await self ._run_query_task (task , client = client )
12461265 except Exception :
12471266 log .exception ("unhandled error in query task execution" )
12481267 finally :
12491268 self ._record_task_metrics ("query" , outcome , time .perf_counter () - task_start )
12501269
1270+ def _clone_client_for_query_tasks (self ) -> Client :
1271+ warning_config = self ._payload_size_warning_config ()
1272+
1273+ return Client (
1274+ self .client .base_url ,
1275+ token = self .client .token ,
1276+ control_token = self .client .control_token ,
1277+ worker_token = self .client .worker_token ,
1278+ namespace = self .client .namespace ,
1279+ timeout = getattr (self .client , "timeout" , 60.0 ),
1280+ retry_policy = self .client .retry_policy ,
1281+ metrics = self .metrics ,
1282+ payload_size_limit_bytes = (
1283+ warning_config .limit_bytes
1284+ if warning_config is not None
1285+ else serializer .DEFAULT_PAYLOAD_SIZE_BYTES
1286+ ),
1287+ payload_size_warning_threshold_percent = (
1288+ warning_config .threshold_percent
1289+ if warning_config is not None
1290+ else serializer .DEFAULT_WARNING_THRESHOLD_PERCENT
1291+ ),
1292+ payload_size_warnings = warning_config is not None ,
1293+ external_storage = self .external_storage ,
1294+ external_storage_threshold_bytes = self .external_storage_threshold_bytes ,
1295+ external_storage_cache = self .external_storage_cache ,
1296+ )
1297+
1298+ def _can_clone_client_for_query_tasks (self ) -> bool :
1299+ return type (self .client ) is Client
1300+
1301+ def _start_query_task_thread (self ) -> None :
1302+ if self ._query_thread is not None and self ._query_thread .is_alive ():
1303+ return
1304+
1305+ self ._query_thread_stop .clear ()
1306+ self ._query_thread = threading .Thread (
1307+ target = self ._run_query_task_thread ,
1308+ name = f"durable-workflow-query-poller-{ self .worker_id } " ,
1309+ daemon = True ,
1310+ )
1311+ self ._query_thread .start ()
1312+
1313+ def _run_query_task_thread (self ) -> None :
1314+ try :
1315+ asyncio .run (self ._query_task_thread_main ())
1316+ except Exception :
1317+ log .exception ("query task poller thread stopped unexpectedly" )
1318+
1319+ async def _query_task_thread_main (self ) -> None :
1320+ async with self ._clone_client_for_query_tasks () as client :
1321+ await self ._poll_query_tasks (client = client , track_tasks = False )
1322+
1323+ async def _stop_query_task_thread (self ) -> None :
1324+ self ._query_thread_stop .set ()
1325+ thread = self ._query_thread
1326+
1327+ if thread is None or not thread .is_alive ():
1328+ return
1329+
1330+ await asyncio .to_thread (
1331+ thread .join ,
1332+ min (self ._shutdown_timeout , self ._poll_timeout + 1 ),
1333+ )
1334+
12511335 async def run (self ) -> None :
12521336 """Register the worker and poll until `stop()` is called or the task is cancelled."""
12531337 await self ._register ()
@@ -1256,9 +1340,15 @@ async def run(self) -> None:
12561340 hb_loop = asyncio .create_task (self ._heartbeat_loop ())
12571341 loops = [wf_loop , act_loop , hb_loop ]
12581342 if self ._query_tasks_supported :
1259- loops .append (asyncio .create_task (self ._poll_query_tasks ()))
1260- with contextlib .suppress (asyncio .CancelledError ):
1261- await asyncio .gather (* loops )
1343+ if self ._can_clone_client_for_query_tasks ():
1344+ self ._start_query_task_thread ()
1345+ else :
1346+ loops .append (asyncio .create_task (self ._poll_query_tasks ()))
1347+ try :
1348+ with contextlib .suppress (asyncio .CancelledError ):
1349+ await asyncio .gather (* loops )
1350+ finally :
1351+ await self ._stop_query_task_thread ()
12621352
12631353 async def _heartbeat_loop (self ) -> None :
12641354 """Periodically refresh the server-side worker registration.
@@ -1386,7 +1476,10 @@ async def run_until(
13861476 await self ._register ()
13871477 background_tasks .append (asyncio .create_task (self ._heartbeat_loop ()))
13881478 if self ._query_tasks_supported :
1389- background_tasks .append (asyncio .create_task (self ._poll_query_tasks ()))
1479+ if self ._can_clone_client_for_query_tasks ():
1480+ self ._start_query_task_thread ()
1481+ else :
1482+ background_tasks .append (asyncio .create_task (self ._poll_query_tasks ()))
13901483
13911484 while True :
13921485 desc = await self .client .describe_workflow (workflow_id )
@@ -1465,6 +1558,7 @@ async def run_until(
14651558 async def stop (self ) -> None :
14661559 """Stop polling and drain in-flight tasks up to the configured shutdown timeout."""
14671560 self ._stop .set ()
1561+ await self ._stop_query_task_thread ()
14681562 if self ._in_flight :
14691563 log .info ("draining %d in-flight task(s)…" , len (self ._in_flight ))
14701564 done , pending = await asyncio .wait (
0 commit comments