Skip to content

Commit b2de890

Browse files
committed
token refresh logic
1 parent 64f11d2 commit b2de890

4 files changed

Lines changed: 295 additions & 34 deletions

File tree

examples/asyncio_workers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ async def fetch_user(user_id: str) -> dict:
9191
"""
9292
try:
9393
import httpx
94-
print(f'fetching user {user_id}')
94+
# print(f'fetching user {user_id}')
9595
async with httpx.AsyncClient() as client:
9696
response = await client.get(
9797
f'https://jsonplaceholder.typicode.com/users/{user_id}'
9898
)
99-
print(f'response {response.json()}')
99+
# print(f'response {response.json()}')
100100
return response.json()
101101

102102
except Exception as e:
@@ -111,12 +111,12 @@ async def process_user(user: User) -> dict:
111111
"""
112112
try:
113113
import httpx
114-
print(f'fetching user details for {user.id}')
114+
# print(f'fetching user details for {user.id}')
115115
async with httpx.AsyncClient() as client:
116116
response = await client.get(
117117
f'https://jsonplaceholder.typicode.com/users/{user.id + 1}'
118118
)
119-
print(f'response {response.json()}')
119+
# print(f'response {response.json()}')
120120
return response.json()
121121

122122
except Exception as e:

src/conductor/client/automator/task_runner.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(
4747
)
4848
)
4949

50+
# Auth failure backoff tracking to prevent retry storms
51+
self._auth_failures = 0
52+
self._last_auth_failure = 0
53+
5054
def run(self) -> None:
5155
if self.configuration is not None:
5256
self.configuration.apply_logging_config()
@@ -80,6 +84,19 @@ def __poll_task(self) -> Task:
8084
if self.worker.paused():
8185
logger.debug("Stop polling task for: %s", task_definition_name)
8286
return None
87+
88+
# Apply exponential backoff if we have recent auth failures
89+
if self._auth_failures > 0:
90+
now = time.time()
91+
# Exponential backoff: 2^failures seconds (2s, 4s, 8s, 16s, 32s)
92+
backoff_seconds = min(2 ** self._auth_failures, 60) # Cap at 60s
93+
time_since_last_failure = now - self._last_auth_failure
94+
95+
if time_since_last_failure < backoff_seconds:
96+
# Still in backoff period - skip polling
97+
time.sleep(0.1) # Small sleep to prevent tight loop
98+
return None
99+
83100
if self.metrics_collector is not None:
84101
self.metrics_collector.increment_task_poll(
85102
task_definition_name
@@ -97,12 +114,25 @@ def __poll_task(self) -> Task:
97114
if self.metrics_collector is not None:
98115
self.metrics_collector.record_task_poll_time(task_definition_name, time_spent)
99116
except AuthorizationException as auth_exception:
117+
# Track auth failure for backoff
118+
self._auth_failures += 1
119+
self._last_auth_failure = time.time()
120+
backoff_seconds = min(2 ** self._auth_failures, 60)
121+
100122
if self.metrics_collector is not None:
101123
self.metrics_collector.increment_task_poll_error(task_definition_name, type(auth_exception))
124+
102125
if auth_exception.invalid_token:
103-
logger.fatal(f"failed to poll task {task_definition_name} due to invalid auth token")
126+
logger.error(
127+
f"Failed to poll task {task_definition_name} due to invalid auth token "
128+
f"(failure #{self._auth_failures}). Will retry with exponential backoff ({backoff_seconds}s). "
129+
"Please check your CONDUCTOR_AUTH_KEY and CONDUCTOR_AUTH_SECRET."
130+
)
104131
else:
105-
logger.fatal(f"failed to poll task {task_definition_name} error: {auth_exception.status} - {auth_exception.error_code}")
132+
logger.error(
133+
f"Failed to poll task {task_definition_name} error: {auth_exception.status} - {auth_exception.error_code} "
134+
f"(failure #{self._auth_failures}). Will retry with exponential backoff ({backoff_seconds}s)."
135+
)
106136
return None
107137
except Exception as e:
108138
if self.metrics_collector is not None:
@@ -113,13 +143,20 @@ def __poll_task(self) -> Task:
113143
traceback.format_exc()
114144
)
115145
return None
146+
147+
# Success - reset auth failure counter
116148
if task is not None:
149+
self._auth_failures = 0
117150
logger.debug(
118151
"Polled task: %s, worker_id: %s, domain: %s",
119152
task_definition_name,
120153
self.worker.get_identity(),
121154
self.worker.get_domain()
122155
)
156+
else:
157+
# No task available - also reset auth failures since poll succeeded
158+
self._auth_failures = 0
159+
123160
return task
124161

125162
def __execute_task(self, task: Task) -> TaskResult:

src/conductor/client/automator/task_runner_asyncio.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def __init__(
116116
# Track background tasks for proper cleanup
117117
self._background_tasks = set()
118118

119+
# Auth failure backoff tracking to prevent retry storms
120+
self._auth_failures = 0
121+
self._last_auth_failure = 0
122+
119123
self._running = False
120124
self._owns_client = http_client is None
121125

@@ -131,9 +135,9 @@ def _get_auth_headers(self) -> dict:
131135
if self.configuration.authentication_settings is None:
132136
return headers
133137

134-
# Use ApiClient's private method to get auth headers
138+
# Use ApiClient's method to get auth headers
135139
# This handles token generation and refresh automatically
136-
auth_headers = self._api_client._ApiClient__get_authentication_headers()
140+
auth_headers = self._api_client.get_authentication_headers()
137141

138142
if auth_headers and 'header' in auth_headers:
139143
headers.update(auth_headers['header'])
@@ -243,6 +247,18 @@ async def _poll_task(self) -> Optional[Task]:
243247
logger.debug("Worker paused for: %s", task_definition_name)
244248
return None
245249

250+
# Apply exponential backoff if we have recent auth failures
251+
if self._auth_failures > 0:
252+
now = time.time()
253+
# Exponential backoff: 2^failures seconds (2s, 4s, 8s, 16s, 32s)
254+
backoff_seconds = min(2 ** self._auth_failures, 60) # Cap at 60s
255+
time_since_last_failure = now - self._last_auth_failure
256+
257+
if time_since_last_failure < backoff_seconds:
258+
# Still in backoff period - skip polling
259+
await asyncio.sleep(0.1) # Small sleep to prevent tight loop
260+
return None
261+
246262
if self.metrics_collector is not None:
247263
self.metrics_collector.increment_task_poll(task_definition_name)
248264

@@ -283,22 +299,88 @@ async def _poll_task(self) -> Optional[Task]:
283299
# Convert to Task object using cached ApiClient
284300
task = self._api_client.deserialize_class(task_data, Task) if task_data else None
285301

302+
# Success - reset auth failure counter
286303
if task is not None:
304+
self._auth_failures = 0
287305
logger.debug(
288306
"Polled task: %s, worker_id: %s, domain: %s",
289307
task_definition_name,
290308
self.worker.get_identity(),
291309
self.worker.get_domain()
292310
)
311+
else:
312+
# No task available (204) - also reset auth failures
313+
self._auth_failures = 0
293314

294315
return task
295316

296317
except httpx.HTTPStatusError as e:
297318
if e.response.status_code == 401:
298-
logger.fatal(
299-
"Authentication failed for task %s: %s",
300-
task_definition_name, e
301-
)
319+
# Check if this is a token expiry/invalid token (renewable) vs invalid credentials
320+
error_code = None
321+
try:
322+
response_data = e.response.json()
323+
error_code = response_data.get('error', '')
324+
except Exception:
325+
pass
326+
327+
# If token is expired or invalid, try to renew it
328+
if error_code in ('EXPIRED_TOKEN', 'INVALID_TOKEN'):
329+
token_status = "expired" if error_code == 'EXPIRED_TOKEN' else "invalid"
330+
logger.info(
331+
"Authentication token is %s, renewing token... (task: %s)",
332+
token_status,
333+
task_definition_name
334+
)
335+
336+
# Force token refresh (skip backoff - this is a legitimate renewal)
337+
success = self._api_client.force_refresh_auth_token()
338+
339+
if success:
340+
logger.info('Authentication token successfully renewed')
341+
# Retry the poll request with new token
342+
try:
343+
headers = self._get_auth_headers()
344+
response = await self.http_client.get(
345+
f"/tasks/poll/{task_definition_name}",
346+
params=params,
347+
headers=headers if headers else None
348+
)
349+
350+
if response.status_code == 204:
351+
return None
352+
353+
response.raise_for_status()
354+
task_data = response.json()
355+
task = self._api_client.deserialize_class(task_data, Task) if task_data else None
356+
357+
# Success - reset auth failures
358+
self._auth_failures = 0
359+
return task
360+
except Exception as retry_error:
361+
logger.error(
362+
"Failed to poll task %s after token renewal: %s",
363+
task_definition_name,
364+
retry_error
365+
)
366+
return None
367+
else:
368+
logger.error('Failed to renew authentication token')
369+
else:
370+
# Not a token expiry - invalid credentials, apply backoff
371+
self._auth_failures += 1
372+
self._last_auth_failure = time.time()
373+
backoff_seconds = min(2 ** self._auth_failures, 60)
374+
375+
logger.error(
376+
"Authentication failed for task %s (failure #%d): %s. "
377+
"Will retry with exponential backoff (%ds). "
378+
"Please check your CONDUCTOR_AUTH_KEY and CONDUCTOR_AUTH_SECRET.",
379+
task_definition_name,
380+
self._auth_failures,
381+
e,
382+
backoff_seconds
383+
)
302384
else:
303385
logger.error(
304386
"HTTP error polling task %s: %s",
@@ -615,6 +697,67 @@ async def _update_task(self, task_result: TaskResult) -> Optional[str]:
615697

616698
return result
617699

700+
except httpx.HTTPStatusError as e:
701+
# Handle 401 authentication errors specially
702+
if e.response.status_code == 401:
703+
# Check if this is a token expiry/invalid token (renewable) vs invalid credentials
704+
error_code = None
705+
try:
706+
response_data = e.response.json()
707+
error_code = response_data.get('error', '')
708+
except Exception:
709+
pass
710+
711+
# If token is expired or invalid, try to renew it and retry
712+
if error_code in ('EXPIRED_TOKEN', 'INVALID_TOKEN'):
713+
token_status = "expired" if error_code == 'EXPIRED_TOKEN' else "invalid"
714+
logger.info(
715+
"Authentication token is %s, renewing token... (updating task: %s)",
716+
token_status,
717+
task_result.task_id
718+
)
719+
720+
# Force token refresh (skip backoff - this is a legitimate renewal)
721+
success = self._api_client.force_refresh_auth_token()
722+
723+
if success:
724+
logger.info('Authentication token successfully renewed, retrying update')
725+
# Retry the update request with new token once
726+
try:
727+
headers = self._get_auth_headers()
728+
response = await self.http_client.post(
729+
"/tasks",
730+
json=task_result_dict,
731+
headers=headers if headers else None
732+
)
733+
response.raise_for_status()
734+
return response.text
735+
except Exception as retry_error:
736+
logger.error(
737+
"Failed to update task after token renewal: %s",
738+
retry_error
739+
)
740+
# Continue to retry loop
741+
else:
742+
logger.error('Failed to renew authentication token')
743+
# Continue to retry loop
744+
745+
# Fall through to generic exception handling for retries
746+
if self.metrics_collector is not None:
747+
self.metrics_collector.increment_task_update_error(
748+
task_definition_name, type(e)
749+
)
750+
751+
logger.error(
752+
"Failed to update task (attempt %d/4), id: %s, "
753+
"workflow_instance_id: %s, task_definition_name: %s, reason: %s",
754+
attempt + 1,
755+
task_result.task_id,
756+
task_result.workflow_instance_id,
757+
task_definition_name,
758+
traceback.format_exc()
759+
)
760+
618761
except Exception as e:
619762
if self.metrics_collector is not None:
620763
self.metrics_collector.increment_task_update_error(

0 commit comments

Comments
 (0)