diff --git a/README.md b/README.md index 2600fcc..8507c35 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,12 @@ Failed tasks are automatically retried up to 3 times with increasing backoff before the session is saved. Session checkpoints are stored in the platform-specific application data directory. +### Auto-save tool log + +Set `AUTO_SAVE_DIR` and `AUTO_SAVE_INTERVAL` to enable periodic tool-result +logging. Every N tool calls, the runner appends an NDJSON entry to +`{AUTO_SAVE_DIR}/auto_save_tool_log.ndjson`. Disabled by default (interval=0). + ### Error Output By default, errors are shown as concise one-line messages. Use `--debug` (or @@ -424,6 +430,13 @@ confirm: - memcache_clear_cache ``` +### Append-only logs + +`memcache_append_log(key, entry)` appends a timestamped entry to an append-only +log. Use this instead of `memcache_set_state` when accumulating findings or +notes that should never be overwritten. Retrieve all entries with +`memcache_get_log(key)`. + ## Taskflows A sequence of interdependent tasks performed by a set of Agents. Configured through YAML files of `filetype` `taskflow`. diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py index 85b3fca..f44beba 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py @@ -76,5 +76,22 @@ def memcache_clear_cache(): return backend.clear_cache() +@mcp.tool() +def memcache_append_log(key: str, entry: Any) -> str: + """Append a timestamped entry to an append-only log under the given key. + Use this for findings or notes that should accumulate, not be replaced. + Retrieve entries with memcache_get_log(key).""" + return backend.append_log(key, entry) + + +@mcp.tool() +def memcache_get_log(key: str) -> str: + """Retrieve all entries from an append-only log created by memcache_append_log.""" + import json as _json + + entries = backend.get_log(key) + return _json.dumps(entries, indent=2) + + if __name__ == "__main__": mcp.run(show_banner=False) diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/backend.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/backend.py index d43cb0c..491e012 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/backend.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/backend.py @@ -22,3 +22,13 @@ def list_keys(self) -> str: def clear_cache(self) -> str: pass + + def snapshot_state(self) -> dict[str, Any]: + """Return all keys with their merged values as a dict. Safe to call externally.""" + return {} + + def append_log(self, key: str, entry: Any) -> str: + pass + + def get_log(self, key: str) -> list: + pass diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py index 72407af..116e936 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py @@ -111,3 +111,32 @@ def _clear_cache() -> str: return "Memory cache was cleared, all previous key lists are invalidated." return _clear_cache() + + def snapshot_state(self) -> dict[str, Any]: + """Return a deep copy of the in-memory dictionary.""" + import copy + + self._inflate_memory() + return copy.deepcopy(self.memcache) + + def append_log(self, key: str, entry: Any) -> str: + from datetime import datetime, timezone + + log_key = f"_log:{key}" + wrapped = {"_ts": datetime.now(timezone.utc).isoformat(), "data": entry} + + @self.with_memory + def _append(k, v): + existing = self.memcache.get(k) + if isinstance(existing, list): + existing.append(v) + else: + self.memcache[k] = [v] + return f"Appended entry to log '{k}'" + return _append(log_key, wrapped) + + def get_log(self, key: str) -> list[dict]: + log_key = f"_log:{key}" + self._inflate_memory() + val = self.memcache.get(log_key, []) + return val if isinstance(val, list) else [val] diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py index 0b91c4e..7a7ebfe 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py @@ -29,7 +29,7 @@ def set_state(self, key: str, value: Any) -> str: kv = KeyValue(key=key, value=json.dumps(value)) session.add(kv) session.commit() - return 'f"Stored value in memory for `{key}`"' + return f"Stored value in memory for `{key}`" def get_state(self, key: str) -> Any: with Session(self.engine) as session: @@ -86,3 +86,33 @@ def clear_cache(self) -> str: session.query(KeyValue).delete() session.commit() return "Cleared all keys in memory cache." + + def snapshot_state(self) -> dict[str, Any]: + """Return all distinct keys with their merged values.""" + with Session(self.engine) as session: + keys = [k[0] for k in session.query(KeyValue.key).distinct().all()] + result = {} + for k in keys: + if k.startswith("_log:"): + result[k] = self.get_log(k.removeprefix("_log:")) + else: + result[k] = self.get_state(k) + return result + + def append_log(self, key: str, entry: Any) -> str: + from datetime import datetime, timezone + + log_key = f"_log:{key}" + wrapped = {"_ts": datetime.now(timezone.utc).isoformat(), "data": entry} + with Session(self.engine) as session: + kv = KeyValue(key=log_key, value=json.dumps(wrapped)) + session.add(kv) + session.commit() + return f"Appended entry to log '{log_key}'" + + def get_log(self, key: str) -> list[dict]: + """Retrieve all log entries for a key as a flat list.""" + log_key = f"_log:{key}" + with Session(self.engine) as session: + rows = session.query(KeyValue).filter_by(key=log_key).order_by(KeyValue.id).all() + return [json.loads(row.value) for row in rows] diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5869385..36b8627 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -16,10 +16,13 @@ "MAX_RATE_LIMIT_BACKOFF", "RATE_LIMIT_BACKOFF", "deploy_task_agents", + "read_tool_log", "run_main", + "write_auto_save", ] import asyncio +import contextlib import json import logging import os @@ -52,6 +55,68 @@ TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries +AUTO_SAVE_LOG_NAME = "auto_save_tool_log.ndjson" + + +def write_auto_save( + auto_save_dir: str, + turn: int, + tool_name: str, + result: str, +) -> None: + """Append tool result to auto-save log (NDJSON, append-only).""" + try: + os.makedirs(auto_save_dir, exist_ok=True) + save_path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME) + entry = json.dumps({ + "turn": turn, + "tool": tool_name, + "result_preview": (result or "")[:2000], + }) + with open(save_path, "a", encoding="utf-8") as f: + f.write(entry + "\n") + except Exception as e: + logging.warning(f"Auto-save failed: {e}") + + +def read_tool_log(auto_save_dir: str) -> list[dict]: + """Read NDJSON auto-save log. Skips malformed lines.""" + if not auto_save_dir: + return [] + entries: list[dict] = [] + try: + path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME) + if os.path.exists(path): + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + with contextlib.suppress(json.JSONDecodeError): + entries.append(json.loads(line)) + except Exception: + logging.debug("Failed to read auto-save tool log", exc_info=True) + return entries + + +def _snapshot_memcache_state() -> dict[str, Any]: + """Read memcache state via backend's snapshot_state(). Returns {} on failure.""" + try: + from .path_utils import mcp_data_dir + + state_dir = str(mcp_data_dir("seclab-taskflow-agent", "memcache", "MEMCACHE_STATE_DIR")) + backend_name = os.getenv("MEMCACHE_BACKEND", "sqlite") + if backend_name == "dictionary_file": + from .mcp_servers.memcache.memcache_backend.dictionary_file import MemcacheDictionaryFileBackend + + return MemcacheDictionaryFileBackend(state_dir).snapshot_state() + from .mcp_servers.memcache.memcache_backend.sqlite import SqliteBackend + + return SqliteBackend(state_dir).snapshot_state() + except Exception: + logging.debug("Failed to snapshot memcache state", exc_info=True) + return {} + def _resolve_model_config( available_tools: AvailableTools, @@ -467,8 +532,22 @@ async def run_main( last_mcp_tool_results: list[str] = [] + # Auto-save scaffolding: periodically persist tool results to disk. + # Disabled by default (interval=0). Set AUTO_SAVE_DIR and + # AUTO_SAVE_INTERVAL to enable. + _tool_call_counter = [0] + try: + _auto_save_interval = int(os.getenv("AUTO_SAVE_INTERVAL", "0")) + except ValueError: + logging.warning("Invalid AUTO_SAVE_INTERVAL value, defaulting to 0 (disabled)") + _auto_save_interval = 0 + _auto_save_dir = os.getenv("AUTO_SAVE_DIR", "") + async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None: last_mcp_tool_results.append(result) + _tool_call_counter[0] += 1 + if _auto_save_dir and _auto_save_interval and _tool_call_counter[0] % _auto_save_interval == 0: + write_auto_save(_auto_save_dir, _tool_call_counter[0], tool.name, result) async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") @@ -701,7 +780,13 @@ async def _deploy(ra: dict, pp: str) -> bool: # If all retries exhausted with an exception, save and re-raise if last_task_error is not None: - session.mark_failed(f"Task {task_name!r}: {last_task_error}") + snap = _snapshot_memcache_state() + log = read_tool_log(_auto_save_dir) + session.mark_failed( + f"Task {task_name!r}: {last_task_error}", + memcache_snapshot=snap, + tool_log_snapshot=log, + ) await render_model_output( f"** 🤖💾 Session saved: {session.session_id}\n" f"** 🤖💡 Resume with: --resume {session.session_id}\n" @@ -711,7 +796,13 @@ async def _deploy(ra: dict, pp: str) -> bool: if must_complete and not task_complete: logging.critical("Required task not completed ... aborting!") await render_model_output("🤖💥 *Required task not completed ...\n") - session.mark_failed(f"Required task {task_name!r} did not complete") + snap = _snapshot_memcache_state() + log = read_tool_log(_auto_save_dir) + session.mark_failed( + f"Required task {task_name!r} did not complete", + memcache_snapshot=snap, + tool_log_snapshot=log, + ) await render_model_output( f"** 🤖💾 Session saved: {session.session_id}\n" f"** 🤖💡 Resume with: --resume {session.session_id}\n" diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py index 9b77151..dd24f83 100644 --- a/src/seclab_taskflow_agent/session.py +++ b/src/seclab_taskflow_agent/session.py @@ -20,6 +20,7 @@ import uuid from datetime import datetime, timezone from pathlib import Path +from typing import Any from pydantic import BaseModel, Field @@ -63,6 +64,10 @@ class TaskflowSession(BaseModel): # Accumulated tool results carried across tasks (used by repeat_prompt) last_tool_results: list[str] = Field(default_factory=list) + # Failure forensics: captured at mark_failed time for post-mortem inspection + memcache_snapshot: dict[str, Any] = Field(default_factory=dict) + tool_log_snapshot: list[dict[str, Any]] = Field(default_factory=list) + @property def next_task_index(self) -> int: """Index of the next task to execute.""" @@ -107,9 +112,22 @@ def mark_finished(self) -> None: self.finished = True self.save() - def mark_failed(self, error: str) -> None: - """Mark the session as failed with an error message and save.""" + def mark_failed( + self, + error: str, + memcache_snapshot: dict[str, Any] | None = None, + tool_log_snapshot: list[dict[str, Any]] | None = None, + ) -> None: + """Mark the session as failed with an error message and save. + + Optionally captures memcache state and tool log at failure time + for post-mortem inspection. + """ self.error = error + if memcache_snapshot is not None: + self.memcache_snapshot = memcache_snapshot + if tool_log_snapshot is not None: + self.tool_log_snapshot = tool_log_snapshot self.save() @classmethod diff --git a/tests/test_memcache_backend.py b/tests/test_memcache_backend.py new file mode 100644 index 0000000..e1025f1 --- /dev/null +++ b/tests/test_memcache_backend.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for memcache backend snapshot_state and append_log.""" + +from __future__ import annotations + +import threading + +from seclab_taskflow_agent.mcp_servers.memcache.memcache_backend.dictionary_file import ( + MemcacheDictionaryFileBackend, +) +from seclab_taskflow_agent.mcp_servers.memcache.memcache_backend.sqlite import SqliteBackend + + +class TestSnapshotStateSqlite: + def test_snapshot_returns_all_keys(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + b.set_state("key1", "val1") + b.set_state("key2", [1, 2, 3]) + snap = b.snapshot_state() + assert snap["key1"] == "val1" + assert snap["key2"] == [1, 2, 3] + + def test_snapshot_empty_db(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + assert b.snapshot_state() == {} + + +class TestSnapshotStateDictFile: + def test_snapshot_returns_deep_copy(self, tmp_path): + b = MemcacheDictionaryFileBackend(str(tmp_path)) + b.set_state("a", {"nested": True}) + snap = b.snapshot_state() + assert snap["a"] == {"nested": True} + # Mutating nested values in the snapshot shouldn't affect the backend + snap["a"]["nested"] = False + assert b.get_state("a") == {"nested": True} + + def test_snapshot_empty(self, tmp_path): + b = MemcacheDictionaryFileBackend(str(tmp_path)) + assert b.snapshot_state() == {} + + +class TestAppendLogSqlite: + def test_creates_list_on_first_append(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + b.append_log("findings", "xss found") + entries = b.get_log("findings") + assert len(entries) == 1 + assert entries[0]["data"] == "xss found" + + def test_appends_to_existing_list(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + b.append_log("findings", "xss") + b.append_log("findings", "sqli") + entries = b.get_log("findings") + assert len(entries) == 2 + assert entries[0]["data"] == "xss" + assert entries[1]["data"] == "sqli" + + def test_entries_have_timestamps(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + b.append_log("k", "v") + entries = b.get_log("k") + assert "_ts" in entries[0] + + def test_get_log_returns_all_entries_in_order(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + for i in range(10): + b.append_log("ordered", f"entry-{i}") + entries = b.get_log("ordered") + assert len(entries) == 10 + for i, entry in enumerate(entries): + assert entry["data"] == f"entry-{i}" + + def test_get_log_empty_key_returns_empty_list(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + assert b.get_log("nonexistent") == [] + + def test_multiple_sequential_appends_no_data_loss(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + n = 50 + for i in range(n): + b.append_log("bulk", f"item-{i}") + assert len(b.get_log("bulk")) == n + + def test_concurrent_appends_no_data_loss(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + n_threads = 5 + n_per_thread = 20 + + def worker(thread_id): + for i in range(n_per_thread): + b.append_log("concurrent", f"t{thread_id}-{i}") + + threads = [threading.Thread(target=worker, args=(t,)) for t in range(n_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + entries = b.get_log("concurrent") + assert len(entries) == n_threads * n_per_thread + + +class TestAppendLogDictFile: + def test_creates_list_on_first_append(self, tmp_path): + b = MemcacheDictionaryFileBackend(str(tmp_path)) + b.append_log("findings", "xss found") + entries = b.get_log("findings") + assert len(entries) == 1 + assert entries[0]["data"] == "xss found" + + def test_appends_to_existing_list(self, tmp_path): + b = MemcacheDictionaryFileBackend(str(tmp_path)) + b.append_log("findings", "xss") + b.append_log("findings", "sqli") + entries = b.get_log("findings") + assert len(entries) == 2 + + def test_get_log_empty_key_returns_empty_list(self, tmp_path): + b = MemcacheDictionaryFileBackend(str(tmp_path)) + assert b.get_log("nonexistent") == [] + + +class TestSetStateReturnFix: + def test_set_state_return_contains_key_name(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + result = b.set_state("my_key", "value") + assert "my_key" in result + # Should NOT be a literal f-string + assert "{key}" not in result + + +class TestSetStateUnchanged: + def test_set_state_replaces_value(self, tmp_path): + b = SqliteBackend(str(tmp_path)) + b.set_state("k", "old") + b.set_state("k", "new") + assert b.get_state("k") == "new" diff --git a/tests/test_runner.py b/tests/test_runner.py index a50c0f2..7a9551e 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -24,6 +24,8 @@ _merge_reusable_task, _resolve_model_config, _resolve_task_model, + read_tool_log, + write_auto_save, ) @@ -441,3 +443,65 @@ def test_raises_type_error_on_non_iterable_result(self): inputs={}, ) ) + + +# =================================================================== +# Auto-save scaffolding +# =================================================================== + + +class TestAutoSave: + """Tests for write_auto_save / read_tool_log (module-level functions).""" + + def test_disabled_when_no_dir(self): + """read_tool_log returns [] when dir is empty string.""" + assert read_tool_log("") == [] + + def test_write_then_read_roundtrip(self, tmp_path): + """write_auto_save produces entries that read_tool_log can read back.""" + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="search_code", result="found 5") + write_auto_save(d, turn=2, tool_name="read_file", result="contents") + entries = read_tool_log(d) + assert len(entries) == 2 + assert entries[0]["turn"] == 1 + assert entries[0]["tool"] == "search_code" + assert entries[1]["turn"] == 2 + + def test_log_format_has_turn_tool_preview(self, tmp_path): + """Each NDJSON entry has the expected keys.""" + d = str(tmp_path) + write_auto_save(d, turn=7, tool_name="search_code", result="found 5 matches") + entries = read_tool_log(d) + assert len(entries) == 1 + assert entries[0] == {"turn": 7, "tool": "search_code", "result_preview": "found 5 matches"} + + def test_result_truncated_to_2000(self, tmp_path): + """Result preview is capped at 2000 characters.""" + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="big", result="x" * 5000) + entries = read_tool_log(d) + assert len(entries[0]["result_preview"]) == 2000 + + def test_survives_write_failure(self): + """write_auto_save suppresses write errors without crashing.""" + with patch("builtins.open", side_effect=OSError("disk full")): + write_auto_save("/tmp/any", turn=1, tool_name="t", result="r") + + def test_read_skips_corrupt_trailing_line(self, tmp_path): + """read_tool_log skips truncated/corrupt lines without discarding valid ones.""" + import os + + d = str(tmp_path) + write_auto_save(d, turn=1, tool_name="good", result="ok") + # Append a corrupt line simulating crash mid-write + log_path = os.path.join(d, "auto_save_tool_log.ndjson") + with open(log_path, "a") as f: + f.write('{"truncated\n') + entries = read_tool_log(d) + assert len(entries) == 1 + assert entries[0]["tool"] == "good" + + def test_read_empty_dir(self, tmp_path): + """read_tool_log on a dir with no log file returns [].""" + assert read_tool_log(str(tmp_path)) == [] diff --git a/tests/test_session.py b/tests/test_session.py index f8563f9..60154f9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -92,3 +92,53 @@ def test_with_results(self): t = CompletedTask(index=2, name="analyze", result=True, tool_results=["r1", "r2"]) assert t.index == 2 assert t.tool_results == ["r1", "r2"] + + +class TestSessionForensics: + """Tests for failure forensics (memcache_snapshot, tool_log_snapshot).""" + + def test_mark_failed_with_snapshot_roundtrips(self, tmp_path, monkeypatch): + monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + snap = {"findings": ["xss", "sqli"]} + log = [{"turn": 1, "tool": "search_code", "result_preview": "found"}] + s.mark_failed("crash", memcache_snapshot=snap, tool_log_snapshot=log) + + loaded = TaskflowSession.load(s.session_id) + assert loaded.memcache_snapshot == snap + assert loaded.tool_log_snapshot == log + assert loaded.error == "crash" + + def test_mark_failed_without_snapshot_backward_compatible(self): + s = TaskflowSession(taskflow_path="test.flow") + s.mark_failed("simple error") + assert s.memcache_snapshot == {} + assert s.tool_log_snapshot == [] + assert s.error == "simple error" + + def test_old_session_json_without_new_fields_loads(self, tmp_path, monkeypatch): + """Session JSON from before forensics fields was added still loads.""" + monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path) + # Write minimal JSON without new fields + import json + + old_data = { + "session_id": "oldformat123", + "taskflow_path": "old.flow", + "cli_globals": {}, + "prompt": "", + "created_at": "2026-01-01T00:00:00+00:00", + "updated_at": "", + "completed_tasks": [], + "total_tasks": 0, + "finished": False, + "error": "old error", + "last_tool_results": [], + } + path = tmp_path / "oldformat123.json" + path.write_text(json.dumps(old_data)) + + loaded = TaskflowSession.load("oldformat123") + assert loaded.error == "old error" + assert loaded.memcache_snapshot == {} + assert loaded.tool_log_snapshot == []