diff --git a/README.md b/README.md index 2600fcc..46d50a6 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,12 @@ Failed tasks are automatically retried up to 3 times with increasing backoff before the session is saved. Session checkpoints are stored in the platform-specific application data directory. +### Auto-save tool log + +Set `AUTO_SAVE_DIR` and `AUTO_SAVE_INTERVAL` to enable periodic tool-result +logging. Every N tool calls, the runner appends an NDJSON entry to +`{AUTO_SAVE_DIR}/auto_save_tool_log.ndjson`. Disabled by default (interval=0). + ### Error Output By default, errors are shown as concise one-line messages. Use `--debug` (or diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5869385..3c96fc2 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -16,10 +16,13 @@ "MAX_RATE_LIMIT_BACKOFF", "RATE_LIMIT_BACKOFF", "deploy_task_agents", + "read_tool_log", "run_main", + "write_auto_save", ] import asyncio +import contextlib import json import logging import os @@ -52,6 +55,49 @@ TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries +AUTO_SAVE_LOG_NAME = "auto_save_tool_log.ndjson" + + +def write_auto_save( + auto_save_dir: str, + turn: int, + tool_name: str, + result: str, +) -> None: + """Append tool result to auto-save log (NDJSON, append-only).""" + try: + os.makedirs(auto_save_dir, exist_ok=True) + save_path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME) + entry = json.dumps({ + "turn": turn, + "tool": tool_name, + "result_preview": (result or "")[:2000], + }) + with open(save_path, "a", encoding="utf-8") as f: + f.write(entry + "\n") + except Exception as e: + logging.warning(f"Auto-save failed: {e}") + + +def read_tool_log(auto_save_dir: str) -> list[dict]: + """Read NDJSON auto-save log. Skips malformed lines.""" + if not auto_save_dir: + return [] + entries: list[dict] = [] + try: + path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME) + if os.path.exists(path): + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + with contextlib.suppress(json.JSONDecodeError): + entries.append(json.loads(line)) + except Exception: + logging.debug("Failed to read auto-save tool log", exc_info=True) + return entries + def _resolve_model_config( available_tools: AvailableTools, @@ -467,8 +513,22 @@ async def run_main( last_mcp_tool_results: list[str] = [] + # Auto-save scaffolding: periodically persist tool results to disk. + # Disabled by default (interval=0). Set AUTO_SAVE_DIR and + # AUTO_SAVE_INTERVAL to enable. + _tool_call_counter = [0] + try: + _auto_save_interval = int(os.getenv("AUTO_SAVE_INTERVAL", "0")) + except ValueError: + logging.warning("Invalid AUTO_SAVE_INTERVAL value, defaulting to 0 (disabled)") + _auto_save_interval = 0 + _auto_save_dir = os.getenv("AUTO_SAVE_DIR", "") + async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None: last_mcp_tool_results.append(result) + _tool_call_counter[0] += 1 + if _auto_save_dir and _auto_save_interval and _tool_call_counter[0] % _auto_save_interval == 0: + write_auto_save(_auto_save_dir, _tool_call_counter[0], tool.name, result) async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") diff --git a/tests/test_runner.py b/tests/test_runner.py index a50c0f2..7a9551e 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -24,6 +24,8 @@ _merge_reusable_task, _resolve_model_config, _resolve_task_model, + read_tool_log, + write_auto_save, ) @@ -441,3 +443,65 @@ def test_raises_type_error_on_non_iterable_result(self): inputs={}, ) ) + + +# =================================================================== +# Auto-save scaffolding +# =================================================================== + + +class TestAutoSave: + """Tests for write_auto_save / read_tool_log (module-level functions).""" + + def test_disabled_when_no_dir(self): + """read_tool_log returns [] when dir is empty string.""" + assert read_tool_log("") == [] + + def test_write_then_read_roundtrip(self, tmp_path): + """write_auto_save produces entries that read_tool_log can read back.""" + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="search_code", result="found 5") + write_auto_save(d, turn=2, tool_name="read_file", result="contents") + entries = read_tool_log(d) + assert len(entries) == 2 + assert entries[0]["turn"] == 1 + assert entries[0]["tool"] == "search_code" + assert entries[1]["turn"] == 2 + + def test_log_format_has_turn_tool_preview(self, tmp_path): + """Each NDJSON entry has the expected keys.""" + d = str(tmp_path) + write_auto_save(d, turn=7, tool_name="search_code", result="found 5 matches") + entries = read_tool_log(d) + assert len(entries) == 1 + assert entries[0] == {"turn": 7, "tool": "search_code", "result_preview": "found 5 matches"} + + def test_result_truncated_to_2000(self, tmp_path): + """Result preview is capped at 2000 characters.""" + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="big", result="x" * 5000) + entries = read_tool_log(d) + assert len(entries[0]["result_preview"]) == 2000 + + def test_survives_write_failure(self): + """write_auto_save suppresses write errors without crashing.""" + with patch("builtins.open", side_effect=OSError("disk full")): + write_auto_save("/tmp/any", turn=1, tool_name="t", result="r") + + def test_read_skips_corrupt_trailing_line(self, tmp_path): + """read_tool_log skips truncated/corrupt lines without discarding valid ones.""" + import os + + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="good", result="ok") + # Append a corrupt line simulating crash mid-write + log_path = os.path.join(d, "auto_save_tool_log.ndjson") + with open(log_path, "a") as f: + f.write('{"truncated\n') + entries = read_tool_log(d) + assert len(entries) == 1 + assert entries[0]["tool"] == "good" + + def test_read_empty_dir(self, tmp_path): + """read_tool_log on a dir with no log file returns [].""" + assert read_tool_log(str(tmp_path)) == []