diff --git a/src/seclab_taskflow_agent/__main__.py b/src/seclab_taskflow_agent/__main__.py index 4f96cea..7b1d865 100644 --- a/src/seclab_taskflow_agent/__main__.py +++ b/src/seclab_taskflow_agent/__main__.py @@ -18,6 +18,8 @@ # from agents.run import DEFAULT_MAX_TURNS # XXX: this is 10, we need more than that from agents.exceptions import AgentsException, MaxTurnsExceeded from agents.extensions.handoff_prompt import prompt_with_handoff_instructions +from agents import Tool, RunContextWrapper, TContext, Agent +from openai import BadRequestError, APITimeoutError, RateLimitError, APIStatusError from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter from dotenv import find_dotenv, load_dotenv from openai import APITimeoutError, BadRequestError, RateLimitError @@ -350,7 +352,16 @@ async def _run_streamed(): return except APITimeoutError: if not max_retry: - logging.exception("Max retries for APITimeoutError reached") + logging.error(f"Max API retries reached") + raise + max_retry -= 1 + except APIStatusError as e: + # Retry transient “client closed request / upstream cancelled” style errors + if getattr(e, "status_code", None) != 499: + raise # propagate non-499 errors + # 499: retry + if not max_retry: + logging.error(f"Max API retries reached") raise max_retry -= 1 except RateLimitError: @@ -377,8 +388,15 @@ async def _run_streamed(): await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id) logging.exception("Bad Request") except APITimeoutError as e: - await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id) - logging.exception("Bad Request") + await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", + async_task=async_task, + task_id=task_id) + logging.error(f"Bad Request: {e}") + except APIStatusError as e: + await render_model_output(f"** 🤖❗ API Status Error: {e}\n", + async_task=async_task, + task_id=task_id) + logging.error(f"API Status Error: Status={e.status_code}, Response={e.response}") if async_task: await flush_async_output(task_id)