Skip to content
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ Failed tasks are automatically retried up to 3 times with increasing backoff
before the session is saved. Session checkpoints are stored in the
platform-specific application data directory.

### Auto-save tool log

Set `AUTO_SAVE_DIR` and `AUTO_SAVE_INTERVAL` to enable periodic tool-result
logging. Every N tool calls, the runner appends an NDJSON entry to
`{AUTO_SAVE_DIR}/auto_save_tool_log.ndjson`. Disabled by default (interval=0).

### Error Output

By default, errors are shown as concise one-line messages. Use `--debug` (or
Expand Down Expand Up @@ -424,6 +430,13 @@ confirm:
- memcache_clear_cache
```

### Append-only logs

`memcache_append_log(key, entry)` appends a timestamped entry to an append-only
log. Use this instead of `memcache_set_state` when accumulating findings or
notes that should never be overwritten. Retrieve all entries with
`memcache_get_log(key)`.

## Taskflows

A sequence of interdependent tasks performed by a set of Agents. Configured through YAML files of `filetype` `taskflow`.
Expand Down
17 changes: 17 additions & 0 deletions src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,22 @@ def memcache_clear_cache():
return backend.clear_cache()


@mcp.tool()
def memcache_append_log(key: str, entry: Any) -> str:
"""Append a timestamped entry to an append-only log under the given key.
Use this for findings or notes that should accumulate, not be replaced.
Retrieve entries with memcache_get_log(key)."""
return backend.append_log(key, entry)


@mcp.tool()
def memcache_get_log(key: str) -> str:
"""Retrieve all entries from an append-only log created by memcache_append_log."""
import json as _json

entries = backend.get_log(key)
return _json.dumps(entries, indent=2)


if __name__ == "__main__":
mcp.run(show_banner=False)
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,13 @@ def list_keys(self) -> str:

def clear_cache(self) -> str:
pass

def snapshot_state(self) -> dict[str, Any]:
"""Return all keys with their merged values as a dict. Safe to call externally."""
return {}

def append_log(self, key: str, entry: Any) -> str:
pass

def get_log(self, key: str) -> list:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,32 @@ def _clear_cache() -> str:
return "Memory cache was cleared, all previous key lists are invalidated."

return _clear_cache()

def snapshot_state(self) -> dict[str, Any]:
"""Return a deep copy of the in-memory dictionary."""
import copy

self._inflate_memory()
return copy.deepcopy(self.memcache)

def append_log(self, key: str, entry: Any) -> str:
from datetime import datetime, timezone

log_key = f"_log:{key}"
wrapped = {"_ts": datetime.now(timezone.utc).isoformat(), "data": entry}

@self.with_memory
def _append(k, v):
existing = self.memcache.get(k)
if isinstance(existing, list):
existing.append(v)
else:
self.memcache[k] = [v]
return f"Appended entry to log '{k}'"
return _append(log_key, wrapped)

def get_log(self, key: str) -> list[dict]:
log_key = f"_log:{key}"
self._inflate_memory()
val = self.memcache.get(log_key, [])
return val if isinstance(val, list) else [val]
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def set_state(self, key: str, value: Any) -> str:
kv = KeyValue(key=key, value=json.dumps(value))
session.add(kv)
session.commit()
return 'f"Stored value in memory for `{key}`"'
return f"Stored value in memory for `{key}`"

def get_state(self, key: str) -> Any:
with Session(self.engine) as session:
Expand Down Expand Up @@ -86,3 +86,33 @@ def clear_cache(self) -> str:
session.query(KeyValue).delete()
session.commit()
return "Cleared all keys in memory cache."

def snapshot_state(self) -> dict[str, Any]:
"""Return all distinct keys with their merged values."""
with Session(self.engine) as session:
keys = [k[0] for k in session.query(KeyValue.key).distinct().all()]
result = {}
for k in keys:
if k.startswith("_log:"):
result[k] = self.get_log(k.removeprefix("_log:"))
else:
result[k] = self.get_state(k)
return result

def append_log(self, key: str, entry: Any) -> str:
from datetime import datetime, timezone

log_key = f"_log:{key}"
wrapped = {"_ts": datetime.now(timezone.utc).isoformat(), "data": entry}
with Session(self.engine) as session:
kv = KeyValue(key=log_key, value=json.dumps(wrapped))
session.add(kv)
session.commit()
return f"Appended entry to log '{log_key}'"

def get_log(self, key: str) -> list[dict]:
"""Retrieve all log entries for a key as a flat list."""
log_key = f"_log:{key}"
with Session(self.engine) as session:
rows = session.query(KeyValue).filter_by(key=log_key).order_by(KeyValue.id).all()
return [json.loads(row.value) for row in rows]
95 changes: 93 additions & 2 deletions src/seclab_taskflow_agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"MAX_RATE_LIMIT_BACKOFF",
"RATE_LIMIT_BACKOFF",
"deploy_task_agents",
"read_tool_log",
"run_main",
"write_auto_save",
]

import asyncio
import contextlib
import json
import logging
import os
Expand Down Expand Up @@ -52,6 +55,68 @@
TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task
TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries

AUTO_SAVE_LOG_NAME = "auto_save_tool_log.ndjson"


def write_auto_save(
auto_save_dir: str,
turn: int,
tool_name: str,
result: str,
) -> None:
"""Append tool result to auto-save log (NDJSON, append-only)."""
try:
os.makedirs(auto_save_dir, exist_ok=True)
save_path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME)
entry = json.dumps({
"turn": turn,
"tool": tool_name,
"result_preview": (result or "")[:2000],
})
with open(save_path, "a", encoding="utf-8") as f:
f.write(entry + "\n")
except Exception as e:
logging.warning(f"Auto-save failed: {e}")


def read_tool_log(auto_save_dir: str) -> list[dict]:
"""Read NDJSON auto-save log. Skips malformed lines."""
if not auto_save_dir:
return []
entries: list[dict] = []
try:
path = os.path.join(auto_save_dir, AUTO_SAVE_LOG_NAME)
if os.path.exists(path):
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
with contextlib.suppress(json.JSONDecodeError):
entries.append(json.loads(line))
except Exception:
logging.debug("Failed to read auto-save tool log", exc_info=True)
return entries


def _snapshot_memcache_state() -> dict[str, Any]:
"""Read memcache state via backend's snapshot_state(). Returns {} on failure."""
try:
from .path_utils import mcp_data_dir

state_dir = str(mcp_data_dir("seclab-taskflow-agent", "memcache", "MEMCACHE_STATE_DIR"))
backend_name = os.getenv("MEMCACHE_BACKEND", "sqlite")
if backend_name == "dictionary_file":
from .mcp_servers.memcache.memcache_backend.dictionary_file import MemcacheDictionaryFileBackend

return MemcacheDictionaryFileBackend(state_dir).snapshot_state()
from .mcp_servers.memcache.memcache_backend.sqlite import SqliteBackend

return SqliteBackend(state_dir).snapshot_state()
except Exception:
logging.debug("Failed to snapshot memcache state", exc_info=True)
return {}


def _resolve_model_config(
available_tools: AvailableTools,
Expand Down Expand Up @@ -467,8 +532,22 @@ async def run_main(

last_mcp_tool_results: list[str] = []

# Auto-save scaffolding: periodically persist tool results to disk.
# Disabled by default (interval=0). Set AUTO_SAVE_DIR and
# AUTO_SAVE_INTERVAL to enable.
_tool_call_counter = [0]
try:
_auto_save_interval = int(os.getenv("AUTO_SAVE_INTERVAL", "0"))
except ValueError:
logging.warning("Invalid AUTO_SAVE_INTERVAL value, defaulting to 0 (disabled)")
_auto_save_interval = 0
_auto_save_dir = os.getenv("AUTO_SAVE_DIR", "")

async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None:
last_mcp_tool_results.append(result)
_tool_call_counter[0] += 1
if _auto_save_dir and _auto_save_interval and _tool_call_counter[0] % _auto_save_interval == 0:
write_auto_save(_auto_save_dir, _tool_call_counter[0], tool.name, result)

async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None:
await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n")
Expand Down Expand Up @@ -701,7 +780,13 @@ async def _deploy(ra: dict, pp: str) -> bool:

# If all retries exhausted with an exception, save and re-raise
if last_task_error is not None:
session.mark_failed(f"Task {task_name!r}: {last_task_error}")
snap = _snapshot_memcache_state()
log = read_tool_log(_auto_save_dir)
session.mark_failed(
f"Task {task_name!r}: {last_task_error}",
memcache_snapshot=snap,
tool_log_snapshot=log,
)
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
Expand All @@ -711,7 +796,13 @@ async def _deploy(ra: dict, pp: str) -> bool:
if must_complete and not task_complete:
logging.critical("Required task not completed ... aborting!")
await render_model_output("🤖💥 *Required task not completed ...\n")
session.mark_failed(f"Required task {task_name!r} did not complete")
snap = _snapshot_memcache_state()
log = read_tool_log(_auto_save_dir)
session.mark_failed(
f"Required task {task_name!r} did not complete",
memcache_snapshot=snap,
tool_log_snapshot=log,
)
await render_model_output(
f"** 🤖💾 Session saved: {session.session_id}\n"
f"** 🤖💡 Resume with: --resume {session.session_id}\n"
Expand Down
22 changes: 20 additions & 2 deletions src/seclab_taskflow_agent/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -63,6 +64,10 @@ class TaskflowSession(BaseModel):
# Accumulated tool results carried across tasks (used by repeat_prompt)
last_tool_results: list[str] = Field(default_factory=list)

# Failure forensics: captured at mark_failed time for post-mortem inspection
memcache_snapshot: dict[str, Any] = Field(default_factory=dict)
tool_log_snapshot: list[dict[str, Any]] = Field(default_factory=list)

@property
def next_task_index(self) -> int:
"""Index of the next task to execute."""
Expand Down Expand Up @@ -107,9 +112,22 @@ def mark_finished(self) -> None:
self.finished = True
self.save()

def mark_failed(self, error: str) -> None:
"""Mark the session as failed with an error message and save."""
def mark_failed(
self,
error: str,
memcache_snapshot: dict[str, Any] | None = None,
tool_log_snapshot: list[dict[str, Any]] | None = None,
) -> None:
"""Mark the session as failed with an error message and save.

Optionally captures memcache state and tool log at failure time
for post-mortem inspection.
"""
self.error = error
if memcache_snapshot is not None:
self.memcache_snapshot = memcache_snapshot
if tool_log_snapshot is not None:
self.tool_log_snapshot = tool_log_snapshot
self.save()

@classmethod
Expand Down
Loading