@@ -580,6 +580,30 @@ def _server_supports_query_tasks(info: dict[str, Any]) -> bool:
580580 return isinstance (capabilities , dict ) and capabilities .get ("query_tasks" ) is True
581581
582582
583+ def _server_long_poll_timeout (info : dict [str , Any ]) -> float | None :
584+ worker_protocol = info .get ("worker_protocol" )
585+ if not isinstance (worker_protocol , dict ):
586+ return None
587+
588+ capabilities = worker_protocol .get ("server_capabilities" )
589+ if not isinstance (capabilities , dict ):
590+ return None
591+
592+ timeout = capabilities .get ("long_poll_timeout" )
593+ if isinstance (timeout , bool ):
594+ return None
595+ if isinstance (timeout , (int , float )):
596+ return float (timeout ) if timeout > 0 else None
597+ if isinstance (timeout , str ):
598+ try :
599+ parsed = float (timeout )
600+ except ValueError :
601+ return None
602+ return parsed if parsed > 0 else None
603+
604+ return None
605+
606+
583607def _contract_version_matches (value : Any , expected : int ) -> bool :
584608 if isinstance (value , int ):
585609 return value == expected
@@ -635,6 +659,7 @@ def __init__(
635659 raise ValueError ("heartbeat_interval must be positive" )
636660
637661 self ._poll_timeout = poll_timeout
662+ self ._poll_http_timeout = poll_timeout
638663 self .max_concurrent_workflow_tasks = max_concurrent_workflow_tasks
639664 self .max_concurrent_activity_tasks = max_concurrent_activity_tasks
640665 self ._stop = asyncio .Event ()
@@ -757,6 +782,9 @@ async def _register(self) -> None:
757782
758783 _validate_server_compatibility (info )
759784 self ._query_tasks_supported = _server_supports_query_tasks (info )
785+ server_long_poll_timeout = _server_long_poll_timeout (info )
786+ if server_long_poll_timeout is not None :
787+ self ._poll_http_timeout = max (self ._poll_http_timeout , server_long_poll_timeout + 5.0 )
760788 log .debug (
761789 "server compatibility accepted: app_version=%s control_plane=%s worker_protocol=%s" ,
762790 info .get ("version" , "unknown" ),
@@ -1536,7 +1564,7 @@ async def _poll_workflow_tasks(self) -> None:
15361564 task = await self .client .poll_workflow_task (
15371565 worker_id = self .worker_id ,
15381566 task_queue = self .task_queue ,
1539- timeout = self ._poll_timeout ,
1567+ timeout = self ._poll_http_timeout ,
15401568 )
15411569 except Exception as e :
15421570 self ._wf_semaphore .release ()
@@ -1613,7 +1641,7 @@ async def _poll_activity_tasks(self) -> None:
16131641 task = await self .client .poll_activity_task (
16141642 worker_id = self .worker_id ,
16151643 task_queue = self .task_queue ,
1616- timeout = self ._poll_timeout ,
1644+ timeout = self ._poll_http_timeout ,
16171645 )
16181646 except Exception as e :
16191647 self ._act_semaphore .release ()
@@ -1650,7 +1678,7 @@ async def _poll_query_tasks(self, *, client: Client | None = None, track_tasks:
16501678 task = await client .poll_query_task (
16511679 worker_id = self .worker_id ,
16521680 task_queue = self .task_queue ,
1653- timeout = self ._poll_timeout ,
1681+ timeout = self ._poll_http_timeout ,
16541682 )
16551683 except Exception as e :
16561684 self ._record_poll_metrics ("query" , "error" , time .perf_counter () - poll_start )
@@ -1739,7 +1767,7 @@ async def _stop_query_task_thread(self) -> None:
17391767
17401768 await asyncio .to_thread (
17411769 thread .join ,
1742- min (self ._shutdown_timeout , self ._poll_timeout + 1 ),
1770+ min (self ._shutdown_timeout , self ._poll_http_timeout + 1 ),
17431771 )
17441772
17451773 async def run (self ) -> None :
@@ -1908,7 +1936,7 @@ async def run_until(
19081936 task = await self .client .poll_workflow_task (
19091937 worker_id = self .worker_id ,
19101938 task_queue = self .task_queue ,
1911- timeout = self ._poll_timeout ,
1939+ timeout = self ._poll_http_timeout ,
19121940 )
19131941 self ._record_poll_metrics (
19141942 "workflow" ,
@@ -1938,7 +1966,7 @@ async def run_until(
19381966 task = await self .client .poll_activity_task (
19391967 worker_id = self .worker_id ,
19401968 task_queue = self .task_queue ,
1941- timeout = self ._poll_timeout ,
1969+ timeout = self ._poll_http_timeout ,
19421970 )
19431971 self ._record_poll_metrics (
19441972 "activity" ,
0 commit comments