@@ -82,9 +82,25 @@ def __init__(
8282 )
8383 )
8484
85- # Auth failure backoff tracking to prevent retry storms
85+ # Auth failure backoff tracking to prevent retry storms.
86+ # `_auth_failures` is capped at `_max_auth_failure_exp` so that
87+ # 2**N cannot overflow on a long-lived worker whose auth is broken.
88+ # The resulting sleep is further clamped to `_auth_backoff_cap_seconds`.
8689 self ._auth_failures = 0
8790 self ._last_auth_failure = 0
91+ self ._auth_backoff_cap_seconds = 60
92+ self ._max_auth_failure_exp = 6 # 2**6 = 64s, sleep clamped to cap
93+
94+ # Generic poll-failure backoff. This is distinct from the empty-poll
95+ # adaptive delay (`_consecutive_empty_polls`) and from the auth-error
96+ # backoff above. It kicks in when batch_poll raises an exception
97+ # (server 5xx, NGINX 502/504 under load, DNS hiccup, a closed httpx
98+ # client that couldn't heal, etc.) so we don't hot-loop the log with
99+ # stack traces while waiting for the server to recover.
100+ self ._poll_failures = 0
101+ self ._last_poll_failure = 0
102+ self ._poll_backoff_cap_seconds = 120 # max 2 minutes between retries
103+ self ._max_poll_failure_exp = 7 # 2**7 = 128s, sleep clamped to cap
88104
89105 # Thread pool for concurrent task execution
90106 # thread_count from worker configuration controls concurrency
@@ -567,15 +583,33 @@ def __batch_poll_tasks(self, count: int) -> list:
567583 logger .debug ("Stop polling task for: %s" , task_definition_name )
568584 return []
569585
570- # Apply exponential backoff if we have recent auth failures
586+ # Apply exponential backoff if we have recent auth failures.
571587 if self ._auth_failures > 0 :
572588 now = time .time ()
573- backoff_seconds = min (2 ** self ._auth_failures , 60 )
589+ backoff_seconds = min (
590+ 2 ** min (self ._auth_failures , self ._max_auth_failure_exp ),
591+ self ._auth_backoff_cap_seconds ,
592+ )
574593 time_since_last_failure = now - self ._last_auth_failure
575594 if time_since_last_failure < backoff_seconds :
576595 time .sleep (0.1 )
577596 return []
578597
598+ # Apply exponential backoff for generic poll failures (5xx, network
599+ # errors, closed-client runtime errors that couldn't self-heal, etc.).
600+ # Bounded at `_poll_backoff_cap_seconds` (2 min) to avoid log floods
601+ # without giving up on recovery.
602+ if self ._poll_failures > 0 :
603+ now = time .time ()
604+ backoff_seconds = min (
605+ 2 ** min (self ._poll_failures , self ._max_poll_failure_exp ),
606+ self ._poll_backoff_cap_seconds ,
607+ )
608+ time_since_last_failure = now - self ._last_poll_failure
609+ if time_since_last_failure < backoff_seconds :
610+ time .sleep (0.1 )
611+ return []
612+
579613 # Publish PollStarted event (metrics collector will handle via event)
580614 self .event_dispatcher .publish (PollStarted (
581615 task_type = task_definition_name ,
@@ -607,15 +641,20 @@ def __batch_poll_tasks(self, count: int) -> list:
607641 tasks_received = len (tasks ) if tasks else 0
608642 ))
609643
610- # Success - reset auth failure counter (any successful HTTP response means auth is working)
644+ # Success - reset both failure counters (any successful HTTP
645+ # response means auth and connectivity are working).
611646 self ._auth_failures = 0
647+ self ._poll_failures = 0
612648
613649 return tasks if tasks else []
614650
615651 except AuthorizationException as auth_exception :
616652 self ._auth_failures += 1
617653 self ._last_auth_failure = time .time ()
618- backoff_seconds = min (2 ** self ._auth_failures , 60 )
654+ backoff_seconds = min (
655+ 2 ** min (self ._auth_failures , self ._max_auth_failure_exp ),
656+ self ._auth_backoff_cap_seconds ,
657+ )
619658
620659 # Publish PollFailure event (metrics collector will handle via event)
621660 self .event_dispatcher .publish (PollFailure (
@@ -643,10 +682,55 @@ def __batch_poll_tasks(self, count: int) -> list:
643682 duration_ms = (time .time () - start_time ) * 1000 ,
644683 cause = e
645684 ))
646- logger .error (
647- "Failed to batch poll task for: %s, reason: %s" ,
685+
686+ # Bump the poll-failure counter so the next poll waits with
687+ # exponential backoff instead of hot-looping on a broken server
688+ # or connection.
689+ self ._poll_failures += 1
690+ self ._last_poll_failure = time .time ()
691+ backoff_seconds = min (
692+ 2 ** min (self ._poll_failures , self ._max_poll_failure_exp ),
693+ self ._poll_backoff_cap_seconds ,
694+ )
695+
696+ # Belt-and-suspenders: if the underlying httpx client got closed
697+ # and rest.request() couldn't heal it (e.g. because the error
698+ # arrived as a non-RuntimeError), nudge it here. Pass the current
699+ # connection as `expected` so concurrent threads racing to heal
700+ # can't cause a reset storm: only the first caller per client
701+ # generation actually replaces it.
702+ try :
703+ rest_client = getattr (
704+ getattr (self .task_client , "api_client" , None ),
705+ "rest_client" ,
706+ None ,
707+ )
708+ if rest_client is not None and getattr (rest_client , "_is_client_closed" , lambda : False )():
709+ current_conn = getattr (rest_client , "connection" , None )
710+ reset = rest_client ._reset_connection (expected = current_conn )
711+ if reset :
712+ logger .warning (
713+ "rest_client was closed after poll failure; reset"
714+ )
715+ except Exception :
716+ # Healing is best-effort; never let it mask the original error.
717+ pass
718+
719+ # Log a single-line warning at a modest level to avoid drowning
720+ # ops in tracebacks when the server is flapping. Full traceback
721+ # goes to debug for when operators need it.
722+ logger .warning (
723+ "Failed to batch poll task for: %s (failure #%d). Will retry with exponential backoff (%ss). Reason: %s: %s" ,
648724 task_definition_name ,
649- traceback .format_exc ()
725+ self ._poll_failures ,
726+ backoff_seconds ,
727+ type (e ).__name__ ,
728+ e ,
729+ )
730+ logger .debug (
731+ "batch poll failure traceback for %s:\n %s" ,
732+ task_definition_name ,
733+ traceback .format_exc (),
650734 )
651735 return []
652736
@@ -915,15 +999,33 @@ def __update_task(self, task_result: TaskResult):
915999 self .metrics_collector .increment_task_update_error (
9161000 task_definition_name , type (e )
9171001 )
918- logger .error (
919- "Failed to update task (attempt %d/%d), id: %s, workflow_instance_id: %s, task_definition_name: %s, reason: %s" ,
920- attempt + 1 ,
921- retry_count ,
922- task_result .task_id ,
923- task_result .workflow_instance_id ,
924- task_definition_name ,
925- traceback .format_exc ()
926- )
1002+ is_last_attempt = (attempt + 1 ) >= retry_count
1003+ # Known recoverable transport hiccups (stale keep-alive,
1004+ # HTTP/2 GOAWAY race, client closed mid-request) are flagged
1005+ # `transient=True` by the REST layer after it self-heals. For
1006+ # those, skip the stack trace until the final attempt — the
1007+ # retry normally succeeds immediately and a full traceback per
1008+ # in-flight task just spams the log.
1009+ if getattr (e , "transient" , False ) and not is_last_attempt :
1010+ logger .warning (
1011+ "Transient transport error updating task; will retry (attempt %d/%d), id: %s, workflow_instance_id: %s, task_definition_name: %s, reason: %s" ,
1012+ attempt + 1 ,
1013+ retry_count ,
1014+ task_result .task_id ,
1015+ task_result .workflow_instance_id ,
1016+ task_definition_name ,
1017+ getattr (e , "reason" , None ) or str (e ),
1018+ )
1019+ else :
1020+ logger .error (
1021+ "Failed to update task (attempt %d/%d), id: %s, workflow_instance_id: %s, task_definition_name: %s, reason: %s" ,
1022+ attempt + 1 ,
1023+ retry_count ,
1024+ task_result .task_id ,
1025+ task_result .workflow_instance_id ,
1026+ task_definition_name ,
1027+ traceback .format_exc ()
1028+ )
9271029 continue
9281030 except Exception as e :
9291031 last_exception = e
0 commit comments