Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ Failed tasks are automatically retried up to 3 times with increasing backoff
before the session is saved. Session checkpoints are stored in the
platform-specific application data directory.

### Auto-save tool log

Set `AUTO_SAVE_DIR` and `AUTO_SAVE_INTERVAL` to enable periodic tool-result
logging. Every N tool calls, the runner appends an NDJSON entry to
`{AUTO_SAVE_DIR}/auto_save_tool_log.ndjson`. Disabled by default (interval=0).

### Error Output

By default, errors are shown as concise one-line messages. Use `--debug` (or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ def list_keys(self) -> str:

def clear_cache(self) -> str:
pass

def snapshot_state(self) -> dict[str, Any]:
"""Return all keys with their merged values as a dict. Safe to call externally."""
return {}
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,10 @@ def _clear_cache() -> str:
return "Memory cache was cleared, all previous key lists are invalidated."

return _clear_cache()

def snapshot_state(self) -> dict[str, Any]:
"""Return a deep copy of the in-memory dictionary."""
import copy

self._inflate_memory()
return copy.deepcopy(self.memcache)
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,15 @@ def clear_cache(self) -> str:
session.query(KeyValue).delete()
session.commit()
return "Cleared all keys in memory cache."

def snapshot_state(self) -> dict[str, Any]:
"""Return all distinct keys with their merged values."""
with Session(self.engine) as session:
keys = [k[0] for k in session.query(KeyValue.key).distinct().all()]
result = {}
for k in keys:
if k.startswith("_log:") and hasattr(self, "get_log"):
result[k] = self.get_log(k.removeprefix("_log:"))
else:
result[k] = self.get_state(k)
return result
95 changes: 93 additions & 2 deletions src/seclab_taskflow_agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"MAX_RATE_LIMIT_BACKOFF",
"RATE_LIMIT_BACKOFF",
"deploy_task_agents",
"read_tool_log",
"run_main",
"write_auto_save",
]

import asyncio
import contextlib
import json
import logging
import os
Expand Down Expand Up @@ -52,6 +55,68 @@
TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task
TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries

AUTO_SAVE_LOG_NAME = "auto_save_tool_log.ndjson"


def write_auto_save(
auto_save_dir: str,
turn: int,
tool_name: str,
result: str,
) -> None:
"""Append tool result to auto-save log (NDJSON, append-only)."""
try:
os.makedirs(auto_save_dir, exist_ok=True)
save_path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME)
entry = json.dumps({
"turn": turn,
"tool": tool_name,
"result_preview": (result or "")[:2000],
})
with open(save_path, "a", encoding="utf-8") as f:
f.write(entry + "\n")
except Exception as e:
logging.warning(f"Auto-save failed: {e}")


def read_tool_log(auto_save_dir: str) -> list[dict]:
"""Read NDJSON auto-save log. Skips malformed lines."""
if not auto_save_dir:
return []
entries: list[dict] = []
try:
path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME)
if os.path.exists(path):
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
with contextlib.suppress(json.JSONDecodeError):
entries.append(json.loads(line))
except Exception:
logging.debug("Failed to read auto-save tool log", exc_info=True)
return entries


def _snapshot_memcache_state() -> dict[str, Any]:
"""Read memcache state via backend's snapshot_state(). Returns {} on failure."""
try:
from .path_utils import mcp_data_dir

state_dir = str(mcp_data_dir("seclab-taskflow-agent", "memcache", "MEMCACHE_STATE_DIR"))
backend_name = os.getenv("MEMCACHE_BACKEND", "sqlite")
if backend_name == "dictionary_file":
from .mcp_servers.memcache.memcache_backend.dictionary_file import MemcacheDictionaryFileBackend

return MemcacheDictionaryFileBackend(state_dir).snapshot_state()
from .mcp_servers.memcache.memcache_backend.sqlite import SqliteBackend

return SqliteBackend(state_dir).snapshot_state()
except Exception:
logging.debug("Failed to snapshot memcache state", exc_info=True)
return {}


def _resolve_model_config(
available_tools: AvailableTools,
Expand Down Expand Up @@ -467,8 +532,22 @@ async def run_main(

last_mcp_tool_results: list[str] = []

# Auto-save scaffolding: periodically persist tool results to disk.
# Disabled by default (interval=0). Set AUTO_SAVE_DIR and
# AUTO_SAVE_INTERVAL to enable.
_tool_call_counter = [0]
try:
_auto_save_interval = int(os.getenv("AUTO_SAVE_INTERVAL", "0"))
except ValueError:
logging.warning("Invalid AUTO_SAVE_INTERVAL value, defaulting to 0 (disabled)")
_auto_save_interval = 0
_auto_save_dir = os.getenv("AUTO_SAVE_DIR", "")

async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None:
last_mcp_tool_results.append(result)
_tool_call_counter[0] += 1
if _auto_save_dir and _auto_save_interval and _tool_call_counter[0] % _auto_save_interval == 0:
write_auto_save(_auto_save_dir, _tool_call_counter[0], tool.name, result)

async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None:
await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n")
Expand Down Expand Up @@ -701,7 +780,13 @@ async def _deploy(ra: dict, pp: str) -> bool:

# If all retries exhausted with an exception, save and re-raise
if last_task_error is not None:
session.mark_failed(f"Task {task_name!r}: {last_task_error}")
snap = _snapshot_memcache_state()
log = read_tool_log(_auto_save_dir)
session.mark_failed(
f"Task {task_name!r}: {last_task_error}",
memcache_snapshot=snap,
tool_log_snapshot=log,
)
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
Expand All @@ -711,7 +796,13 @@ async def _deploy(ra: dict, pp: str) -> bool:
if must_complete and not task_complete:
logging.critical("Required task not completed ... aborting!")
await render_model_output("🤖💥 *Required task not completed ...\n")
session.mark_failed(f"Required task {task_name!r} did not complete")
snap = _snapshot_memcache_state()
log = read_tool_log(_auto_save_dir)
session.mark_failed(
f"Required task {task_name!r} did not complete",
memcache_snapshot=snap,
tool_log_snapshot=log,
)
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
Expand Down
22 changes: 20 additions & 2 deletions src/seclab_taskflow_agent/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -63,6 +64,10 @@ class TaskflowSession(BaseModel):
# Accumulated tool results carried across tasks (used by repeat_prompt)
last_tool_results: list[str] = Field(default_factory=list)

# Failure forensics: captured at mark_failed time for post-mortem inspection
memcache_snapshot: dict[str, Any] = Field(default_factory=dict)
tool_log_snapshot: list[dict[str, Any]] = Field(default_factory=list)

@property
def next_task_index(self) -> int:
"""Index of the next task to execute."""
Expand Down Expand Up @@ -107,9 +112,22 @@ def mark_finished(self) -> None:
self.finished = True
self.save()

def mark_failed(self, error: str) -> None:
"""Mark the session as failed with an error message and save."""
def mark_failed(
self,
error: str,
memcache_snapshot: dict[str, Any] | None = None,
tool_log_snapshot: list[dict[str, Any]] | None = None,
) -> None:
"""Mark the session as failed with an error message and save.

Optionally captures memcache state and tool log at failure time
for post-mortem inspection.
"""
self.error = error
if memcache_snapshot is not None:
self.memcache_snapshot = memcache_snapshot
if tool_log_snapshot is not None:
self.tool_log_snapshot = tool_log_snapshot
self.save()

@classmethod
Expand Down
40 changes: 40 additions & 0 deletions tests/test_memcache_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: GitHub, Inc.
# SPDX-License-Identifier: MIT

"""Tests for memcache backend snapshot_state and append_log."""

from __future__ import annotations

from seclab_taskflow_agent.mcp_servers.memcache.memcache_backend.dictionary_file import (
MemcacheDictionaryFileBackend,
)
from seclab_taskflow_agent.mcp_servers.memcache.memcache_backend.sqlite import SqliteBackend


class TestSnapshotStateSqlite:
def test_snapshot_returns_all_keys(self, tmp_path):
b = SqliteBackend(str(tmp_path))
b.set_state("key1", "val1")
b.set_state("key2", [1, 2, 3])
snap = b.snapshot_state()
assert snap["key1"] == "val1"
assert snap["key2"] == [1, 2, 3]

def test_snapshot_empty_db(self, tmp_path):
b = SqliteBackend(str(tmp_path))
assert b.snapshot_state() == {}


class TestSnapshotStateDictFile:
def test_snapshot_returns_deep_copy(self, tmp_path):
b = MemcacheDictionaryFileBackend(str(tmp_path))
b.set_state("a", {"nested": True})
snap = b.snapshot_state()
assert snap["a"] == {"nested": True}
# Mutating nested values in the snapshot shouldn't affect the backend
snap["a"]["nested"] = False
assert b.get_state("a") == {"nested": True}

def test_snapshot_empty(self, tmp_path):
b = MemcacheDictionaryFileBackend(str(tmp_path))
assert b.snapshot_state() == {}
64 changes: 64 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
_merge_reusable_task,
_resolve_model_config,
_resolve_task_model,
read_tool_log,
write_auto_save,
)


Expand Down Expand Up @@ -441,3 +443,65 @@ def test_raises_type_error_on_non_iterable_result(self):
inputs={},
)
)


# ===================================================================
# Auto-save scaffolding
# ===================================================================


class TestAutoSave:
"""Tests for write_auto_save / read_tool_log (module-level functions)."""

def test_disabled_when_no_dir(self):
"""read_tool_log returns [] when dir is empty string."""
assert read_tool_log("") == []

def test_write_then_read_roundtrip(self, tmp_path):
"""write_auto_save produces entries that read_tool_log can read back."""
d = str(tmp_path)
write_auto_save(d, turn=1, tool_name="search_code", result="found 5")
write_auto_save(d, turn=2, tool_name="read_file", result="contents")
entries = read_tool_log(d)
assert len(entries) == 2
assert entries[0]["turn"] == 1
assert entries[0]["tool"] == "search_code"
assert entries[1]["turn"] == 2

def test_log_format_has_turn_tool_preview(self, tmp_path):
"""Each NDJSON entry has the expected keys."""
d = str(tmp_path)
write_auto_save(d, turn=7, tool_name="search_code", result="found 5 matches")
entries = read_tool_log(d)
assert len(entries) == 1
assert entries[0] == {"turn": 7, "tool": "search_code", "result_preview": "found 5 matches"}

def test_result_truncated_to_2000(self, tmp_path):
"""Result preview is capped at 2000 characters."""
d = str(tmp_path)
write_auto_save(d, turn=1, tool_name="big", result="x" * 5000)
entries = read_tool_log(d)
assert len(entries[0]["result_preview"]) == 2000

def test_survives_write_failure(self):
"""write_auto_save suppresses write errors without crashing."""
with patch("builtins.open", side_effect=OSError("disk full")):
write_auto_save("/tmp/any", turn=1, tool_name="t", result="r")

def test_read_skips_corrupt_trailing_line(self, tmp_path):
"""read_tool_log skips truncated/corrupt lines without discarding valid ones."""
import os

d = str(tmp_path)
write_auto_save(d, turn=1, tool_name="good", result="ok")
# Append a corrupt line simulating crash mid-write
log_path = os.path.join(d, "auto_save_tool_log.ndjson")
with open(log_path, "a") as f:
f.write('{"truncated\n')
entries = read_tool_log(d)
assert len(entries) == 1
assert entries[0]["tool"] == "good"

def test_read_empty_dir(self, tmp_path):
"""read_tool_log on a dir with no log file returns []."""
assert read_tool_log(str(tmp_path)) == []
50 changes: 50 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,53 @@ def test_with_results(self):
t = CompletedTask(index=2, name="analyze", result=True, tool_results=["r1", "r2"])
assert t.index == 2
assert t.tool_results == ["r1", "r2"]


class TestSessionForensics:
"""Tests for failure forensics (memcache_snapshot, tool_log_snapshot)."""

def test_mark_failed_with_snapshot_roundtrips(self, tmp_path, monkeypatch):
monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path)
s = TaskflowSession(taskflow_path="test.flow")
snap = {"findings": ["xss", "sqli"]}
log = [{"turn": 1, "tool": "search_code", "result_preview": "found"}]
s.mark_failed("crash", memcache_snapshot=snap, tool_log_snapshot=log)

loaded = TaskflowSession.load(s.session_id)
assert loaded.memcache_snapshot == snap
assert loaded.tool_log_snapshot == log
assert loaded.error == "crash"

def test_mark_failed_without_snapshot_backward_compatible(self):
s = TaskflowSession(taskflow_path="test.flow")
s.mark_failed("simple error")
assert s.memcache_snapshot == {}
assert s.tool_log_snapshot == []
assert s.error == "simple error"

def test_old_session_json_without_new_fields_loads(self, tmp_path, monkeypatch):
"""Session JSON from before forensics fields was added still loads."""
monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path)
# Write minimal JSON without new fields
import json

old_data = {
"session_id": "oldformat123",
"taskflow_path": "old.flow",
"cli_globals": {},
"prompt": "",
"created_at": "2026-01-01T00:00:00+00:00",
"updated_at": "",
"completed_tasks": [],
"total_tasks": 0,
"finished": False,
"error": "old error",
"last_tool_results": [],
}
path = tmp_path / "oldformat123.json"
path.write_text(json.dumps(old_data))

loaded = TaskflowSession.load("oldformat123")
assert loaded.error == "old error"
assert loaded.memcache_snapshot == {}
assert loaded.tool_log_snapshot == []