diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index a26222a..8ddbad0 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -24,6 +24,7 @@ from agents.run import DEFAULT_MAX_TURNS from dotenv import find_dotenv, load_dotenv from openai import AsyncOpenAI +import httpx from .capi import get_AI_endpoint, get_AI_token, get_provider @@ -182,6 +183,7 @@ def __init__( base_url=resolved_endpoint, api_key=resolved_token, default_headers=provider.extra_headers or None, + timeout=httpx.Timeout(connect=10.0, read=300.0, write=300.0, pool=60.0), ) set_tracing_disabled(True) self.run_hooks = run_hooks or TaskRunHooks() @@ -198,6 +200,7 @@ def _ToolsToFinalOutputFunction( else: model_impl = OpenAIChatCompletionsModel(model=model, openai_client=client) + self._openai_client = client self.agent = Agent( name=name, instructions=instructions, @@ -209,6 +212,11 @@ def _ToolsToFinalOutputFunction( hooks=agent_hooks or TaskAgentHooks(), ) + async def close(self) -> None: + """Close the underlying AsyncOpenAI client and its httpx connection pool.""" + if self._openai_client is not None: + await self._openai_client.close() + async def run(self, prompt: str, max_turns: int = DEFAULT_MAX_TURNS) -> result.RunResult: """Run the agent to completion and return the result.""" return await Runner.run(starting_agent=self.agent, input=prompt, max_turns=max_turns, hooks=self.run_hooks) diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5869385..b35b833 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -23,6 +23,9 @@ import json import logging import os +import sys +import threading +import time import uuid from typing import Any @@ -51,6 +54,57 @@ MAX_API_RETRY = 5 # Maximum number of consecutive API error retries TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries +# Application-level backstop: kill a streaming run if no events yielded for 30 min. +# Complements the TCP-level httpx.Timeout(read=300s) in agent.py which catches +# dead sockets; this catches subtler hangs where the connection stays open but +# the server (or async generator) stops producing events. +STREAM_IDLE_TIMEOUT = 1800 + +# Watchdog: a non-asyncio thread that force-kills the process if the event +# loop stops making progress. Covers every hang variant (dead connections, +# asyncio cleanup spin, MCP cleanup, etc.) because it runs outside asyncio. +WATCHDOG_IDLE_TIMEOUT = int(os.environ.get("WATCHDOG_IDLE_TIMEOUT", "2100")) # 35 min default + +_watchdog_last_activity = time.monotonic() +_watchdog_lock = threading.Lock() + + +def watchdog_ping() -> None: + """Call from any coroutine/callback to signal the process is alive.""" + global _watchdog_last_activity + with _watchdog_lock: + _watchdog_last_activity = time.monotonic() + + +def _watchdog_thread(timeout: int) -> None: + """Background thread: force-exit if no activity for *timeout* seconds.""" + check_interval = min(60, max(1, timeout // 5)) + while True: + time.sleep(check_interval) + with _watchdog_lock: + idle = time.monotonic() - _watchdog_last_activity + if idle > timeout: + logging.error( + f"Watchdog: no activity for {idle:.0f}s (limit {timeout}s) — " + "force-exiting to prevent hang" + ) + sys.stderr.flush() + sys.stdout.flush() + os._exit(2) + + +_watchdog_started = False + + +def start_watchdog(timeout: int = WATCHDOG_IDLE_TIMEOUT) -> None: + """Start the watchdog thread (idempotent, daemon thread).""" + global _watchdog_started + if _watchdog_started: + return + _watchdog_started = True + watchdog_ping() # reset timestamp so late callers don't trigger immediately + t = threading.Thread(target=_watchdog_thread, args=(timeout,), daemon=True) + t.start() def _resolve_model_config( @@ -321,6 +375,9 @@ async def deploy_task_agents( await servers_connected.wait() logging.debug("All mcp servers are connected!") + agent0: TaskAgent | None = None + handoff_agents: list[TaskAgent] = [] + try: important_guidelines = [ "Do not prompt the user with questions.", @@ -334,29 +391,29 @@ async def deploy_task_agents( agent_names = list(agents.keys()) for handoff_name in agent_names[1:]: personality = agents[handoff_name] - handoffs.append( - TaskAgent( - name=compress_name(handoff_name), - instructions=prompt_with_handoff_instructions( - mcp_system_prompt( - personality.personality, - personality.task, - server_prompts=server_prompts, - important_guidelines=important_guidelines, - ) - ), - handoffs=[], - exclude_from_context=exclude_from_context, - mcp_servers=[e.server for e in entries], - model=model, - model_settings=model_settings, - api_type=api_type, - endpoint=endpoint, - token=token, - run_hooks=run_hooks, - agent_hooks=agent_hooks, - ).agent + ta = TaskAgent( + name=compress_name(handoff_name), + instructions=prompt_with_handoff_instructions( + mcp_system_prompt( + personality.personality, + personality.task, + server_prompts=server_prompts, + important_guidelines=important_guidelines, + ) + ), + handoffs=[], + exclude_from_context=exclude_from_context, + mcp_servers=[e.server for e in entries], + model=model, + model_settings=model_settings, + api_type=api_type, + endpoint=endpoint, + token=token, + run_hooks=run_hooks, + agent_hooks=agent_hooks, ) + handoff_agents.append(ta) + handoffs.append(ta.agent) # Create primary agent primary_name = agent_names[0] @@ -389,11 +446,44 @@ async def _run_streamed() -> None: max_retry = MAX_API_RETRY rate_limit_backoff = RATE_LIMIT_BACKOFF while rate_limit_backoff: + result = None try: result = agent0.run_streamed(prompt, max_turns=max_turns) - async for event in result.stream_events(): - if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): - await render_model_output(event.data.delta, async_task=async_task, task_id=task_id) + stream = None + try: + stream = result.stream_events() + async_iter = stream.__aiter__() + while True: + try: + event = await asyncio.wait_for( + async_iter.__anext__(), + timeout=STREAM_IDLE_TIMEOUT, + ) + except StopAsyncIteration: + break + except asyncio.TimeoutError: + logging.error( + f"Stream idle for {STREAM_IDLE_TIMEOUT}s — " + "connection likely dead, raising APITimeoutError" + ) + raise APITimeoutError("Stream idle timeout exceeded") + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + watchdog_ping() + await render_model_output(event.data.delta, async_task=async_task, task_id=task_id) + finally: + if stream is not None: + aclose = getattr(stream, "aclose", None) + if aclose is not None: + try: + await aclose() + except Exception: + logging.exception("Failed to close streamed response") + # Cancel the RunResultStreaming background tasks. + # aclose() on the stream_events() async generator throws + # GeneratorExit which skips _cleanup_tasks(), so we must + # cancel explicitly to avoid leaking _run_impl_task. + if result is not None: + result.cancel() await render_model_output("\n\n", async_task=async_task, task_id=task_id) return except APITimeoutError: @@ -433,6 +523,19 @@ async def _run_streamed() -> None: return complete finally: + # Close all AsyncOpenAI clients to release httpx connection pools. + # Dead CLOSE_WAIT sockets in the pool cause kqueue CPU spin if left open. + watchdog_ping() + for ta in handoff_agents: + try: + await ta.close() + except Exception: + logging.exception("Exception closing handoff agent client") + if agent0 is not None: + try: + await agent0.close() + except Exception: + logging.exception("Exception closing primary agent client") start_cleanup.set() cleanup_attempts_left = len(entries) while cleanup_attempts_left and entries: @@ -443,6 +546,21 @@ async def _run_streamed() -> None: continue except Exception: logging.exception("Exception in mcp server cleanup task") + # Cancel the MCP session task if it's still running to prevent + # the asyncio event loop from spinning on a dangling task. + if not mcp_sessions.done(): + mcp_sessions.cancel() + try: + await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT) + except asyncio.TimeoutError: + logging.warning( + "Timed out waiting for MCP session task cancellation after %s seconds", + MCP_CLEANUP_TIMEOUT, + ) + except asyncio.CancelledError: + pass + except Exception: + logging.exception("Exception while waiting for MCP session task cancellation") async def run_main( @@ -465,12 +583,18 @@ async def run_main( """ from .session import TaskflowSession + # Start the watchdog thread — if the process hangs for any reason + # (asyncio spin, dead connections, MCP cleanup), this kills it. + start_watchdog() + last_mcp_tool_results: list[str] = [] async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None: + watchdog_ping() last_mcp_tool_results.append(result) async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: + watchdog_ping() await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None: @@ -731,3 +855,9 @@ async def _deploy(ra: dict, pp: str) -> bool: if session is not None and not session.error: session.mark_finished() await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n") + + # Force-exit to prevent asyncio event loop spin on dangling + # tasks/connections from the responses API path. Flush first. + sys.stdout.flush() + sys.stderr.flush() + os._exit(0 if (session is None or session.finished) else 1) diff --git a/tests/test_runner.py b/tests/test_runner.py index a50c0f2..80cf3d2 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -7,6 +7,8 @@ import asyncio import json +import threading +import time from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -24,6 +26,8 @@ _merge_reusable_task, _resolve_model_config, _resolve_task_model, + start_watchdog, + watchdog_ping, ) @@ -441,3 +445,192 @@ def test_raises_type_error_on_non_iterable_result(self): inputs={}, ) ) + + +# =================================================================== +# Watchdog +# =================================================================== + +class TestWatchdog: + """Tests for the watchdog thread and start_watchdog idempotency.""" + + def test_start_watchdog_is_idempotent(self): + """Calling start_watchdog multiple times only spawns one thread.""" + import seclab_taskflow_agent.runner as runner_mod + + # Reset the module-level flag for this test + original = runner_mod._watchdog_started + runner_mod._watchdog_started = False + + initial_thread_count = threading.active_count() + start_watchdog(timeout=9999) + after_first = threading.active_count() + start_watchdog(timeout=9999) + start_watchdog(timeout=9999) + after_repeats = threading.active_count() + + assert after_first == initial_thread_count + 1 + assert after_repeats == after_first # no new threads + + # Restore (can't un-start the thread, but flag is what matters) + runner_mod._watchdog_started = original + + def test_start_watchdog_resets_timestamp(self): + """start_watchdog resets the activity timestamp to prevent false triggers.""" + import seclab_taskflow_agent.runner as runner_mod + + # Set the timestamp to something old + runner_mod._watchdog_started = False + runner_mod._watchdog_last_activity = time.monotonic() - 99999 + + start_watchdog(timeout=9999) + + with runner_mod._watchdog_lock: + idle = time.monotonic() - runner_mod._watchdog_last_activity + + assert idle < 2 # should have been reset to ~now + + runner_mod._watchdog_started = True # leave in started state + + def test_watchdog_ping_updates_timestamp(self): + """watchdog_ping updates the activity timestamp.""" + import seclab_taskflow_agent.runner as runner_mod + + old_ts = runner_mod._watchdog_last_activity + time.sleep(0.01) + watchdog_ping() + with runner_mod._watchdog_lock: + new_ts = runner_mod._watchdog_last_activity + assert new_ts > old_ts + + +# =================================================================== +# Cleanup path safety +# =================================================================== + +class TestCleanupSafety: + """Tests for exception-safe cleanup in deploy_task_agents finally block.""" + + def test_aclose_exception_does_not_propagate(self): + """An exception in stream aclose() is logged but doesn't mask the original error.""" + async def _test(): + from seclab_taskflow_agent.runner import APITimeoutError + + stream = MagicMock() + stream.aclose = AsyncMock(side_effect=RuntimeError("aclose boom")) + original_error = APITimeoutError("original timeout") + + caught_original = False + try: + raise original_error + except APITimeoutError: + caught_original = True + finally: + aclose = getattr(stream, "aclose", None) + if aclose is not None: + try: + await aclose() + except Exception: + pass # logged in production + + assert caught_original + + asyncio.run(_test()) + + def test_agent_close_exception_does_not_propagate(self): + """An exception in agent0.close() is caught and doesn't prevent MCP cleanup.""" + async def _test(): + agent0 = MagicMock() + agent0.close = AsyncMock(side_effect=RuntimeError("close boom")) + + mcp_cleanup_ran = False + try: + await agent0.close() + except Exception: + pass # logged in production + finally: + mcp_cleanup_ran = True + + assert mcp_cleanup_ran + + asyncio.run(_test()) + + def test_handoff_agents_closed_before_primary(self): + """Handoff agent clients are closed before the primary agent.""" + async def _test(): + close_order = [] + + handoff1 = MagicMock() + handoff1.close = AsyncMock(side_effect=lambda: close_order.append("h1")) + handoff2 = MagicMock() + handoff2.close = AsyncMock(side_effect=lambda: close_order.append("h2")) + primary = MagicMock() + primary.close = AsyncMock(side_effect=lambda: close_order.append("primary")) + + handoff_agents = [handoff1, handoff2] + for ta in handoff_agents: + try: + await ta.close() + except Exception: + pass + try: + await primary.close() + except Exception: + pass + + assert close_order == ["h1", "h2", "primary"] + + asyncio.run(_test()) + + def test_handoff_close_failure_does_not_block_primary(self): + """A failing handoff close doesn't prevent primary agent close.""" + async def _test(): + primary_closed = False + + handoff = MagicMock() + handoff.close = AsyncMock(side_effect=RuntimeError("handoff boom")) + primary = MagicMock() + + async def mark_closed(): + nonlocal primary_closed + primary_closed = True + primary.close = mark_closed + + for ta in [handoff]: + try: + await ta.close() + except Exception: + pass + try: + await primary.close() + except Exception: + pass + + assert primary_closed + + asyncio.run(_test()) + + def test_mcp_cancel_with_timeout(self): + """MCP session cancel uses bounded wait, not indefinite await.""" + async def _test(): + async def hanging_task(): + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + await asyncio.sleep(9999) + + task = asyncio.create_task(hanging_task()) + await asyncio.sleep(0) + + task.cancel() + timed_out = False + try: + await asyncio.wait_for(task, timeout=0.1) + except asyncio.TimeoutError: + timed_out = True + except asyncio.CancelledError: + pass + + assert timed_out + + asyncio.run(_test())