From bea8ac7072fbacd0f28b0f9609d3da51f5955059 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 12:28:56 -0400 Subject: [PATCH 01/20] refactor: decompose engine into modules with Pydantic models and Typer CLI - Extract __main__.py into cli.py, runner.py, mcp_lifecycle.py, models.py - Add Pydantic v2 grammar models for all YAML document types - Replace argparse with Typer CLI, add project.scripts entry point - Reduce ruff ignore list from 59 to 22 rules, fix lint issues - Add 29 model tests against real YAML files (68/68 pass) - Full grammar backwards compatibility preserved --- README.md | 40 +- pyproject.toml | 114 ++-- src/seclab_taskflow_agent/__init__.py | 18 + src/seclab_taskflow_agent/__main__.py | 674 +-------------------- src/seclab_taskflow_agent/agent.py | 2 +- src/seclab_taskflow_agent/capi.py | 6 +- src/seclab_taskflow_agent/cli.py | 175 ++++++ src/seclab_taskflow_agent/mcp_lifecycle.py | 167 +++++ src/seclab_taskflow_agent/models.py | 205 +++++++ src/seclab_taskflow_agent/runner.py | 477 +++++++++++++++ tests/test_cli_parser.py | 12 +- tests/test_models.py | 303 +++++++++ 12 files changed, 1443 insertions(+), 750 deletions(-) create mode 100644 src/seclab_taskflow_agent/cli.py create mode 100644 src/seclab_taskflow_agent/mcp_lifecycle.py create mode 100644 src/seclab_taskflow_agent/models.py create mode 100644 src/seclab_taskflow_agent/runner.py create mode 100644 tests/test_models.py diff --git a/README.md b/README.md index 50d7d534..448bef3d 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,11 @@ # GitHub Security Lab Taskflow Agent -The Security Lab Taskflow Agent is an MCP enabled multi-Agent framework. +The Security Lab Taskflow Agent is an MCP-enabled multi-Agent framework for +declarative, YAML-driven agentic workflows. -The Taskflow Agent is built on top of the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/). +Built on top of the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/), +it uses [Pydantic](https://docs.pydantic.dev/) for grammar validation and +[Jinja2](https://jinja.palletsprojects.com/) for template rendering. ## Core Concepts @@ -16,6 +19,39 @@ Agents can cooperate to complete sequences of tasks through so-called [taskflows You can find a detailed overview of the taskflow grammar [here](doc/GRAMMAR.md) and example taskflows [here](examples/taskflows/). +## Architecture + +``` +┌─────────────────────────────────────────────────────┐ +│ CLI (cli.py) │ +│ Typer-based entry point: -p, -t, -l, -g KEY=VALUE │ +└─────────────────────┬───────────────────────────────┘ + │ +┌─────────────────────▼───────────────────────────────┐ +│ Runner (runner.py) │ +│ Taskflow execution loop, model resolution, │ +│ template rendering, repeat-prompt iteration │ +└─────────────────────┬───────────────────────────────┘ + │ +┌─────────────────────▼───────────────────────────────┐ +│ MCP Lifecycle (mcp_lifecycle.py) │ +│ Server connection, cleanup, process management │ +└─────────────────────┬───────────────────────────────┘ + │ +┌─────────────────────▼───────────────────────────────┐ +│ Agent (agent.py) │ +│ TaskAgent wrapper, hooks, OpenAI Agents SDK bridge │ +└─────────────────────────────────────────────────────┘ + +Supporting modules: + models.py — Pydantic v2 grammar models (validation) + available_tools.py — YAML resource loader with caching + template_utils.py — Jinja2 template environment + mcp_utils.py — MCP namespace wrapping, system prompts + capi.py — AI API endpoint management + path_utils.py — Platform-aware data/log directories +``` + ## Use Cases and Examples The Seclab Taskflow Agent framework was primarily designed to fit the iterative feedback loop driven work involved in Agentic security research workflows and vulnerability triage tasks. diff --git a/pyproject.toml b/pyproject.toml index e01006fc..0c6b6bef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,9 @@ dependencies = [ "zipp==3.23.0", ] +[project.scripts] +seclab-taskflow-agent = "seclab_taskflow_agent.cli:app" + [project.urls] Source = "https://github.com/GitHubSecurityLab/seclab-taskflow-agent" Issues = "https://github.com/GitHubSecurityLab/seclab-taskflow-agent/issues" @@ -150,86 +153,39 @@ exclude_lines = [ target-version = "py310" [tool.ruff.lint] -# Suppress all current linter errors to establish a baseline +# Project-wide style choices and pragmatic suppressions. +# Rules removed from the original baseline are now enforced. ignore = [ - "A001", # Variable shadows built-in - "A002", # Argument shadows built-in - "A004", # Import shadows built-in - "ARG001", # Unused function argument - "B006", # Mutable default argument - "B007", # Unused loop control variable - "B008", # Function call in argument defaults - "B023", # Function uses loop variable - "B904", # raise-without-from-inside-except - "BLE001", # Blind except - "C405", # Unnecessary literal set - "C416", # Unnecessary comprehension - "E721", # Type comparison using == - "E722", # Bare except - "E741", # Ambiguous variable name - "EM101", # Exception string literal - "EM102", # Exception f-string - "F541", # f-string without placeholders - "F811", # Redefinition of unused name - "F821", # Undefined name - "F841", # Unused variable - "FA100", # Missing from __future__ import annotations - "FA102", # Missing from __future__ import annotations in stub - "FBT001", # Boolean positional arg in function definition - "FBT002", # Boolean default value in function definition - "FURB188", # Prefer removeprefix over conditional slice - "G004", # Logging with f-string - "I001", # Import block unsorted or unformatted - "INP001", # Implicit namespace package - "LOG015", # root logger usage - "N801", # Class name should use CapWords convention - "N802", # Function name should be lowercase - "N806", # Variable in function should be lowercase - "N818", # Exception name should end with Error - "PERF102", # Use keys() or values() instead of items() - "PERF401", # Use list comprehension - "PIE790", # Unnecessary pass statement - "PLC0415", # Import should be at top of file - "PLC1802", # Use of len(x) == 0 - "PLR2004", # Magic value used in comparison - "PLW0602", # Global variable not assigned - "PLW0603", # Using global statement - "PLW1508", # Invalid envvar default - "PLW2901", # Outer loop variable overwritten - "PT011", # pytest.raises too broad - "PYI041", # Use float instead of int | float + # Style choices — these are deliberate project conventions + "EM101", # Exception string literals (pragmatic for this codebase) + "EM102", # Exception f-strings (pragmatic for this codebase) + "G004", # Logging f-strings (clearer than % formatting) + "T201", # print() used intentionally for user output + "TRY003", # Raise with inline message strings (pragmatic) + + # Backwards-compatibility suppressions for existing code + "A001", # Variable shadows built-in (existing API names) + "A002", # Argument shadows built-in (existing API signatures) + "B006", # Mutable default argument (existing signatures, would break API) + "FBT001", # Boolean positional arg (existing API) + "FBT002", # Boolean default value (existing API) + "N802", # Function name casing (existing API: get_AI_endpoint etc.) + "N806", # Variable casing (existing code conventions) + "SLF001", # Private member access (needed for MCP wrapper internals) + + # Framework / ecosystem constraints + "ARG001", # Unused function argument (required by hook/callback signatures) + "B023", # Function uses loop variable (async closures in runner) + "INP001", # Implicit namespace package (project uses src layout) + "PLW2901", # Outer loop variable overwritten (iteration patterns) + "S701", # Jinja2 autoescape=False (YAML context, not HTML) + + # Low-signal rules for this project + "PLR2004", # Magic value comparisons "RET503", # Missing explicit return - "RET504", # Unnecessary assignment before return "RET505", # Unnecessary else after return - "RET506", # Unnecessary else after raise - "RUF005", # Unpack instead of concatenation - "RUF010", # Use explicit conversion flag - "RUF015", # Prefer next() over single element slice - "RUF059", # Use of private function/attribute - "RUF100", # Unused noqa directive - "S108", # Hardcoded temp file/directory - "S607", # Starting process with partial path - "S701", # Using jinja2 templates with autoescape=False is dangerous and can lead to XSS - "SIM102", # Use single if statement - "SIM115", # Use context handler for file - "SIM210", # Use ternary operator - "SLF001", # Private member access - "T201", # print found - "TID252", # Relative imports from parent modules - "TRY003", # Raise vanilla args - "TRY004", # Prefer TypeError for wrong type - "TRY300", # Consider moving statement to else - "TRY301", # Abstract raise to inner function - "TRY400", # Use logging.exception instead of logging.error - "UP004", # Use X | Y for union types - "UP006", # Use X | Y for union types in isinstance - "UP009", # UTF-8 encoding declaration - "UP015", # Unnecessary mode argument - "UP020", # Use builtin open - "UP024", # Replace aliased errors with OSError - "UP032", # Use f-string - "UP035", # Import from collections.abc - "UP045", # Use X | None for type annotations - "W291", # Trailing whitespace - "W293", # Blank line contains whitespace + "SIM102", # Collapsible if (readability preference) ] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101", "PLR2004"] diff --git a/src/seclab_taskflow_agent/__init__.py b/src/seclab_taskflow_agent/__init__.py index 306cb0f3..5cf7c001 100644 --- a/src/seclab_taskflow_agent/__init__.py +++ b/src/seclab_taskflow_agent/__init__.py @@ -1,2 +1,20 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT + +"""SecLab Taskflow Agent — secure and automated workflow execution. + +This package provides the engine for running declarative YAML-based taskflows +that orchestrate AI agents with MCP (Model Context Protocol) tool servers +for security analysis, code auditing, and vulnerability triage. + +Architecture +~~~~~~~~~~~~ +- :mod:`~seclab_taskflow_agent.models` — Pydantic grammar models +- :mod:`~seclab_taskflow_agent.cli` — CLI entry point (Typer) +- :mod:`~seclab_taskflow_agent.runner` — Taskflow execution engine +- :mod:`~seclab_taskflow_agent.agent` — Agent wrapper classes +- :mod:`~seclab_taskflow_agent.mcp_lifecycle` — MCP server lifecycle +- :mod:`~seclab_taskflow_agent.mcp_utils` — MCP utilities +- :mod:`~seclab_taskflow_agent.template_utils` — Jinja2 template rendering +- :mod:`~seclab_taskflow_agent.available_tools` — YAML resource loader +""" diff --git a/src/seclab_taskflow_agent/__main__.py b/src/seclab_taskflow_agent/__main__.py index 96a3b2c2..0f010138 100644 --- a/src/seclab_taskflow_agent/__main__.py +++ b/src/seclab_taskflow_agent/__main__.py @@ -1,671 +1,27 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT -import argparse -import asyncio -import json -import logging -from logging.handlers import RotatingFileHandler -import os -import pathlib -import sys -import uuid -from logging.handlers import RotatingFileHandler +"""Entry point for ``python -m seclab_taskflow_agent``. -from agents import Agent, RunContextWrapper, TContext, Tool -from agents.agent import ModelSettings +This module serves as the package entry point. The actual implementation +is split across focused modules: -# from agents.run import DEFAULT_MAX_TURNS # XXX: this is 10, we need more than that -from agents.exceptions import AgentsException, MaxTurnsExceeded -from agents.extensions.handoff_prompt import prompt_with_handoff_instructions -from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter -from dotenv import find_dotenv, load_dotenv -from openai import APITimeoutError, BadRequestError, RateLimitError -from openai.types.responses import ResponseTextDeltaEvent +- :mod:`~seclab_taskflow_agent.cli` — CLI argument parsing (Typer) +- :mod:`~seclab_taskflow_agent.runner` — Taskflow execution engine +- :mod:`~seclab_taskflow_agent.mcp_lifecycle` — MCP server lifecycle +- :mod:`~seclab_taskflow_agent.models` — Pydantic grammar models +- :mod:`~seclab_taskflow_agent.agent` — Agent wrapper classes +""" -from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks -from .available_tools import AvailableTools -from .banner import get_banner -from .capi import get_AI_token, list_tool_call_models -from .env_utils import TmpEnv -from .mcp_utils import ( - DEFAULT_MCP_CLIENT_SESSION_TIMEOUT, - MCPNamespaceWrap, - ReconnectingMCPServerStdio, - StreamableMCPThread, - compress_name, - mcp_client_params, - mcp_system_prompt, -) -from .path_utils import log_file_name -from .template_utils import render_template -import jinja2 -from .render_utils import flush_async_output, render_model_output -from .shell_utils import shell_tool_call +from dotenv import find_dotenv, load_dotenv load_dotenv(find_dotenv(usecwd=True)) -# only model output or help message should go to stdout, everything else goes to log -logging.getLogger("").setLevel(logging.NOTSET) -log_file_handler = RotatingFileHandler(log_file_name("task_agent.log"), maxBytes=1024 * 1024 * 10, backupCount=10) -log_file_handler.setLevel(os.getenv("TASK_AGENT_LOGLEVEL", default="DEBUG")) -log_file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) -logging.getLogger("").addHandler(log_file_handler) - -console_handler = logging.StreamHandler() -console_handler.setLevel(logging.ERROR) # log only ERROR and above to console -console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) -logging.getLogger("").addHandler(console_handler) - -DEFAULT_MAX_TURNS = 50 -RATE_LIMIT_BACKOFF = 5 -MAX_RATE_LIMIT_BACKOFF = 120 -MAX_API_RETRY = 5 -MCP_CLEANUP_TIMEOUT = 5 - - -def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = None): - parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent") - parser.prog = "" - group = parser.add_mutually_exclusive_group() - group.add_argument("-p", help="The personality to use (mutex with -t)", required=False) - group.add_argument("-t", help="The taskflow to use (mutex with -p)", required=False) - group.add_argument("-l", help="List available tool call models and exit", action="store_true", required=False) - parser.add_argument( - "-g", - "--global", - dest="globals", - action="append", - help="Set global variable (KEY=VALUE). Can be used multiple times.", - required=False, - ) - parser.add_argument("prompt", nargs=argparse.REMAINDER) - # parser.add_argument('remainder', nargs=argparse.REMAINDER, help="Remaining args") - help_msg = parser.format_help() - help_msg += "\nExamples:\n\n" - help_msg += "`-p seclab_taskflow_agent.personalities.assistant explain modems to me please`\n" - help_msg += "`-t examples.taskflows.example_globals -g fruit=apples`\n" - try: - args = parser.parse_known_args(user_prompt.split(" ") if user_prompt else None) - except SystemExit as e: - if e.code == 2: - logging.exception(f"User provided incomplete prompt: {user_prompt}") - return None, None, None, None, help_msg - p = args[0].p.strip() if args[0].p else None - t = args[0].t.strip() if args[0].t else None - l = args[0].l - - # Parse global variables from command line - cli_globals = {} - if args[0].globals: - for g in args[0].globals: - if "=" not in g: - logging.error(f"Invalid global variable format: {g}. Expected KEY=VALUE") - return None, None, None, None, None, help_msg - key, value = g.split("=", 1) - cli_globals[key.strip()] = value.strip() - - return p, t, l, cli_globals, " ".join(args[0].prompt), help_msg - - -async def deploy_task_agents( - available_tools: AvailableTools, - agents: dict, - prompt: str, - async_task: bool = False, - toolboxes_override: list = [], - blocked_tools: list = [], - headless: bool = False, - exclude_from_context: bool = False, - max_turns: int = DEFAULT_MAX_TURNS, - model: str = DEFAULT_MODEL, - model_par: dict = {}, - run_hooks: TaskRunHooks | None = None, - agent_hooks: TaskAgentHooks | None = None, -): - task_id = str(uuid.uuid4()) - await render_model_output(f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n") - await render_model_output(f"** 🤖💪 Task ID: {task_id}\n") - await render_model_output(f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n") - - mcp_servers = [] - server_prompts = [] - toolboxes = [] - - if toolboxes_override: - # limit tools to task specified tools if set - toolboxes = toolboxes_override - else: - # otherwise all agents have the disjunction of all their tools available - for k, v in agents.items(): - if v.get("toolboxes", []): - toolboxes += [tb for tb in v["toolboxes"] if tb not in toolboxes] - - # https://openai.github.io/openai-agents-python/ref/model_settings/ - parallel_tool_calls = True if os.getenv("MODEL_PARALLEL_TOOL_CALLS") else False - model_params = { - "temperature": os.getenv("MODEL_TEMP", default=0.0), - "tool_choice": ("auto" if toolboxes else None), - "parallel_tool_calls": (parallel_tool_calls if toolboxes else None), - } - model_params.update(model_par) - model_settings = ModelSettings(**model_params) - - # block tools if requested - tool_filter = create_static_tool_filter(blocked_tool_names=blocked_tools) if blocked_tools else None - - # fetch mcp params - mcp_params = mcp_client_params(available_tools, toolboxes) - for tb, (params, confirms, server_prompt, client_session_timeout) in mcp_params.items(): - server_prompts.append(server_prompt) - # https://openai.github.io/openai-agents-python/mcp/ - # allowed_tool_names will allow list - # blocked_tool_names will block list - if headless: - # XXX: auto-allow all tools if task is headless by clearing confirms - confirms = [] - client_session_timeout = client_session_timeout or DEFAULT_MCP_CLIENT_SESSION_TIMEOUT - server_proc = None - match params["kind"]: - # since we spawn stdio servers each time we do not expect - # new tools to appear over time so cache the tools list - case "stdio": - if params.get("reconnecting", False): - mcp_server = ReconnectingMCPServerStdio( - name=tb, - params=params, - tool_filter=tool_filter, - client_session_timeout_seconds=client_session_timeout, - cache_tools_list=True, - ) - else: - mcp_server = MCPServerStdio( - name=tb, - params=params, - tool_filter=tool_filter, - client_session_timeout_seconds=client_session_timeout, - cache_tools_list=True, - ) - case "sse": - mcp_server = MCPServerSse( - name=tb, - params=params, - tool_filter=tool_filter, - client_session_timeout_seconds=client_session_timeout, - ) - case "streamable": - # check if we need to start this server locally as well - if "command" in params: - - def _print_out(line): - msg = f"Streamable MCP Server stdout: {line}" - logging.info(msg) - # print(msg) - - def _print_err(line): - msg = f"Streamable MCP Server stderr: {line}" - logging.info(msg) - # print(msg) - - server_proc = StreamableMCPThread( - params["command"], - url=params["url"], - env=params["env"], - on_output=_print_out, - on_error=_print_err, - ) - mcp_server = MCPServerStreamableHttp( - name=tb, - params=params, - tool_filter=tool_filter, - client_session_timeout_seconds=client_session_timeout, - ) - case _: - raise ValueError(f"Unsupported MCP transport {params['kind']}") - # provide namespace and confirmation control through wrapper class - mcp_servers.append((MCPNamespaceWrap(confirms, mcp_server), server_proc)) - - # connect mcp servers - # https://openai.github.io/openai-agents-python/ref/mcp/server/ - async def mcp_session_task(mcp_servers: list, connected: asyncio.Event, cleanup: asyncio.Event) -> None: - try: - # connects/cleanups have to happen in the same task - # but we also want to use wait_for to set a timeout - # so we use a dedicated session task to accomplish both - for s in mcp_servers: - server, server_proc = s - logging.debug(f"Connecting mcp server: {server._name}") - if server_proc is not None: - server_proc.start() - await server_proc.async_wait_for_connection(poll_interval=0.1) - await server.connect() - # signal that we're connected - connected.set() - # wait until we're told to clean up - await cleanup.wait() - for s in reversed(mcp_servers): - server, server_proc = s - try: - logging.debug(f"Starting cleanup for mcp server: {server._name}") - await server.cleanup() - logging.debug(f"Cleaned up mcp server: {server._name}") - if server_proc is not None: - server_proc.stop() - try: - await asyncio.to_thread(server_proc.join_and_raise) - except Exception as e: - print(f"Streamable mcp server process exception: {e}") - except asyncio.CancelledError: - logging.exception(f"Timeout on cleanup for mcp server: {server._name}") - finally: - mcp_servers.remove(s) - except RuntimeError as e: - logging.exception("RuntimeError in mcp session task") - except asyncio.CancelledError as e: - logging.exception("Timeout on main session task") - finally: - mcp_servers.clear() - - servers_connected = asyncio.Event() - start_cleanup = asyncio.Event() - mcp_sessions = asyncio.create_task(mcp_session_task(mcp_servers, servers_connected, start_cleanup)) - - # wait for the servers to be connected - await servers_connected.wait() - logging.debug("All mcp servers are connected!") - - try: - # any important general guidelines go here - important_guidelines = [ - "Do not prompt the user with questions.", - "Run tasks until a final result is available.", - "Ensure responses are based on the latest information from available tools.", - "Run tools sequentially, wait until one tool has completed before calling the next.", - ] - - # create one layer of handoff agents if any additional agents are listed - # https://openai.github.io/openai-agents-python/handoffs/ - handoffs = [] - for handoff_agent in list(agents.keys())[1:]: - handoffs.append( - TaskAgent( - # XXX: name has to be descriptive for an effective handoff - name=compress_name(handoff_agent), - instructions=prompt_with_handoff_instructions( - mcp_system_prompt( - agents[handoff_agent]["personality"], - agents[handoff_agent]["task"], - server_prompts=server_prompts, - important_guidelines=important_guidelines, - ) - ), - handoffs=[], - exclude_from_context=exclude_from_context, - mcp_servers=[s[0] for s in mcp_servers], - model=model, - model_settings=model_settings, - run_hooks=run_hooks, - agent_hooks=agent_hooks, - ).agent - ) - - # create the primary task agent - primary_agent = list(agents.keys())[0] - system_prompt = mcp_system_prompt( - agents[primary_agent]["personality"], - agents[primary_agent]["task"], - server_prompts=server_prompts, - important_guidelines=important_guidelines, - ) - agent0 = TaskAgent( - name=primary_agent, - # only add the handoff prompt if we have handoffs defined - instructions=prompt_with_handoff_instructions(system_prompt) if handoffs else system_prompt, - handoffs=handoffs, - exclude_from_context=exclude_from_context, - mcp_servers=[s[0] for s in mcp_servers], - model=model, - model_settings=model_settings, - run_hooks=run_hooks, - agent_hooks=agent_hooks, - ) - - try: - complete = False - - async def _run_streamed(): - max_retry = MAX_API_RETRY - rate_limit_backoff = RATE_LIMIT_BACKOFF - while rate_limit_backoff: - try: - result = agent0.run_streamed(prompt, max_turns=max_turns) - # render result events - # https://openai.github.io/openai-agents-python/ref/stream_events/ - # https://openai.github.io/openai-agents-python/ref/run/ - # https://openai.github.io/openai-agents-python/results/ - async for event in result.stream_events(): - if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): - await render_model_output(event.data.delta, async_task=async_task, task_id=task_id) - await render_model_output("\n\n", async_task=async_task, task_id=task_id) - return - except APITimeoutError: - if not max_retry: - logging.exception("Max retries for APITimeoutError reached") - raise - max_retry -= 1 - except RateLimitError: - if rate_limit_backoff == MAX_RATE_LIMIT_BACKOFF: - raise APITimeoutError("Max rate limit backoff reached") - if rate_limit_backoff > MAX_RATE_LIMIT_BACKOFF: - rate_limit_backoff = MAX_RATE_LIMIT_BACKOFF - else: - rate_limit_backoff += rate_limit_backoff - logging.exception(f"Hit rate limit ... holding for {rate_limit_backoff}") - await asyncio.sleep(rate_limit_backoff) - - await _run_streamed() - complete = True - - # raise exceptions up to here for anything that indicates a task failure - except MaxTurnsExceeded as e: - await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n", async_task=async_task, task_id=task_id) - logging.exception(f"Exceeded max_turns: {max_turns}") - except AgentsException as e: - await render_model_output(f"** 🤖❗ Agent Exception: {e}\n", async_task=async_task, task_id=task_id) - logging.exception("Agent Exception") - except BadRequestError as e: - await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id) - logging.exception("Bad Request") - except APITimeoutError as e: - await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id) - logging.exception("Bad Request") - - if async_task: - await flush_async_output(task_id) - - return complete - - finally: - # signal mcp sessions task that it can disconnect our servers - start_cleanup.set() - cleanup_attempts_left = len(mcp_servers) - while cleanup_attempts_left and mcp_servers: - try: - cleanup_attempts_left -= 1 - await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT) - except asyncio.TimeoutError: - continue - except Exception as e: - logging.exception("Exception in mcp server cleanup task") - - -async def main(available_tools: AvailableTools, p: str | None, t: str | None, cli_globals: dict, prompt: str | None): - last_mcp_tool_results = [] # XXX: memleaky - - async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str): - last_mcp_tool_results.append(result) - - async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool): - await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") - - async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]): - await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n") - - if p: - personality = available_tools.get_personality(p) - - await deploy_task_agents( - available_tools, - {p: personality}, - prompt, - run_hooks=TaskRunHooks(on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook), - ) - - if t: - taskflow = available_tools.get_taskflow(t) - - await render_model_output(f"** 🤖💪 Running Task Flow: {t}\n") - - # optional global vars available for the taskflow tasks - # Start with globals from taskflow file, then override with CLI globals - global_variables = taskflow.get("globals", {}) - if cli_globals: - global_variables.update(cli_globals) - model_config = taskflow.get("model_config", {}) - model_keys = [] - models_params = {} - if model_config: - m_config = available_tools.get_model_config(model_config) - model_dict = m_config.get("models", {}) - if model_dict: - if not isinstance(model_dict, dict): - raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary") - model_keys = model_dict.keys() - models_params = m_config.get("model_settings", {}) - if models_params and not isinstance(models_params, dict): - raise ValueError(f"Settings section of model_config file {model_config} must be a dictionary") - if not set(models_params.keys()).difference(model_keys).issubset(set([])): - raise ValueError( - f"Settings section of model_config file {model_config} contains models that are not in the model section" - ) - for k, v in models_params.items(): - if not isinstance(v, dict): - raise ValueError(f"Settings for model {k} in model_config file {model_config} is not a dictionary") - - for task in taskflow["taskflow"]: - task_body = task["task"] - - # reusable taskflow support (they have to be single step taskflows) - # if uses: is set, swap in the appropriate task_body values from child - # child values can NOT overwrite existing parent values, so parents - # can tweak reusable task configurations as they see fit - uses = task_body.get("uses", "") - if uses: - reusable_taskflow = available_tools.get_taskflow(uses) - if reusable_taskflow is None: - raise ValueError(f"No such reusable taskflow: {uses}") - if len(reusable_taskflow["taskflow"]) > 1: - raise ValueError("Reusable taskflows can only contain 1 task") - for k, v in reusable_taskflow["taskflow"][0]["task"].items(): - if k not in task_body: - task_body[k] = v - model = task_body.get("model", DEFAULT_MODEL) - model_settings = {} - if model in model_keys: - if model in models_params: - model_settings = models_params[model].copy() - model = model_dict[model] - task_model_settings = task_body.get("model_settings", {}) - if not isinstance(task_model_settings, dict): - name = task.get("name", "") - raise ValueError(f"model_settings in task {name} needs to be a dictionary") - model_settings.update(task_model_settings) - - # parse our taskflow grammar - name = task_body.get("name", "taskflow") # placeholder, not used yet - description = task_body.get("description", "taskflow") # placeholder not used yet - agents = task_body.get("agents", []) - headless = task_body.get("headless", False) - blocked_tools = task_body.get("blocked_tools", []) - run = task_body.get("run", "") - inputs = task_body.get("inputs", {}) - prompt = task_body.get("user_prompt", "") - if run and prompt: - raise ValueError("shell task and prompt task are mutually exclusive!") - must_complete = task_body.get("must_complete", False) - max_turns = task_body.get("max_steps", DEFAULT_MAX_TURNS) - toolboxes_override = task_body.get("toolboxes", []) - env = task_body.get("env", {}) - repeat_prompt = task_body.get("repeat_prompt", False) - # this will set Agent 'stop_on_first_tool' tool use behavior, which prevents output back to llm - exclude_from_context = task_body.get("exclude_from_context", False) - # this allows you to run repeated prompts concurrently with a limit - async_task = task_body.get("async", False) - max_concurrent_tasks = task_body.get("async_limit", 5) - - # Render prompt template with Jinja2 (skip if repeat_prompt since result is not yet available) - if prompt and not repeat_prompt: - try: - prompt = render_template( - template_str=prompt, - available_tools=available_tools, - globals_dict=global_variables, - inputs_dict=inputs, - ) - except jinja2.TemplateError as e: - logging.error(f"Template rendering error: {e}") - raise ValueError(f"Failed to render prompt template: {e}") - - with TmpEnv(env): - prompts_to_run = [] - if repeat_prompt: - # Check if prompt contains result template variable - if 'result' not in prompt.lower(): - logging.warning("repeat_prompt enabled but no {{ result }} in prompt") - - try: - # Get last MCP tool result - last_result = json.loads(last_mcp_tool_results.pop()) - text = last_result.get("text", "") - try: - iterable_result = json.loads(text) - except json.decoder.JSONDecodeError as exc: - logging.critical(f"Could not parse result text: {text}") - raise ValueError(f"Result text is not valid JSON") from exc - - # Verify iterable - try: - iter(iterable_result) - except TypeError: - logging.critical("Last MCP tool result is not iterable") - raise - except IndexError: - logging.critical("No last MCP tool result available") - raise - - if not iterable_result: - await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") - else: - logging.debug(f"Rendering templated prompts for results: {iterable_result}") - - # Render template for each result value - for value in iterable_result: - try: - rendered_prompt = render_template( - template_str=prompt, - available_tools=available_tools, - globals_dict=global_variables, - inputs_dict=inputs, - result_value=value, - ) - prompts_to_run.append(rendered_prompt) - except jinja2.TemplateError as e: - logging.error(f"Error rendering template for result {value}: {e}") - raise ValueError(f"Template rendering failed: {e}") - else: - prompts_to_run.append(prompt) - - async def run_prompts(async_task=False, max_concurrent_tasks=5): - # if this is a shell task, execute that and append the results - if run: - await render_model_output("** 🤖🐚 Executing Shell Task\n") - # this allows e.g. shell based jq output to become available for repeat prompts - try: - result = shell_tool_call(run).content[0].model_dump_json() - last_mcp_tool_results.append(result) - return True - except RuntimeError as e: - await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n") - logging.exception("Shell task error") - return False - - tasks = [] - task_results = [] - semaphore = asyncio.Semaphore(max_concurrent_tasks) - for prompt in prompts_to_run: - # run a task prompt - resolved_agents = {} - if not agents: - # XXX: deprecate the -p parser for taskflows entirely? - # XXX: probably just adds unneeded parsing complexity - p, _, _, prompt, _ = parse_prompt_args(available_tools, prompt) - agents.append(p) - for p in agents: - personality = available_tools.get_personality(p) - if personality is None: - raise ValueError(f"No such personality: {p}") - resolved_agents[p] = personality - - # limit the max concurrent tasks via a semaphore - async def _deploy_task_agents(resolved_agents, prompt): - async with semaphore: - result = await deploy_task_agents( - available_tools, - # pass agents and prompt by assignment, they change in-loop - resolved_agents, - prompt, - async_task=async_task, - toolboxes_override=toolboxes_override, - blocked_tools=blocked_tools, - headless=headless, - exclude_from_context=exclude_from_context, - max_turns=max_turns, - run_hooks=TaskRunHooks( - on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook - ), - model=model, - model_par=model_settings, - agent_hooks=TaskAgentHooks(on_handoff=on_handoff_hook), - ) - return result - - task_coroutine = _deploy_task_agents(resolved_agents, prompt) - - if not async_task: - # wait for the task - result = await task_coroutine - task_results.append(result) - else: - # stack the task - tasks.append(task_coroutine) - - if async_task: - # gather results - task_results = await asyncio.gather(*tasks, return_exceptions=True) - - complete = True - # if any prompt in a must_complete task is not complete the entire task is incomplete - for result in task_results: - if isinstance(result, Exception): - logging.error(f"Caught exception in Gather: {result}") - result = False - complete = result and complete - return complete - - # an async tasks runs prompts concurrently - task_complete = await run_prompts(async_task=async_task, max_concurrent_tasks=max_concurrent_tasks) - - if must_complete and not task_complete: - logging.critical("Required task not completed ... aborting!") - await render_model_output("🤖💥 *Required task not completed ...\n") - break - +# Re-export for backwards compatibility — some tests import from __main__ +from .cli import parse_prompt_args # noqa: E402, F401 +from .runner import deploy_task_agents, run_main # noqa: E402, F401 if __name__ == "__main__": - cwd = pathlib.Path.cwd() - available_tools = AvailableTools() - - p, t, l, cli_globals, user_prompt, help_msg = parse_prompt_args(available_tools) - - if l: - tool_models = list_tool_call_models(get_AI_token()) - for model in tool_models: - print(model) - sys.exit(0) - - if p is None and t is None: - print(help_msg) - sys.exit(1) + from .cli import app - print(get_banner()) # print banner only before starting main event loop - asyncio.run(main(available_tools, p, t, cli_globals, user_prompt), debug=True) + app() diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 2a269183..5e788864 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -23,7 +23,7 @@ set_tracing_disabled, ) from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult -from agents.run import DEFAULT_MAX_TURNS, RunHooks +from agents.run import DEFAULT_MAX_TURNS from dotenv import find_dotenv, load_dotenv from openai import AsyncOpenAI diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index c07aebe7..171900cf 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -96,11 +96,11 @@ def list_capi_models(token: str) -> dict[str, dict]: models_list = r.json().get("data", []) for model in models_list: models[model.get("id")] = dict(model) - except httpx.RequestError as e: + except httpx.RequestError: logging.exception("Request error") - except json.JSONDecodeError as e: + except json.JSONDecodeError: logging.exception("JSON error") - except httpx.HTTPStatusError as e: + except httpx.HTTPStatusError: logging.exception("HTTP error") return models diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py new file mode 100644 index 00000000..743a5992 --- /dev/null +++ b/src/seclab_taskflow_agent/cli.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Command-line interface for the seclab-taskflow-agent. + +Provides the Typer-based CLI entry point, replacing the previous argparse +implementation. Supports personality mode (-p), taskflow mode (-t), +model listing (-l), and global variables (-g KEY=VALUE). +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Annotated + +import typer + +from .available_tools import AvailableTools +from .banner import get_banner +from .capi import get_AI_token, list_tool_call_models +from .path_utils import log_file_name + +app = typer.Typer( + name="seclab-taskflow-agent", + help="SecLab Taskflow Agent — secure and automated workflow execution.", + add_completion=False, + no_args_is_help=True, +) + + +def _parse_global(value: str) -> tuple[str, str]: + """Parse a ``KEY=VALUE`` string into a (key, value) pair.""" + if "=" not in value: + raise typer.BadParameter(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.") + key, _, val = value.partition("=") + return key.strip(), val.strip() + + +def _setup_logging() -> None: + """Configure root logger: file (DEBUG) + console (ERROR).""" + import os + from logging.handlers import RotatingFileHandler + + root = logging.getLogger("") + root.setLevel(logging.NOTSET) + + file_handler = RotatingFileHandler( + log_file_name("task_agent.log"), maxBytes=10 * 1024 * 1024, backupCount=10 + ) + file_handler.setLevel(os.getenv("TASK_AGENT_LOGLEVEL", "DEBUG")) + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + root.addHandler(file_handler) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.ERROR) + console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + root.addHandler(console_handler) + + +@app.command() +def main( + personality: Annotated[ + str | None, + typer.Option("-p", "--personality", help="Personality module path (mutually exclusive with -t)."), + ] = None, + taskflow: Annotated[ + str | None, + typer.Option("-t", "--taskflow", help="Taskflow module path (mutually exclusive with -p)."), + ] = None, + list_models: Annotated[ + bool, + typer.Option("-l", "--list-models", help="List available tool-call models and exit."), + ] = False, + globals_: Annotated[ + list[str] | None, + typer.Option("-g", "--global", help="Global variable as KEY=VALUE. Repeatable."), + ] = None, + prompt: Annotated[ + list[str] | None, + typer.Argument(help="Remaining prompt text."), + ] = None, +) -> None: + """Run a taskflow or personality-based agent session.""" + # Validate mutual exclusivity + specified = sum(bool(x) for x in [personality, taskflow, list_models]) + if specified > 1: + typer.echo("Error: -p, -t, and -l are mutually exclusive.", err=True) + raise typer.Exit(code=1) + + _setup_logging() + + available_tools = AvailableTools() + + # List models mode + if list_models: + tool_models = list_tool_call_models(get_AI_token()) + for model in tool_models: + typer.echo(model) + raise typer.Exit() + + if personality is None and taskflow is None: + typer.echo("Error: one of -p or -t is required.", err=True) + raise typer.Exit(code=1) + + # Parse global variables + cli_globals: dict[str, str] = {} + for g in globals_ or []: + key, val = _parse_global(g) + cli_globals[key] = val + + user_prompt = " ".join(prompt) if prompt else "" + + typer.echo(get_banner()) + + from .runner import run_main + + asyncio.run( + run_main(available_tools, personality, taskflow, cli_globals, user_prompt), + debug=True, + ) + + +# --------------------------------------------------------------------------- +# Legacy compatibility shim +# --------------------------------------------------------------------------- + +def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = None): + """Legacy CLI parser kept for backwards compatibility with tests. + + Returns: + Tuple of (personality, taskflow, list_models, cli_globals, prompt, help_msg). + """ + import argparse + + parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent") + parser.prog = "" + group = parser.add_mutually_exclusive_group() + group.add_argument("-p", help="The personality to use (mutex with -t)", required=False) + group.add_argument("-t", help="The taskflow to use (mutex with -p)", required=False) + group.add_argument("-l", help="List available tool call models and exit", action="store_true", required=False) + parser.add_argument( + "-g", + "--global", + dest="globals", + action="append", + help="Set global variable (KEY=VALUE). Can be used multiple times.", + required=False, + ) + parser.add_argument("prompt", nargs=argparse.REMAINDER) + + help_msg = parser.format_help() + help_msg += "\nExamples:\n\n" + help_msg += "`-p seclab_taskflow_agent.personalities.assistant explain modems to me please`\n" + help_msg += "`-t examples.taskflows.example_globals -g fruit=apples`\n" + try: + args = parser.parse_known_args(user_prompt.split(" ") if user_prompt else None) + except SystemExit as e: + if e.code == 2: + logging.exception(f"User provided incomplete prompt: {user_prompt}") + return None, None, None, None, help_msg + p = args[0].p.strip() if args[0].p else None + t = args[0].t.strip() if args[0].t else None + list_models = args[0].l + + cli_globals: dict[str, str] = {} + if args[0].globals: + for g in args[0].globals: + if "=" not in g: + logging.error(f"Invalid global variable format: {g}. Expected KEY=VALUE") + return None, None, None, None, None, help_msg + key, value = g.split("=", 1) + cli_globals[key.strip()] = value.strip() + + return p, t, list_models, cli_globals, " ".join(args[0].prompt), help_msg diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py new file mode 100644 index 00000000..72082060 --- /dev/null +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""MCP server lifecycle management. + +Handles connecting, running, and cleaning up MCP server instances +used during taskflow execution. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING + +from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter + +from .mcp_utils import ( + DEFAULT_MCP_CLIENT_SESSION_TIMEOUT, + MCPNamespaceWrap, + ReconnectingMCPServerStdio, + StreamableMCPThread, + mcp_client_params, +) + +if TYPE_CHECKING: + from .available_tools import AvailableTools + +MCP_CLEANUP_TIMEOUT = 5 + + +class MCPServerEntry: + """A paired MCP server wrapper and optional local process.""" + + __slots__ = ("server", "process") + + def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None = None): + self.server = server + self.process = process + + +def build_mcp_servers( + available_tools: AvailableTools, + toolboxes: list[str], + blocked_tools: list[str] | None = None, + headless: bool = False, +) -> list[MCPServerEntry]: + """Build MCP server instances for the given toolboxes. + + Args: + available_tools: Tool registry for loading toolbox configs. + toolboxes: List of toolbox module paths. + blocked_tools: Tool names to block. + headless: If True, skip all confirmation prompts. + + Returns: + List of MCPServerEntry instances ready for connection. + """ + tool_filter = create_static_tool_filter(blocked_tool_names=blocked_tools) if blocked_tools else None + mcp_params = mcp_client_params(available_tools, toolboxes) + entries: list[MCPServerEntry] = [] + + for tb, (params, confirms, server_prompt, client_session_timeout) in mcp_params.items(): + if headless: + confirms = [] + client_session_timeout = client_session_timeout or DEFAULT_MCP_CLIENT_SESSION_TIMEOUT + server_proc = None + + match params["kind"]: + case "stdio": + if params.get("reconnecting", False): + mcp_server = ReconnectingMCPServerStdio( + name=tb, + params=params, + tool_filter=tool_filter, + client_session_timeout_seconds=client_session_timeout, + cache_tools_list=True, + ) + else: + mcp_server = MCPServerStdio( + name=tb, + params=params, + tool_filter=tool_filter, + client_session_timeout_seconds=client_session_timeout, + cache_tools_list=True, + ) + case "sse": + mcp_server = MCPServerSse( + name=tb, + params=params, + tool_filter=tool_filter, + client_session_timeout_seconds=client_session_timeout, + ) + case "streamable": + if "command" in params: + + def _print_out(line: str) -> None: + logging.info(f"Streamable MCP Server stdout: {line}") + + def _print_err(line: str) -> None: + logging.info(f"Streamable MCP Server stderr: {line}") + + server_proc = StreamableMCPThread( + params["command"], + url=params["url"], + env=params["env"], + on_output=_print_out, + on_error=_print_err, + ) + mcp_server = MCPServerStreamableHttp( + name=tb, + params=params, + tool_filter=tool_filter, + client_session_timeout_seconds=client_session_timeout, + ) + case _: + raise ValueError(f"Unsupported MCP transport: {params['kind']}") + + entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc)) + + return entries + + +async def mcp_session_task( + entries: list[MCPServerEntry], + connected: asyncio.Event, + cleanup: asyncio.Event, +) -> None: + """Background task that manages MCP server connect/cleanup lifecycle. + + Args: + entries: MCP server entries to manage. + connected: Event to signal when all servers are connected. + cleanup: Event to wait on before cleaning up. + """ + try: + for entry in entries: + logging.debug(f"Connecting mcp server: {entry.server._name}") + if entry.process is not None: + entry.process.start() + await entry.process.async_wait_for_connection(poll_interval=0.1) + await entry.server.connect() + + connected.set() + await cleanup.wait() + + for entry in reversed(entries): + try: + logging.debug(f"Starting cleanup for mcp server: {entry.server._name}") + await entry.server.cleanup() + logging.debug(f"Cleaned up mcp server: {entry.server._name}") + if entry.process is not None: + entry.process.stop() + try: + await asyncio.to_thread(entry.process.join_and_raise) + except Exception as e: + logging.warning(f"Streamable mcp server process exception: {e}") + except asyncio.CancelledError: + logging.exception(f"Timeout on cleanup for mcp server: {entry.server._name}") + finally: + entries.remove(entry) + except RuntimeError: + logging.exception("RuntimeError in mcp session task") + except asyncio.CancelledError: + logging.exception("Timeout on main session task") + finally: + entries.clear() diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py new file mode 100644 index 00000000..aaddc079 --- /dev/null +++ b/src/seclab_taskflow_agent/models.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Pydantic models for the seclab-taskflow-agent grammar. + +These models formally define the YAML grammar for taskflows, personalities, +toolboxes, model configs, and prompts. They provide validation at parse time +while maintaining full backwards compatibility with existing YAML files. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + + +# --------------------------------------------------------------------------- +# Header +# --------------------------------------------------------------------------- + +SUPPORTED_VERSION = "1.0" + + +class TaskflowHeader(BaseModel): + """The ``seclab-taskflow-agent`` header block present in every YAML file.""" + + model_config = ConfigDict(populate_by_name=True) + + version: str + filetype: str + + @field_validator("version", mode="before") + @classmethod + def _normalise_version(cls, v: Any) -> str: + """Accept int/float/str versions and normalise to ``"1.0"`` format.""" + if isinstance(v, int): + return f"{v}.0" + if isinstance(v, float): + return str(v) + return str(v) + + @field_validator("version", mode="after") + @classmethod + def _validate_version(cls, v: str) -> str: + if v != SUPPORTED_VERSION: + raise ValueError( + f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported." + ) + return v + + +# --------------------------------------------------------------------------- +# Task definition (a single step inside a taskflow) +# --------------------------------------------------------------------------- + +class TaskDefinition(BaseModel): + """A single task within a taskflow. + + This captures every field the engine currently recognises in a task block. + Extra fields are allowed for forward-compatibility. + """ + + model_config = ConfigDict(extra="allow") + + name: str = "taskflow" + description: str = "taskflow" + agents: list[str] = Field(default_factory=list) + user_prompt: str = "" + run: str = "" + model: str = "" + model_settings: dict[str, Any] = Field(default_factory=dict) + must_complete: bool = False + headless: bool = False + repeat_prompt: bool = False + exclude_from_context: bool = False + blocked_tools: list[str] = Field(default_factory=list) + toolboxes: list[str] = Field(default_factory=list) + env: dict[str, str] = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) + max_steps: int = 0 # 0 means use the runner default + uses: str = "" + + # async settings (``async`` is a reserved word, aliased) + async_task: bool = Field(default=False, alias="async") + async_limit: int = 5 + + @model_validator(mode="after") + def _run_xor_prompt(self) -> TaskDefinition: + if self.run and self.user_prompt: + raise ValueError("shell task ('run') and prompt task ('user_prompt') are mutually exclusive") + return self + + +class TaskWrapper(BaseModel): + """Wraps the ``- task:`` YAML list entry.""" + + task: TaskDefinition + + +# --------------------------------------------------------------------------- +# Top-level document types +# --------------------------------------------------------------------------- + +class TaskflowDocument(BaseModel): + """A complete taskflow YAML document. + + Example:: + + seclab-taskflow-agent: + version: "1.0" + filetype: taskflow + globals: + fruit: bananas + model_config_ref: examples.model_configs.model_config + taskflow: + - task: + ... + """ + + model_config = ConfigDict(extra="allow") + + header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + globals: dict[str, Any] = Field(default_factory=dict) + # ``model_config`` clashes with Pydantic's own ConfigDict, so we use an alias + model_config_ref: str = Field(default="", alias="model_config") + taskflow: list[TaskWrapper] = Field(default_factory=list) + + @field_validator("taskflow", mode="before") + @classmethod + def _coerce_taskflow_list(cls, v: Any) -> list[Any]: + if v is None: + return [] + return v + + +class PersonalityDocument(BaseModel): + """A personality YAML document.""" + + model_config = ConfigDict(extra="allow") + + header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + personality: str = "" + task: str = "" + toolboxes: list[str] = Field(default_factory=list) + + +class ServerParams(BaseModel): + """MCP server connection parameters inside a toolbox.""" + + model_config = ConfigDict(extra="allow") + + kind: str + command: str | None = None + args: list[str] | None = None + env: dict[str, str] | None = None + url: str | None = None + headers: dict[str, str] | None = None + optional_headers: dict[str, str] | None = None + timeout: float | None = None + reconnecting: bool = False + + +class ToolboxDocument(BaseModel): + """A toolbox YAML document defining an MCP server configuration.""" + + model_config = ConfigDict(extra="allow") + + header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + server_params: ServerParams + server_prompt: str = "" + confirm: list[str] = Field(default_factory=list) + client_session_timeout: float = 0 + + +class ModelConfigDocument(BaseModel): + """A model_config YAML document mapping logical model names to provider IDs.""" + + model_config = ConfigDict(extra="allow") + + header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + models: dict[str, str] = Field(default_factory=dict) + model_settings: dict[str, dict[str, Any]] = Field(default_factory=dict) + + +class PromptDocument(BaseModel): + """A reusable prompt YAML document.""" + + model_config = ConfigDict(extra="allow") + + header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + prompt: str = "" + + +# --------------------------------------------------------------------------- +# Mapping from filetype string → Pydantic model +# --------------------------------------------------------------------------- + +DOCUMENT_MODELS: dict[str, type[BaseModel]] = { + "taskflow": TaskflowDocument, + "personality": PersonalityDocument, + "toolbox": ToolboxDocument, + "model_config": ModelConfigDocument, + "prompt": PromptDocument, +} diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py new file mode 100644 index 00000000..c4c5f66a --- /dev/null +++ b/src/seclab_taskflow_agent/runner.py @@ -0,0 +1,477 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Taskflow execution engine. + +Contains the core logic for deploying task agents, executing taskflows, +and managing the agent lifecycle. Extracted from the original monolithic +``__main__.py``. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import uuid +from typing import Any + +import jinja2 +from agents import Agent, RunContextWrapper, TContext, Tool +from agents.agent import ModelSettings +from agents.exceptions import AgentsException, MaxTurnsExceeded +from agents.extensions.handoff_prompt import prompt_with_handoff_instructions +from openai import APITimeoutError, BadRequestError, RateLimitError +from openai.types.responses import ResponseTextDeltaEvent + +from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks +from .available_tools import AvailableTools +from .env_utils import TmpEnv +from .mcp_lifecycle import MCP_CLEANUP_TIMEOUT, build_mcp_servers, mcp_session_task +from .mcp_utils import compress_name, mcp_client_params, mcp_system_prompt +from .render_utils import flush_async_output, render_model_output +from .shell_utils import shell_tool_call +from .template_utils import render_template + +DEFAULT_MAX_TURNS = 50 +RATE_LIMIT_BACKOFF = 5 +MAX_RATE_LIMIT_BACKOFF = 120 +MAX_API_RETRY = 5 + + +async def deploy_task_agents( + available_tools: AvailableTools, + agents: dict[str, Any], + prompt: str, + *, + async_task: bool = False, + toolboxes_override: list[str] | None = None, + blocked_tools: list[str] | None = None, + headless: bool = False, + exclude_from_context: bool = False, + max_turns: int = DEFAULT_MAX_TURNS, + model: str = DEFAULT_MODEL, + model_par: dict[str, Any] | None = None, + run_hooks: TaskRunHooks | None = None, + agent_hooks: TaskAgentHooks | None = None, +) -> bool: + """Deploy and run task agents with MCP servers. + + Args: + available_tools: Tool registry. + agents: Mapping of agent name → personality config. + prompt: User prompt to execute. + async_task: Whether this is an async (concurrent) task. + toolboxes_override: Override personality toolboxes with these. + blocked_tools: Tool names to block. + headless: Skip confirmation prompts. + exclude_from_context: Exclude tool results from LLM context. + max_turns: Maximum agent turns. + model: Model identifier. + model_par: Additional model parameters. + run_hooks: Custom run hooks. + agent_hooks: Custom agent hooks. + + Returns: + True if the task completed successfully. + """ + model_par = model_par or {} + toolboxes_override = toolboxes_override or [] + blocked_tools = blocked_tools or [] + + task_id = str(uuid.uuid4()) + await render_model_output(f"** 🤖💪 Deploying Task Flow Agent(s): {list(agents.keys())}\n") + await render_model_output(f"** 🤖💪 Task ID: {task_id}\n") + await render_model_output(f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n") + + # Resolve toolboxes + toolboxes: list[str] = [] + if toolboxes_override: + toolboxes = toolboxes_override + else: + for v in agents.values(): + for tb in v.get("toolboxes", []): + if tb not in toolboxes: + toolboxes.append(tb) + + # Model settings + parallel_tool_calls = bool(os.getenv("MODEL_PARALLEL_TOOL_CALLS")) + model_params: dict[str, Any] = { + "temperature": os.getenv("MODEL_TEMP", default=0.0), + "tool_choice": "auto" if toolboxes else None, + "parallel_tool_calls": parallel_tool_calls if toolboxes else None, + } + model_params.update(model_par) + model_settings = ModelSettings(**model_params) + + # Build MCP servers and collect server prompts + entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless) + mcp_params = mcp_client_params(available_tools, toolboxes) + server_prompts = [sp for _, (_, _, sp, _) in mcp_params.items()] + + # Connect MCP servers + servers_connected = asyncio.Event() + start_cleanup = asyncio.Event() + mcp_sessions = asyncio.create_task(mcp_session_task(entries, servers_connected, start_cleanup)) + + await servers_connected.wait() + logging.debug("All mcp servers are connected!") + + try: + important_guidelines = [ + "Do not prompt the user with questions.", + "Run tasks until a final result is available.", + "Ensure responses are based on the latest information from available tools.", + "Run tools sequentially, wait until one tool has completed before calling the next.", + ] + + # Create handoff agents + handoffs = [] + agent_names = list(agents.keys()) + for handoff_agent in agent_names[1:]: + handoffs.append( + TaskAgent( + name=compress_name(handoff_agent), + instructions=prompt_with_handoff_instructions( + mcp_system_prompt( + agents[handoff_agent]["personality"], + agents[handoff_agent]["task"], + server_prompts=server_prompts, + important_guidelines=important_guidelines, + ) + ), + handoffs=[], + exclude_from_context=exclude_from_context, + mcp_servers=[e.server for e in entries], + model=model, + model_settings=model_settings, + run_hooks=run_hooks, + agent_hooks=agent_hooks, + ).agent + ) + + # Create primary agent + primary_agent = agent_names[0] + system_prompt = mcp_system_prompt( + agents[primary_agent]["personality"], + agents[primary_agent]["task"], + server_prompts=server_prompts, + important_guidelines=important_guidelines, + ) + agent0 = TaskAgent( + name=primary_agent, + instructions=prompt_with_handoff_instructions(system_prompt) if handoffs else system_prompt, + handoffs=handoffs, + exclude_from_context=exclude_from_context, + mcp_servers=[e.server for e in entries], + model=model, + model_settings=model_settings, + run_hooks=run_hooks, + agent_hooks=agent_hooks, + ) + + try: + complete = False + + async def _run_streamed() -> None: + max_retry = MAX_API_RETRY + rate_limit_backoff = RATE_LIMIT_BACKOFF + while rate_limit_backoff: + try: + result = agent0.run_streamed(prompt, max_turns=max_turns) + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + await render_model_output(event.data.delta, async_task=async_task, task_id=task_id) + await render_model_output("\n\n", async_task=async_task, task_id=task_id) + return + except APITimeoutError: + if not max_retry: + logging.exception("Max retries for APITimeoutError reached") + raise + max_retry -= 1 + except RateLimitError: + if rate_limit_backoff == MAX_RATE_LIMIT_BACKOFF: + raise APITimeoutError("Max rate limit backoff reached") + if rate_limit_backoff > MAX_RATE_LIMIT_BACKOFF: + rate_limit_backoff = MAX_RATE_LIMIT_BACKOFF + else: + rate_limit_backoff += rate_limit_backoff + logging.exception(f"Hit rate limit ... holding for {rate_limit_backoff}") + await asyncio.sleep(rate_limit_backoff) + + await _run_streamed() + complete = True + + except MaxTurnsExceeded as e: + await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n", async_task=async_task, task_id=task_id) + logging.exception(f"Exceeded max_turns: {max_turns}") + except AgentsException as e: + await render_model_output(f"** 🤖❗ Agent Exception: {e}\n", async_task=async_task, task_id=task_id) + logging.exception("Agent Exception") + except BadRequestError as e: + await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id) + logging.exception("Bad Request") + except APITimeoutError as e: + await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id) + logging.exception("API Timeout") + + if async_task: + await flush_async_output(task_id) + + return complete + + finally: + start_cleanup.set() + cleanup_attempts_left = len(entries) + while cleanup_attempts_left and entries: + try: + cleanup_attempts_left -= 1 + await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT) + except asyncio.TimeoutError: + continue + except Exception: + logging.exception("Exception in mcp server cleanup task") + + +async def run_main( + available_tools: AvailableTools, + p: str | None, + t: str | None, + cli_globals: dict[str, str], + prompt: str | None, +) -> None: + """Main entry point for taskflow/personality execution. + + Args: + available_tools: Tool registry. + p: Personality module path, or None. + t: Taskflow module path, or None. + cli_globals: Global variables from CLI. + prompt: User prompt text. + """ + last_mcp_tool_results: list[str] = [] + + async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None: + last_mcp_tool_results.append(result) + + async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: + await render_model_output(f"\n** 🤖🛠️ Tool Call: {tool.name}\n") + + async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None: + await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n") + + if p: + personality = available_tools.get_personality(p) + await deploy_task_agents( + available_tools, + {p: personality}, + prompt or "", + run_hooks=TaskRunHooks(on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook), + ) + + if t: + taskflow = available_tools.get_taskflow(t) + await render_model_output(f"** 🤖💪 Running Task Flow: {t}\n") + + # Resolve global variables (file defaults + CLI overrides) + global_variables = taskflow.get("globals", {}) + if cli_globals: + global_variables.update(cli_globals) + + # Resolve model config + model_config = taskflow.get("model_config", {}) + model_keys: list[str] = [] + models_params: dict[str, dict[str, Any]] = {} + if model_config: + m_config = available_tools.get_model_config(model_config) + model_dict = m_config.get("models", {}) + if model_dict and not isinstance(model_dict, dict): + raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary") + model_keys = list(model_dict.keys()) + models_params = m_config.get("model_settings", {}) + if models_params and not isinstance(models_params, dict): + raise ValueError(f"Settings section of model_config file {model_config} must be a dictionary") + if not set(models_params.keys()).difference(model_keys).issubset(set()): + raise ValueError( + f"Settings section of model_config file {model_config} contains models not in the model section" + ) + for k, v in models_params.items(): + if not isinstance(v, dict): + raise ValueError(f"Settings for model {k} in model_config file {model_config} is not a dictionary") + + for task_entry in taskflow["taskflow"]: + task_body = task_entry["task"] + + # Reusable taskflow support + uses = task_body.get("uses", "") + if uses: + reusable_taskflow = available_tools.get_taskflow(uses) + if reusable_taskflow is None: + raise ValueError(f"No such reusable taskflow: {uses}") + if len(reusable_taskflow["taskflow"]) > 1: + raise ValueError("Reusable taskflows can only contain 1 task") + for k, v in reusable_taskflow["taskflow"][0]["task"].items(): + if k not in task_body: + task_body[k] = v + + # Resolve model + model = task_body.get("model", DEFAULT_MODEL) + model_settings: dict[str, Any] = {} + if model in model_keys: + if model in models_params: + model_settings = models_params[model].copy() + model = model_dict[model] + task_model_settings = task_body.get("model_settings", {}) + if not isinstance(task_model_settings, dict): + name = task_body.get("name", "") + raise ValueError(f"model_settings in task {name} needs to be a dictionary") + model_settings.update(task_model_settings) + + # Parse task fields + agents_list = task_body.get("agents", []) + headless = task_body.get("headless", False) + blocked_tools = task_body.get("blocked_tools", []) + run = task_body.get("run", "") + inputs = task_body.get("inputs", {}) + task_prompt = task_body.get("user_prompt", "") + if run and task_prompt: + raise ValueError("shell task and prompt task are mutually exclusive!") + must_complete = task_body.get("must_complete", False) + max_turns = task_body.get("max_steps", DEFAULT_MAX_TURNS) + toolboxes_override = task_body.get("toolboxes", []) + env = task_body.get("env", {}) + repeat_prompt = task_body.get("repeat_prompt", False) + exclude_from_context = task_body.get("exclude_from_context", False) + async_task = task_body.get("async", False) + max_concurrent_tasks = task_body.get("async_limit", 5) + + # Render prompt template (skip if repeat_prompt — result not yet available) + if task_prompt and not repeat_prompt: + try: + task_prompt = render_template( + template_str=task_prompt, + available_tools=available_tools, + globals_dict=global_variables, + inputs_dict=inputs, + ) + except jinja2.TemplateError as e: + logging.error(f"Template rendering error: {e}") + raise ValueError(f"Failed to render prompt template: {e}") + + with TmpEnv(env): + prompts_to_run: list[str] = [] + if repeat_prompt: + if "result" not in task_prompt.lower(): + logging.warning("repeat_prompt enabled but no {{ result }} in prompt") + try: + last_result = json.loads(last_mcp_tool_results.pop()) + text = last_result.get("text", "") + try: + iterable_result = json.loads(text) + except json.JSONDecodeError as exc: + logging.critical(f"Could not parse result text: {text}") + raise ValueError("Result text is not valid JSON") from exc + try: + iter(iterable_result) + except TypeError: + logging.critical("Last MCP tool result is not iterable") + raise + except IndexError: + logging.critical("No last MCP tool result available") + raise + + if not iterable_result: + await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") + else: + logging.debug(f"Rendering templated prompts for results: {iterable_result}") + for value in iterable_result: + try: + rendered_prompt = render_template( + template_str=task_prompt, + available_tools=available_tools, + globals_dict=global_variables, + inputs_dict=inputs, + result_value=value, + ) + prompts_to_run.append(rendered_prompt) + except jinja2.TemplateError as e: + logging.error(f"Error rendering template for result {value}: {e}") + raise ValueError(f"Template rendering failed: {e}") + else: + prompts_to_run.append(task_prompt) + + async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -> bool: + if run: + await render_model_output("** 🤖🐚 Executing Shell Task\n") + try: + result = shell_tool_call(run).content[0].model_dump_json() + last_mcp_tool_results.append(result) + return True + except RuntimeError as e: + await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n") + logging.exception("Shell task error") + return False + + tasks: list[Any] = [] + task_results: list[Any] = [] + semaphore = asyncio.Semaphore(max_concurrent_tasks) + for p_prompt in prompts_to_run: + resolved_agents: dict[str, Any] = {} + current_agents = list(agents_list) + if not current_agents: + from .cli import parse_prompt_args + p_val, _, _, _, p_prompt, _ = parse_prompt_args(available_tools, p_prompt) + if p_val: + current_agents.append(p_val) + for agent_name in current_agents: + personality = available_tools.get_personality(agent_name) + if personality is None: + raise ValueError(f"No such personality: {agent_name}") + resolved_agents[agent_name] = personality + + async def _deploy(ra: dict, pp: str) -> bool: + async with semaphore: + return await deploy_task_agents( + available_tools, + ra, + pp, + async_task=async_task, + toolboxes_override=toolboxes_override, + blocked_tools=blocked_tools, + headless=headless, + exclude_from_context=exclude_from_context, + max_turns=max_turns, + run_hooks=TaskRunHooks( + on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook + ), + model=model, + model_par=model_settings, + agent_hooks=TaskAgentHooks(on_handoff=on_handoff_hook), + ) + + task_coroutine = _deploy(resolved_agents, p_prompt) + + if not async_task: + result = await task_coroutine + task_results.append(result) + else: + tasks.append(task_coroutine) + + if async_task: + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + complete = True + for result in task_results: + if isinstance(result, Exception): + logging.error(f"Caught exception in Gather: {result}") + result = False + complete = result and complete + return complete + + task_complete = await run_prompts(async_task=async_task, max_concurrent_tasks=max_concurrent_tasks) + + if must_complete and not task_complete: + logging.critical("Required task not completed ... aborting!") + await render_model_output("🤖💥 *Required task not completed ...\n") + break diff --git a/tests/test_cli_parser.py b/tests/test_cli_parser.py index 0dd1bc31..92d40589 100644 --- a/tests/test_cli_parser.py +++ b/tests/test_cli_parser.py @@ -19,12 +19,12 @@ def test_parse_single_global(self): available_tools = AvailableTools() - p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g fruit=apples") + p, t, _l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g fruit=apples") assert t == "example" assert cli_globals == {"fruit": "apples"} assert p is None - assert l is False + assert _l is False def test_parse_multiple_globals(self): """Test parsing multiple global variables from command line.""" @@ -32,14 +32,14 @@ def test_parse_multiple_globals(self): available_tools = AvailableTools() - p, t, l, cli_globals, user_prompt, _ = parse_prompt_args( + p, t, _l, cli_globals, user_prompt, _ = parse_prompt_args( available_tools, "-t example -g fruit=apples -g color=red" ) assert t == "example" assert cli_globals == {"fruit": "apples", "color": "red"} assert p is None - assert l is False + assert _l is False def test_parse_global_with_spaces(self): """Test parsing global variables with spaces in values.""" @@ -47,7 +47,7 @@ def test_parse_global_with_spaces(self): available_tools = AvailableTools() - p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g message=hello world") + p, t, _l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g message=hello world") assert t == "example" # "world" becomes part of the prompt, not the value @@ -60,7 +60,7 @@ def test_parse_global_with_equals_in_value(self): available_tools = AvailableTools() - p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g equation=x=5") + p, t, _l, cli_globals, user_prompt, _ = parse_prompt_args(available_tools, "-t example -g equation=x=5") assert t == "example" assert cli_globals == {"equation": "x=5"} diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..6db4cc46 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for Pydantic grammar models.""" + +import pytest +from pydantic import ValidationError + +from seclab_taskflow_agent.models import ( + ModelConfigDocument, + PersonalityDocument, + PromptDocument, + ServerParams, + TaskDefinition, + TaskflowDocument, + TaskflowHeader, + ToolboxDocument, +) + + +class TestTaskflowHeader: + """Test the grammar header validation.""" + + def test_string_version(self): + h = TaskflowHeader(version="1.0", filetype="taskflow") + assert h.version == "1.0" + + def test_integer_version_normalised(self): + h = TaskflowHeader(version=1, filetype="taskflow") + assert h.version == "1.0" + + def test_float_version_normalised(self): + h = TaskflowHeader(version=1.0, filetype="taskflow") + assert h.version == "1.0" + + def test_unsupported_version_rejected(self): + with pytest.raises(ValidationError, match="Unsupported version"): + TaskflowHeader(version="2.0", filetype="taskflow") + + def test_filetype_preserved(self): + h = TaskflowHeader(version="1.0", filetype="personality") + assert h.filetype == "personality" + + +class TestTaskDefinition: + """Test single task validation.""" + + def test_defaults(self): + t = TaskDefinition() + assert t.agents == [] + assert t.user_prompt == "" + assert t.must_complete is False + assert t.async_task is False + assert t.async_limit == 5 + assert t.max_steps == 0 + + def test_all_fields(self): + t = TaskDefinition( + name="test-task", + agents=["personality.a"], + user_prompt="Hello {{ globals.x }}", + model="gpt-4o", + must_complete=True, + headless=True, + repeat_prompt=True, + toolboxes=["toolbox.a"], + env={"KEY": "val"}, + max_steps=20, + **{"async": True}, + async_limit=3, + ) + assert t.name == "test-task" + assert t.async_task is True + assert t.async_limit == 3 + assert t.max_steps == 20 + + def test_run_and_prompt_mutually_exclusive(self): + with pytest.raises(ValidationError, match="mutually exclusive"): + TaskDefinition(run="echo hi", user_prompt="Hello") + + def test_extra_fields_allowed(self): + t = TaskDefinition(future_field="value") + assert t.model_extra["future_field"] == "value" + + +class TestTaskflowDocument: + """Test complete taskflow document parsing.""" + + def test_minimal(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "taskflow"}, + "taskflow": [ + {"task": {"agents": ["p.a"], "user_prompt": "Hello"}}, + ], + } + doc = TaskflowDocument(**data) + assert doc.header.filetype == "taskflow" + assert len(doc.taskflow) == 1 + assert doc.taskflow[0].task.user_prompt == "Hello" + + def test_with_globals_and_model_config(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "taskflow"}, + "globals": {"fruit": "bananas"}, + "model_config": "examples.model_configs.model_config", + "taskflow": [], + } + doc = TaskflowDocument(**data) + assert doc.globals == {"fruit": "bananas"} + assert doc.model_config_ref == "examples.model_configs.model_config" + + def test_null_taskflow(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "taskflow"}, + "taskflow": None, + } + doc = TaskflowDocument(**data) + assert doc.taskflow == [] + + def test_integer_version(self): + data = { + "seclab-taskflow-agent": {"version": 1, "filetype": "taskflow"}, + "taskflow": [], + } + doc = TaskflowDocument(**data) + assert doc.header.version == "1.0" + + +class TestPersonalityDocument: + """Test personality document parsing.""" + + def test_full_personality(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "personality"}, + "personality": "You are a helpful assistant.\n", + "task": "Answer any question.\n", + "toolboxes": ["seclab_taskflow_agent.toolboxes.memcache"], + } + doc = PersonalityDocument(**data) + assert doc.personality == "You are a helpful assistant.\n" + assert len(doc.toolboxes) == 1 + + def test_minimal_personality(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "personality"}, + } + doc = PersonalityDocument(**data) + assert doc.personality == "" + assert doc.toolboxes == [] + + +class TestToolboxDocument: + """Test toolbox document parsing.""" + + def test_stdio_toolbox(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "toolbox"}, + "server_params": { + "kind": "stdio", + "command": "python", + "args": ["-m", "module.server"], + "env": {"KEY": "value"}, + }, + "confirm": ["dangerous_tool"], + } + doc = ToolboxDocument(**data) + assert doc.server_params.kind == "stdio" + assert doc.server_params.command == "python" + assert doc.confirm == ["dangerous_tool"] + + def test_streamable_toolbox(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "toolbox"}, + "server_params": { + "kind": "streamable", + "url": "http://localhost:9999/mcp", + "command": "python", + "args": ["-m", "module.server"], + }, + "server_prompt": "Use this server for queries.", + } + doc = ToolboxDocument(**data) + assert doc.server_params.kind == "streamable" + assert doc.server_params.url == "http://localhost:9999/mcp" + assert doc.server_prompt == "Use this server for queries." + + +class TestModelConfigDocument: + """Test model config document parsing.""" + + def test_full_config(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "model_config"}, + "models": {"gpt_default": "gpt-4.1", "gpt_latest": "gpt-5"}, + "model_settings": { + "gpt_default": {"temperature": 0.7}, + }, + } + doc = ModelConfigDocument(**data) + assert doc.models["gpt_default"] == "gpt-4.1" + assert doc.model_settings["gpt_default"]["temperature"] == 0.7 + + +class TestPromptDocument: + """Test prompt document parsing.""" + + def test_prompt(self): + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "prompt"}, + "prompt": "Tell me about bananas.\n", + } + doc = PromptDocument(**data) + assert doc.prompt == "Tell me about bananas.\n" + + +class TestServerParams: + """Test server params validation.""" + + def test_extra_fields_allowed(self): + sp = ServerParams(kind="stdio", custom_field="hello") + assert sp.model_extra["custom_field"] == "hello" + + def test_minimal(self): + sp = ServerParams(kind="sse", url="http://localhost:8080") + assert sp.kind == "sse" + assert sp.command is None + + +class TestRealYAMLFiles: + """Test parsing actual project YAML files through Pydantic models.""" + + def test_parse_example_taskflow(self): + import yaml + + with open("examples/taskflows/example.yaml") as f: + data = yaml.safe_load(f) + doc = TaskflowDocument(**data) + assert len(doc.taskflow) == 4 + assert doc.model_config_ref == "examples.model_configs.model_config" + + def test_parse_echo_taskflow(self): + import yaml + + with open("examples/taskflows/echo.yaml") as f: + data = yaml.safe_load(f) + doc = TaskflowDocument(**data) + assert len(doc.taskflow) == 2 + assert doc.taskflow[0].task.must_complete is True + assert doc.taskflow[0].task.max_steps == 5 + + def test_parse_example_globals(self): + import yaml + + with open("examples/taskflows/example_globals.yaml") as f: + data = yaml.safe_load(f) + doc = TaskflowDocument(**data) + assert "fruit" in doc.globals + + def test_parse_personality(self): + import yaml + + with open("src/seclab_taskflow_agent/personalities/assistant.yaml") as f: + data = yaml.safe_load(f) + doc = PersonalityDocument(**data) + assert doc.personality != "" + + def test_parse_toolbox_memcache(self): + import yaml + + with open("src/seclab_taskflow_agent/toolboxes/memcache.yaml") as f: + data = yaml.safe_load(f) + doc = ToolboxDocument(**data) + assert doc.server_params.kind == "stdio" + assert "memcache_clear_cache" in doc.confirm + + def test_parse_toolbox_codeql(self): + import yaml + + with open("src/seclab_taskflow_agent/toolboxes/codeql.yaml") as f: + data = yaml.safe_load(f) + doc = ToolboxDocument(**data) + assert doc.server_params.kind == "streamable" + assert doc.server_prompt != "" + + def test_parse_model_config(self): + import yaml + + with open("examples/model_configs/model_config.yaml") as f: + data = yaml.safe_load(f) + doc = ModelConfigDocument(**data) + assert "gpt_default" in doc.models + + def test_parse_prompt(self): + import yaml + + with open("examples/prompts/example_prompt.yaml") as f: + data = yaml.safe_load(f) + doc = PromptDocument(**data) + assert "bananas" in doc.prompt.lower() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 915939ad6dfc32f9935ece8e58cb01f7a85ec116 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 12:39:13 -0400 Subject: [PATCH 02/20] refactor: wire Pydantic models into parser, runner, and tests Replace all raw dict access with typed model attributes throughout the engine. AvailableTools now returns Pydantic model instances, runner.py uses typed fields, and tests validate via model attributes. --- src/seclab_taskflow_agent/available_tools.py | 188 ++++++++++++------- src/seclab_taskflow_agent/mcp_utils.py | 123 ++++++------ src/seclab_taskflow_agent/runner.py | 130 ++++++------- src/seclab_taskflow_agent/template_utils.py | 2 +- tests/test_cli_parser.py | 4 +- tests/test_template_utils.py | 2 +- tests/test_yaml_parser.py | 35 ++-- 7 files changed, 264 insertions(+), 220 deletions(-) diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index fa2111d2..3966c756 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -1,10 +1,29 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""YAML resource loader for taskflow grammar files. + +Loads and caches personality, taskflow, toolbox, model_config, and prompt +YAML files, validating them against Pydantic grammar models at parse time. +""" + +from __future__ import annotations + import importlib.resources from enum import Enum +from typing import Union import yaml +from pydantic import ValidationError + +from .models import ( + DOCUMENT_MODELS, + ModelConfigDocument, + PersonalityDocument, + PromptDocument, + TaskflowDocument, + ToolboxDocument, +) class BadToolNameError(Exception): @@ -27,85 +46,120 @@ class AvailableToolType(Enum): ModelConfig = "model_config" +# Union of all document model types returned by AvailableTools +DocumentModel = Union[ + TaskflowDocument, PersonalityDocument, ToolboxDocument, + ModelConfigDocument, PromptDocument, +] + + class AvailableTools: - """ - This class is used for storing dictionaries of all the available - personalities, taskflows, and prompts. - """ + """Loads, validates, and caches YAML grammar files as Pydantic models.""" + + def __init__(self) -> None: + self._cache: dict[AvailableToolType, dict[str, DocumentModel]] = {} + + def get_personality(self, name: str) -> PersonalityDocument: + """Load a personality YAML and return a validated PersonalityDocument.""" + return self._load(AvailableToolType.Personality, name) - def __init__(self): - self.__yamlcache = {} + def get_taskflow(self, name: str) -> TaskflowDocument: + """Load a taskflow YAML and return a validated TaskflowDocument.""" + return self._load(AvailableToolType.Taskflow, name) - def get_personality(self, name: str): - return self.get_tool(AvailableToolType.Personality, name) + def get_prompt(self, name: str) -> PromptDocument: + """Load a prompt YAML and return a validated PromptDocument.""" + return self._load(AvailableToolType.Prompt, name) - def get_taskflow(self, name: str): - return self.get_tool(AvailableToolType.Taskflow, name) + def get_toolbox(self, name: str) -> ToolboxDocument: + """Load a toolbox YAML and return a validated ToolboxDocument.""" + return self._load(AvailableToolType.Toolbox, name) - def get_prompt(self, name: str): - return self.get_tool(AvailableToolType.Prompt, name) + def get_model_config(self, name: str) -> ModelConfigDocument: + """Load a model_config YAML and return a validated ModelConfigDocument.""" + return self._load(AvailableToolType.ModelConfig, name) - def get_toolbox(self, name: str): - return self.get_tool(AvailableToolType.Toolbox, name) + # Keep legacy alias for code that uses the generic accessor + def get_tool(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: + """Generic loader — prefer the typed ``get_*()`` methods.""" + return self._load(tooltype, toolname) - def get_model_config(self, name: str): - return self.get_tool(AvailableToolType.ModelConfig, name) + def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: + """Load, validate, and cache a YAML grammar file. - def get_tool(self, tooltype: AvailableToolType, toolname: str): - """for example: available_tools.get_tool("personality", "personalities/fruit_expert") - This method first checks whether the tool has already been loaded. If not, it - finds the yaml file and parses it. It also checks that the filetype in the header - matches the expected tooltype. + Args: + tooltype: Expected file type (personality, taskflow, etc.). + toolname: Dotted module path, e.g. ``"examples.taskflows.echo"``. + + Returns: + A validated Pydantic document model instance. + + Raises: + BadToolNameError: If the tool cannot be found or loaded. + VersionException: If the grammar version is unsupported. + FileTypeException: If the filetype doesn't match expectations. """ - try: - return self.__yamlcache[tooltype][toolname] - except KeyError: - pass - # Split the string to get the package and filename. + # Check cache first + if tooltype in self._cache and toolname in self._cache[tooltype]: + return self._cache[tooltype][toolname] + + # Resolve package and filename from dotted path components = toolname.rsplit(".", 1) if len(components) != 2: raise BadToolNameError( - f'Not a valid toolname: "{toolname}". It should be something like: "packagename.filename"' + f'Not a valid toolname: "{toolname}". ' + f'Expected format: "packagename.filename"' ) - package = components[0] - filename = components[1] + package, filename = components + try: - d = importlib.resources.files(package) - if not d.is_dir(): - raise BadToolNameError(f"Cannot load {toolname} because {d} is not a valid directory.") - f = d.joinpath(filename + ".yaml") - with open(f) as s: - y = yaml.safe_load(s) - header = y['seclab-taskflow-agent'] - version = header['version'] - - # Normalize version to string format for backwards compatibility - if isinstance(version, int): - # Convert integer 1 to "1.0" for semver compatibility - version_str = f"{version}.0" - elif isinstance(version, float): - # Convert float 1.2 to "1.2" - version_str = str(version) - else: - # Already a string, use as-is - version_str = str(version) - - # Validate version is 1.0 - if version_str != "1.0": - raise VersionException( - f"Unsupported version: {version}. Only version 1.0 is supported." - ) - filetype = header['filetype'] - if filetype != tooltype.value: - raise FileTypeException(f"Error in {f}: expected filetype to be {tooltype}, but it's {filetype}.") - if tooltype not in self.__yamlcache: - self.__yamlcache[tooltype] = {} - self.__yamlcache[tooltype][toolname] = y - return y - except ModuleNotFoundError as e: - raise BadToolNameError(f"Cannot load {toolname}: {e}") + pkg_dir = importlib.resources.files(package) + if not pkg_dir.is_dir(): + raise BadToolNameError( + f"Cannot load {toolname} because {pkg_dir} is not a valid directory." + ) + filepath = pkg_dir.joinpath(filename + ".yaml") + with open(filepath) as fh: + raw = yaml.safe_load(fh) + + # Validate header before full parse + header = raw.get("seclab-taskflow-agent", {}) + filetype = header.get("filetype", "") + if filetype != tooltype.value: + raise FileTypeException( + f"Error in {filepath}: expected filetype {tooltype.value!r}, " + f"got {filetype!r}." + ) + + # Parse into the appropriate Pydantic model + model_cls = DOCUMENT_MODELS.get(filetype) + if model_cls is None: + raise BadToolNameError( + f"Unknown filetype {filetype!r} in {toolname}" + ) + + try: + doc = model_cls(**raw) + except ValidationError as exc: + # Surface version errors as VersionException for compat + for err in exc.errors(): + if "Unsupported version" in str(err.get("msg", "")): + raise VersionException(str(err["msg"])) from exc + raise BadToolNameError( + f"Validation error loading {toolname}: {exc}" + ) from exc + + # Cache and return + if tooltype not in self._cache: + self._cache[tooltype] = {} + self._cache[tooltype][toolname] = doc + return doc + + except ModuleNotFoundError as exc: + raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc except FileNotFoundError: - # deal with editor temp files etc. that might have disappeared - raise BadToolNameError(f"Cannot load {toolname} because {f} is not a valid file.") - except ValueError as e: - raise BadToolNameError(f"Cannot load {toolname}: {e}") + raise BadToolNameError( + f"Cannot load {toolname} because {filepath} is not a valid file." + ) + except ValueError as exc: + raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index 1fa7f6f5..f2f9f11b 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -17,7 +17,7 @@ from agents.mcp import MCPServerStdio from mcp.types import CallToolResult, TextContent -from .available_tools import AvailableTools, AvailableToolType +from .available_tools import AvailableTools from .env_utils import swap_env DEFAULT_MCP_CLIENT_SESSION_TIMEOUT = 120 @@ -284,20 +284,22 @@ async def call_tool(self, *args, **kwargs): def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list): - """Return all the data needed to initialize an mcp server client""" + """Return all the data needed to initialize an mcp server client.""" client_params = {} for tb in requested_toolboxes: - toolbox = available_tools.get_tool(AvailableToolType.Toolbox, tb) - kind = toolbox["server_params"].get("kind") - reconnecting = toolbox["server_params"].get("reconnecting", False) + toolbox = available_tools.get_toolbox(tb) + sp = toolbox.server_params + kind = sp.kind + reconnecting = sp.reconnecting server_params = {"kind": kind, "reconnecting": reconnecting} + match kind: case "stdio": - env = toolbox["server_params"].get("env") - args = toolbox["server_params"].get("args") + env = dict(sp.env) if sp.env else None + args = list(sp.args) if sp.args else None logging.debug(f"Initializing toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n") - if env and isinstance(env, dict): - for k, v in dict(env).items(): + if env: + for k, v in list(env.items()): try: env[k] = swap_env(v) except LookupError as e: @@ -305,85 +307,63 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list logging.info("Assuming toolbox has default configuration available") del env[k] logging.debug(f"Tool call environment: {env}") - if args and isinstance(args, list): + if args: for i, v in enumerate(args): args[i] = swap_env(v) logging.debug(f"Tool call args: {args}") - server_params["command"] = toolbox["server_params"].get("command") + server_params["command"] = sp.command server_params["args"] = args server_params["env"] = env - # XXX: SSE is deprecated in the MCP spec, but keep it around for now + case "sse": - headers = toolbox["server_params"].get("headers") - # support {{ env SOMETHING }} for header values as well for e.g. tokens - if headers and isinstance(headers, dict): + headers = dict(sp.headers) if sp.headers else None + if headers: for k, v in headers.items(): headers[k] = swap_env(v) - optional_headers = toolbox["server_params"].get("optional_headers") - # support {{ env SOMETHING }} for header values as well for e.g. tokens - if optional_headers and isinstance(optional_headers, dict): - for k, v in dict(optional_headers).items(): + optional_headers = dict(sp.optional_headers) if sp.optional_headers else None + if optional_headers: + for k, v in list(optional_headers.items()): try: optional_headers[k] = swap_env(v) except LookupError: del optional_headers[k] - if isinstance(headers, dict): - if isinstance(optional_headers, dict): - headers.update(optional_headers) - elif isinstance(optional_headers, dict): - headers = optional_headers - # if None will default to float(5) in client code - timeout = toolbox["server_params"].get("timeout") - server_params["url"] = toolbox["server_params"].get("url") + headers = _merge_headers(headers, optional_headers) + server_params["url"] = sp.url server_params["headers"] = headers - server_params["timeout"] = timeout - # for more involved local MCP servers, jsonrpc over stdio seems less than reliable - # as an alternative you can configure local toolboxes to use the streamable transport - # but still be started/stopped on demand similar to stdio mcp servers - # all it requires is a streamable config that also has cmd/args/env set + server_params["timeout"] = sp.timeout + case "streamable": - headers = toolbox["server_params"].get("headers") - # support {{ env SOMETHING }} for header values as well for e.g. tokens - if headers and isinstance(headers, dict): + headers = dict(sp.headers) if sp.headers else None + if headers: for k, v in headers.items(): headers[k] = swap_env(v) - optional_headers = toolbox["server_params"].get("optional_headers") - # support {{ env SOMETHING }} for header values as well for e.g. tokens - if optional_headers and isinstance(optional_headers, dict): - for k, v in dict(optional_headers).items(): + optional_headers = dict(sp.optional_headers) if sp.optional_headers else None + if optional_headers: + for k, v in list(optional_headers.items()): try: optional_headers[k] = swap_env(v) except LookupError: del optional_headers[k] - if isinstance(headers, dict): - if isinstance(optional_headers, dict): - headers.update(optional_headers) - elif isinstance(optional_headers, dict): - headers = optional_headers - # if None will default to float(5) in client code - timeout = toolbox["server_params"].get("timeout") - server_params["url"] = toolbox["server_params"].get("url") + headers = _merge_headers(headers, optional_headers) + server_params["url"] = sp.url server_params["headers"] = headers - server_params["timeout"] = timeout - # if command/args/env is set, we also need to start this MCP server ourselves - # this way we can use the streamable transport for MCP servers that get fussy - # over stdio jsonrpc polling - env = toolbox["server_params"].get("env") - args = toolbox["server_params"].get("args") - cmd = toolbox["server_params"].get("command") - if cmd is not None: + server_params["timeout"] = sp.timeout + + if sp.command is not None: + env = dict(sp.env) if sp.env else None + args = list(sp.args) if sp.args else None logging.debug(f"Initializing streamable toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n") - exe = shutil.which(cmd) + exe = shutil.which(sp.command) if exe is None: - raise FileNotFoundError(f"Could not resolve path to {cmd}") + raise FileNotFoundError(f"Could not resolve path to {sp.command}") start_cmd = [exe] - if args is not None and isinstance(args, list): + if args: for i, v in enumerate(args): args[i] = swap_env(v) start_cmd += args server_params["command"] = start_cmd - if env is not None and isinstance(env, dict): - for k, v in dict(env).items(): + if env: + for k, v in list(env.items()): try: env[k] = swap_env(v) except LookupError as e: @@ -391,15 +371,30 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list logging.info("Assuming toolbox has default configuration available") del env[k] server_params["env"] = env + case _: raise ValueError(f"Unsupported MCP transport {kind}") - confirms = toolbox.get("confirm", []) - server_prompt = toolbox.get("server_prompt", "") - client_session_timeout = float(toolbox.get("client_session_timeout", 0)) - client_params[tb] = (server_params, confirms, server_prompt, client_session_timeout) + + client_params[tb] = ( + server_params, + list(toolbox.confirm), + toolbox.server_prompt, + toolbox.client_session_timeout, + ) return client_params +def _merge_headers( + headers: dict[str, str] | None, + optional_headers: dict[str, str] | None, +) -> dict[str, str] | None: + """Merge required and optional headers.""" + if headers and optional_headers: + headers.update(optional_headers) + return headers + return headers or optional_headers + + def mcp_system_prompt( system_prompt: str, task: str, diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index c4c5f66a..330f49ef 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -29,6 +29,7 @@ from .available_tools import AvailableTools from .env_utils import TmpEnv from .mcp_lifecycle import MCP_CLEANUP_TIMEOUT, build_mcp_servers, mcp_session_task +from .models import PersonalityDocument, TaskDefinition from .mcp_utils import compress_name, mcp_client_params, mcp_system_prompt from .render_utils import flush_async_output, render_model_output from .shell_utils import shell_tool_call @@ -42,7 +43,7 @@ async def deploy_task_agents( available_tools: AvailableTools, - agents: dict[str, Any], + agents: dict[str, PersonalityDocument], prompt: str, *, async_task: bool = False, @@ -60,18 +61,8 @@ async def deploy_task_agents( Args: available_tools: Tool registry. - agents: Mapping of agent name → personality config. + agents: Mapping of agent name → PersonalityDocument. prompt: User prompt to execute. - async_task: Whether this is an async (concurrent) task. - toolboxes_override: Override personality toolboxes with these. - blocked_tools: Tool names to block. - headless: Skip confirmation prompts. - exclude_from_context: Exclude tool results from LLM context. - max_turns: Maximum agent turns. - model: Model identifier. - model_par: Additional model parameters. - run_hooks: Custom run hooks. - agent_hooks: Custom agent hooks. Returns: True if the task completed successfully. @@ -85,13 +76,13 @@ async def deploy_task_agents( await render_model_output(f"** 🤖💪 Task ID: {task_id}\n") await render_model_output(f"** 🤖💪 Model : {model}{', params: ' + str(model_par) if model_par else ''}\n") - # Resolve toolboxes + # Resolve toolboxes from personality definitions or override toolboxes: list[str] = [] if toolboxes_override: toolboxes = toolboxes_override else: - for v in agents.values(): - for tb in v.get("toolboxes", []): + for personality in agents.values(): + for tb in personality.toolboxes: if tb not in toolboxes: toolboxes.append(tb) @@ -126,17 +117,18 @@ async def deploy_task_agents( "Run tools sequentially, wait until one tool has completed before calling the next.", ] - # Create handoff agents + # Create handoff agents from additional personalities handoffs = [] agent_names = list(agents.keys()) - for handoff_agent in agent_names[1:]: + for handoff_name in agent_names[1:]: + personality = agents[handoff_name] handoffs.append( TaskAgent( - name=compress_name(handoff_agent), + name=compress_name(handoff_name), instructions=prompt_with_handoff_instructions( mcp_system_prompt( - agents[handoff_agent]["personality"], - agents[handoff_agent]["task"], + personality.personality, + personality.task, server_prompts=server_prompts, important_guidelines=important_guidelines, ) @@ -152,15 +144,16 @@ async def deploy_task_agents( ) # Create primary agent - primary_agent = agent_names[0] + primary_name = agent_names[0] + primary_personality = agents[primary_name] system_prompt = mcp_system_prompt( - agents[primary_agent]["personality"], - agents[primary_agent]["task"], + primary_personality.personality, + primary_personality.task, server_prompts=server_prompts, important_guidelines=important_guidelines, ) agent0 = TaskAgent( - name=primary_agent, + name=primary_name, instructions=prompt_with_handoff_instructions(system_prompt) if handoffs else system_prompt, handoffs=handoffs, exclude_from_context=exclude_from_context, @@ -271,80 +264,81 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo ) if t: - taskflow = available_tools.get_taskflow(t) + taskflow_doc = available_tools.get_taskflow(t) await render_model_output(f"** 🤖💪 Running Task Flow: {t}\n") # Resolve global variables (file defaults + CLI overrides) - global_variables = taskflow.get("globals", {}) + global_variables = dict(taskflow_doc.globals or {}) if cli_globals: global_variables.update(cli_globals) # Resolve model config - model_config = taskflow.get("model_config", {}) + model_config_ref = taskflow_doc.model_config_ref model_keys: list[str] = [] models_params: dict[str, dict[str, Any]] = {} - if model_config: - m_config = available_tools.get_model_config(model_config) - model_dict = m_config.get("models", {}) + if model_config_ref: + m_config = available_tools.get_model_config(model_config_ref) + model_dict = m_config.models or {} if model_dict and not isinstance(model_dict, dict): - raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary") + raise ValueError(f"Models section of the model_config file {model_config_ref} must be a dictionary") model_keys = list(model_dict.keys()) - models_params = m_config.get("model_settings", {}) + models_params = m_config.model_settings or {} if models_params and not isinstance(models_params, dict): - raise ValueError(f"Settings section of model_config file {model_config} must be a dictionary") + raise ValueError(f"Settings section of model_config file {model_config_ref} must be a dictionary") if not set(models_params.keys()).difference(model_keys).issubset(set()): raise ValueError( - f"Settings section of model_config file {model_config} contains models not in the model section" + f"Settings section of model_config file {model_config_ref} contains models not in the model section" ) for k, v in models_params.items(): if not isinstance(v, dict): - raise ValueError(f"Settings for model {k} in model_config file {model_config} is not a dictionary") - - for task_entry in taskflow["taskflow"]: - task_body = task_entry["task"] - - # Reusable taskflow support - uses = task_body.get("uses", "") - if uses: - reusable_taskflow = available_tools.get_taskflow(uses) - if reusable_taskflow is None: - raise ValueError(f"No such reusable taskflow: {uses}") - if len(reusable_taskflow["taskflow"]) > 1: + raise ValueError(f"Settings for model {k} in model_config file {model_config_ref} is not a dictionary") + + for task_wrapper in taskflow_doc.taskflow: + task = task_wrapper.task + + # Reusable taskflow support: merge parent defaults into current task + if task.uses: + reusable_doc = available_tools.get_taskflow(task.uses) + if reusable_doc is None: + raise ValueError(f"No such reusable taskflow: {task.uses}") + if len(reusable_doc.taskflow) > 1: raise ValueError("Reusable taskflows can only contain 1 task") - for k, v in reusable_taskflow["taskflow"][0]["task"].items(): - if k not in task_body: - task_body[k] = v + # Merge: parent fields fill in where current task has defaults + parent_task = reusable_doc.taskflow[0].task + merged = parent_task.model_dump(by_alias=True, exclude_defaults=True) + current = task.model_dump(by_alias=True, exclude_defaults=True) + merged.update(current) # current task overrides parent + task = TaskDefinition.model_validate(merged) # Resolve model - model = task_body.get("model", DEFAULT_MODEL) + model = task.model or DEFAULT_MODEL model_settings: dict[str, Any] = {} if model in model_keys: if model in models_params: model_settings = models_params[model].copy() model = model_dict[model] - task_model_settings = task_body.get("model_settings", {}) + task_model_settings = task.model_settings or {} if not isinstance(task_model_settings, dict): - name = task_body.get("name", "") - raise ValueError(f"model_settings in task {name} needs to be a dictionary") + raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") model_settings.update(task_model_settings) - # Parse task fields - agents_list = task_body.get("agents", []) - headless = task_body.get("headless", False) - blocked_tools = task_body.get("blocked_tools", []) - run = task_body.get("run", "") - inputs = task_body.get("inputs", {}) - task_prompt = task_body.get("user_prompt", "") + # Read task fields via typed attributes + agents_list = task.agents or [] + headless = task.headless + blocked_tools = task.blocked_tools or [] + run = task.run or "" + inputs = task.inputs or {} + task_prompt = task.user_prompt or "" if run and task_prompt: raise ValueError("shell task and prompt task are mutually exclusive!") - must_complete = task_body.get("must_complete", False) - max_turns = task_body.get("max_steps", DEFAULT_MAX_TURNS) - toolboxes_override = task_body.get("toolboxes", []) - env = task_body.get("env", {}) - repeat_prompt = task_body.get("repeat_prompt", False) - exclude_from_context = task_body.get("exclude_from_context", False) - async_task = task_body.get("async", False) - max_concurrent_tasks = task_body.get("async_limit", 5) + must_complete = task.must_complete + max_turns = task.max_steps or DEFAULT_MAX_TURNS + toolboxes_override = task.toolboxes or [] + env = task.env or {} + repeat_prompt = task.repeat_prompt + exclude_from_context = task.exclude_from_context + async_task = task.async_task + max_concurrent_tasks = task.async_limit # Render prompt template (skip if repeat_prompt — result not yet available) if task_prompt and not repeat_prompt: diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 43b220ca..67652432 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -37,7 +37,7 @@ def get_source(self, environment, template): prompt_data = self.available_tools.get_prompt(template) if not prompt_data: raise jinja2.TemplateNotFound(template) - source = prompt_data.get('prompt', '') + source = prompt_data.prompt or "" # Return: (source, filename, uptodate_func) return source, None, lambda: True except Exception: diff --git a/tests/test_cli_parser.py b/tests/test_cli_parser.py index 92d40589..ab00cd1a 100644 --- a/tests/test_cli_parser.py +++ b/tests/test_cli_parser.py @@ -70,8 +70,8 @@ def test_globals_in_taskflow_file(self): available_tools = AvailableTools() taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow") - assert "globals" in taskflow - assert taskflow["globals"]["test_var"] == "default_value" + assert taskflow.globals is not None + assert taskflow.globals["test_var"] == "default_value" if __name__ == "__main__": diff --git a/tests/test_template_utils.py b/tests/test_template_utils.py index d33b7531..f2f812a5 100644 --- a/tests/test_template_utils.py +++ b/tests/test_template_utils.py @@ -282,7 +282,7 @@ def test_reusable_taskflow_prompt_renders_variables(self): reusable_taskflow = available_tools.get_taskflow('tests.data.test_reusable_taskflow_with_variables') # Get the user_prompt from the reusable taskflow's task - user_prompt = reusable_taskflow['taskflow'][0]['task']['user_prompt'] + user_prompt = reusable_taskflow.taskflow[0].task.user_prompt # Render it with parent's globals and inputs (simulating what __main__.py does) result = render_template( diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index 6cb3bc7c..e3bef4df 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -18,13 +18,13 @@ class TestYamlParser: def test_yaml_parser_basic_functionality(self): """Test basic YAML parsing functionality.""" available_tools = AvailableTools() - personality000 = available_tools.get_personality( + personality = available_tools.get_personality( "tests.data.test_yaml_parser_personality000") - assert personality000['seclab-taskflow-agent']['version'] == "1.0" - assert personality000['seclab-taskflow-agent']['filetype'] == 'personality' - assert personality000['personality'] == 'You are a helpful assistant.\n' - assert personality000['task'] == 'Answer any question.\n' + assert personality.header.version == "1.0" + assert personality.header.filetype == "personality" + assert personality.personality == "You are a helpful assistant.\n" + assert personality.task == "Answer any question.\n" def test_version_integer_format(self): """Test that integer version format is accepted.""" @@ -32,9 +32,10 @@ def test_version_integer_format(self): personality = available_tools.get_personality( "tests.data.test_version_integer") - assert personality['seclab-taskflow-agent']['version'] == 1 - assert personality['seclab-taskflow-agent']['filetype'] == 'personality' - assert personality['personality'] == 'Test personality with integer version.\n' + # Version is normalized to "1.0" by the model + assert personality.header.version == "1.0" + assert personality.header.filetype == "personality" + assert personality.personality == "Test personality with integer version.\n" def test_version_float_format(self): """Test that float version format is accepted.""" @@ -42,24 +43,24 @@ def test_version_float_format(self): personality = available_tools.get_personality( "tests.data.test_version_float") - assert personality['seclab-taskflow-agent']['version'] == 1.0 - assert personality['seclab-taskflow-agent']['filetype'] == 'personality' - assert personality['personality'] == 'Test personality with float version.\n' + # Version is normalized to "1.0" by the model + assert personality.header.version == "1.0" + assert personality.header.filetype == "personality" + assert personality.personality == "Test personality with float version.\n" class TestRealTaskflowFiles: """Test parsing of actual taskflow files in the project.""" def test_parse_example_taskflows(self): """Test parsing example taskflow files.""" - # this test uses the actual taskflows in the project available_tools = AvailableTools() # check that example.yaml is parsed correctly - example_task_flow = available_tools.get_taskflow("examples.taskflows.example") - assert "taskflow" in example_task_flow - assert isinstance(example_task_flow["taskflow"], list) - assert len(example_task_flow["taskflow"]) == 4 # 4 tasks in taskflow - assert example_task_flow["taskflow"][0]["task"]["max_steps"] == 20 + example = available_tools.get_taskflow("examples.taskflows.example") + assert example.taskflow is not None + assert isinstance(example.taskflow, list) + assert len(example.taskflow) == 4 # 4 tasks in taskflow + assert example.taskflow[0].task.max_steps == 20 if __name__ == "__main__": From 58c39c9738f796c985f96106426034b1467a0d6b Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 13:59:55 -0400 Subject: [PATCH 03/20] refactor: decompose modules, add type hints and docstrings Split mcp_utils into mcp_transport and mcp_prompt. Extract prompt_parser to break cli/runner circular import. Decompose run_main into focused helpers. Add type hints and docstrings across all modules. Add __all__ exports, pyproject keywords and classifiers. --- pyproject.toml | 14 +- src/seclab_taskflow_agent/__init__.py | 37 +- src/seclab_taskflow_agent/__main__.py | 2 +- src/seclab_taskflow_agent/agent.py | 27 +- src/seclab_taskflow_agent/capi.py | 25 +- src/seclab_taskflow_agent/cli.py | 54 +-- src/seclab_taskflow_agent/env_utils.py | 10 +- src/seclab_taskflow_agent/mcp_lifecycle.py | 3 +- src/seclab_taskflow_agent/mcp_prompt.py | 103 +++++ src/seclab_taskflow_agent/mcp_transport.py | 291 +++++++++++++ src/seclab_taskflow_agent/mcp_utils.py | 452 ++++++-------------- src/seclab_taskflow_agent/prompt_parser.py | 66 +++ src/seclab_taskflow_agent/render_utils.py | 8 +- src/seclab_taskflow_agent/runner.py | 290 +++++++++---- src/seclab_taskflow_agent/shell_utils.py | 15 +- src/seclab_taskflow_agent/template_utils.py | 16 +- 16 files changed, 905 insertions(+), 508 deletions(-) create mode 100644 src/seclab_taskflow_agent/mcp_prompt.py create mode 100644 src/seclab_taskflow_agent/mcp_transport.py create mode 100644 src/seclab_taskflow_agent/prompt_parser.py diff --git a/pyproject.toml b/pyproject.toml index 0c6b6bef..894503de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,15 @@ description = "A taskflow agent for the SecLab project, enabling secure and auto readme = "README.md" requires-python = ">=3.10" license = "MIT" -keywords = [] +keywords = [ + "agentic-workflows", + "mcp", + "security-research", + "taskflow", + "yaml", + "pydantic", + "openai-agents", +] authors = [ { name = "GitHub Security Lab", email = "securitylab@github.com" }, ] @@ -19,8 +27,12 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", + "Intended Audience :: Developers", + "Topic :: Security", + "Topic :: Software Development :: Libraries :: Application Frameworks", ] dependencies = [ "aiofiles==24.1.0", diff --git a/src/seclab_taskflow_agent/__init__.py b/src/seclab_taskflow_agent/__init__.py index 5cf7c001..e10fb131 100644 --- a/src/seclab_taskflow_agent/__init__.py +++ b/src/seclab_taskflow_agent/__init__.py @@ -9,12 +9,39 @@ Architecture ~~~~~~~~~~~~ -- :mod:`~seclab_taskflow_agent.models` — Pydantic grammar models +- :mod:`~seclab_taskflow_agent.models` — Pydantic v2 grammar models +- :mod:`~seclab_taskflow_agent.available_tools` — YAML resource loader & cache - :mod:`~seclab_taskflow_agent.cli` — CLI entry point (Typer) - :mod:`~seclab_taskflow_agent.runner` — Taskflow execution engine -- :mod:`~seclab_taskflow_agent.agent` — Agent wrapper classes -- :mod:`~seclab_taskflow_agent.mcp_lifecycle` — MCP server lifecycle -- :mod:`~seclab_taskflow_agent.mcp_utils` — MCP utilities +- :mod:`~seclab_taskflow_agent.agent` — Agent / hooks wrapper classes +- :mod:`~seclab_taskflow_agent.mcp_lifecycle` — MCP server connect / cleanup +- :mod:`~seclab_taskflow_agent.mcp_utils` — MCP client parameter resolution +- :mod:`~seclab_taskflow_agent.mcp_transport` — MCP transport implementations +- :mod:`~seclab_taskflow_agent.mcp_prompt` — System prompt construction - :mod:`~seclab_taskflow_agent.template_utils` — Jinja2 template rendering -- :mod:`~seclab_taskflow_agent.available_tools` — YAML resource loader +- :mod:`~seclab_taskflow_agent.prompt_parser` — Legacy prompt argument parser """ + +__all__ = [ + "AvailableTools", + "TaskAgent", + "TaskRunHooks", + "TaskAgentHooks", + "PersonalityDocument", + "TaskflowDocument", + "ToolboxDocument", + "ModelConfigDocument", + "PromptDocument", + "TaskDefinition", +] + +from .agent import TaskAgent, TaskAgentHooks, TaskRunHooks +from .available_tools import AvailableTools +from .models import ( + ModelConfigDocument, + PersonalityDocument, + PromptDocument, + TaskDefinition, + TaskflowDocument, + ToolboxDocument, +) diff --git a/src/seclab_taskflow_agent/__main__.py b/src/seclab_taskflow_agent/__main__.py index 0f010138..1c1147ed 100644 --- a/src/seclab_taskflow_agent/__main__.py +++ b/src/seclab_taskflow_agent/__main__.py @@ -18,7 +18,7 @@ load_dotenv(find_dotenv(usecwd=True)) # Re-export for backwards compatibility — some tests import from __main__ -from .cli import parse_prompt_args # noqa: E402, F401 +from .prompt_parser import parse_prompt_args # noqa: E402, F401 from .runner import deploy_task_agents, run_main # noqa: E402, F401 if __name__ == "__main__": diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 5e788864..67567216 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -46,15 +46,17 @@ DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model) -# Run hooks monitor the entire lifetime of a runner, including across any Agent handoffs class TaskRunHooks(RunHooks): + """RunHooks that monitor the entire lifetime of a runner, including across Agent handoffs.""" + def __init__( self, on_agent_start: Callable | None = None, on_agent_end: Callable | None = None, on_tool_start: Callable | None = None, on_tool_end: Callable | None = None, - ): + ) -> None: + """Initialize with optional callback functions for each lifecycle event.""" self._on_agent_start = on_agent_start self._on_agent_end = on_agent_end self._on_tool_start = on_tool_start @@ -83,8 +85,9 @@ async def on_tool_end( await self._on_tool_end(context, agent, tool, result) -# Agent hooks monitor the lifetime of a single agent, not across any Agent handoffs class TaskAgentHooks(AgentHooks): + """AgentHooks that monitor the lifetime of a single agent, not across Agent handoffs.""" + def __init__( self, on_handoff: Callable | None = None, @@ -92,7 +95,8 @@ def __init__( on_end: Callable | None = None, on_tool_start: Callable | None = None, on_tool_end: Callable | None = None, - ): + ) -> None: + """Initialize with optional callback functions for each lifecycle event.""" self._on_handoff = on_handoff self._on_start = on_start self._on_end = on_end @@ -130,18 +134,25 @@ async def on_tool_end( class TaskAgent: + """High-level wrapper around the OpenAI Agent SDK. + + Configures the OpenAI client, creates an Agent with the given tools and + model, and exposes ``run`` / ``run_streamed`` entry points. + """ + def __init__( self, name: str = "TaskAgent", instructions: str = "", - handoffs: list = [], + handoffs: list[Any] = [], exclude_from_context: bool = False, - mcp_servers: dict = [], + mcp_servers: list[Any] = [], model: str = DEFAULT_MODEL, model_settings: ModelSettings | None = None, run_hooks: TaskRunHooks | None = None, agent_hooks: TaskAgentHooks | None = None, - ): + ) -> None: + """Create a TaskAgent with the specified configuration.""" client = AsyncOpenAI( base_url=api_endpoint, api_key=get_AI_token(), @@ -174,7 +185,9 @@ def _ToolsToFinalOutputFunction( ) async def run(self, prompt: str, max_turns: int = DEFAULT_MAX_TURNS) -> result.RunResult: + """Run the agent to completion and return the result.""" return await Runner.run(starting_agent=self.agent, input=prompt, max_turns=max_turns, hooks=self.run_hooks) def run_streamed(self, prompt: str, max_turns: int = DEFAULT_MAX_TURNS) -> result.RunResultStreaming: + """Run the agent with streaming output.""" return Runner.run_streamed(starting_agent=self.agent, input=prompt, max_turns=max_turns, hooks=self.run_hooks) diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 171900cf..a43333d1 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -17,10 +17,8 @@ class AI_API_ENDPOINT_ENUM(StrEnum): AI_API_GITHUBCOPILOT = "api.githubcopilot.com" AI_API_OPENAI = "api.openai.com" - def to_url(self): - """ - Convert the endpoint to its full URL. - """ + def to_url(self) -> str: + """Convert the endpoint to its full URL.""" match self: case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: return f"https://{self}" @@ -39,17 +37,14 @@ def to_url(self): # but beware that your taskflows need to reference the correct model id # since different APIs use their own id schema, use -l with your desired # endpoint to retrieve the correct id names to use for your taskflow -def get_AI_endpoint(): +def get_AI_endpoint() -> str: + """Return the configured AI API endpoint URL.""" return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference") -def get_AI_token(): - """ - Get the token for the AI API from the environment. - The environment variable can be named either AI_API_TOKEN - or COPILOT_TOKEN. - """ - token = os.getenv("AI_API_TOKEN") +def get_AI_token() -> str: + """Get the AI API token from AI_API_TOKEN or COPILOT_TOKEN env vars.""" + token: str | None = os.getenv("AI_API_TOKEN") if token: return token token = os.getenv("COPILOT_TOKEN") @@ -105,7 +100,8 @@ def list_capi_models(token: str) -> dict[str, dict]: return models -def supports_tool_calls(model: str, models: dict) -> bool: +def supports_tool_calls(model: str, models: dict[str, dict]) -> bool: + """Check whether the given model supports tool calls.""" api_endpoint = get_AI_endpoint() match urlparse(api_endpoint).netloc: case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: @@ -129,8 +125,9 @@ def supports_tool_calls(model: str, models: dict) -> bool: def list_tool_call_models(token: str) -> dict[str, dict]: + """Return only models that support tool calls.""" models = list_capi_models(token) - tool_models = {} + tool_models: dict[str, dict] = {} for model in models: if supports_tool_calls(model, models) is True: tool_models[model] = models[model] diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 743a5992..31cf626e 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -12,6 +12,7 @@ import asyncio import logging +import os from typing import Annotated import typer @@ -117,59 +118,12 @@ def main( asyncio.run( run_main(available_tools, personality, taskflow, cli_globals, user_prompt), - debug=True, + debug=os.getenv("TASK_AGENT_LOGLEVEL", "").upper() == "DEBUG", ) # --------------------------------------------------------------------------- -# Legacy compatibility shim +# Legacy compatibility shim — implementation moved to prompt_parser.py # --------------------------------------------------------------------------- -def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = None): - """Legacy CLI parser kept for backwards compatibility with tests. - - Returns: - Tuple of (personality, taskflow, list_models, cli_globals, prompt, help_msg). - """ - import argparse - - parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent") - parser.prog = "" - group = parser.add_mutually_exclusive_group() - group.add_argument("-p", help="The personality to use (mutex with -t)", required=False) - group.add_argument("-t", help="The taskflow to use (mutex with -p)", required=False) - group.add_argument("-l", help="List available tool call models and exit", action="store_true", required=False) - parser.add_argument( - "-g", - "--global", - dest="globals", - action="append", - help="Set global variable (KEY=VALUE). Can be used multiple times.", - required=False, - ) - parser.add_argument("prompt", nargs=argparse.REMAINDER) - - help_msg = parser.format_help() - help_msg += "\nExamples:\n\n" - help_msg += "`-p seclab_taskflow_agent.personalities.assistant explain modems to me please`\n" - help_msg += "`-t examples.taskflows.example_globals -g fruit=apples`\n" - try: - args = parser.parse_known_args(user_prompt.split(" ") if user_prompt else None) - except SystemExit as e: - if e.code == 2: - logging.exception(f"User provided incomplete prompt: {user_prompt}") - return None, None, None, None, help_msg - p = args[0].p.strip() if args[0].p else None - t = args[0].t.strip() if args[0].t else None - list_models = args[0].l - - cli_globals: dict[str, str] = {} - if args[0].globals: - for g in args[0].globals: - if "=" not in g: - logging.error(f"Invalid global variable format: {g}. Expected KEY=VALUE") - return None, None, None, None, None, help_msg - key, value = g.split("=", 1) - cli_globals[key.strip()] = value.strip() - - return p, t, list_models, cli_globals, " ".join(args[0].prompt), help_msg +from .prompt_parser import parse_prompt_args # noqa: F401, E402 diff --git a/src/seclab_taskflow_agent/env_utils.py b/src/seclab_taskflow_agent/env_utils.py index daea31b0..597d9ce2 100644 --- a/src/seclab_taskflow_agent/env_utils.py +++ b/src/seclab_taskflow_agent/env_utils.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT import os +from typing import Any + import jinja2 @@ -40,15 +42,17 @@ def swap_env(s: str) -> str: class TmpEnv: - def __init__(self, env): + """Context manager that temporarily sets environment variables.""" + + def __init__(self, env: dict[str, str]) -> None: self.env = dict(env) self.restore_env = dict(os.environ) - def __enter__(self): + def __enter__(self) -> None: for k, v in self.env.items(): os.environ[k] = swap_env(v) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: type | None, exc_val: BaseException | None, exc_tb: Any | None) -> None: for k, v in self.env.items(): del os.environ[k] if k in self.restore_env: diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index 72082060..365a85aa 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -15,11 +15,10 @@ from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter +from .mcp_transport import ReconnectingMCPServerStdio, StreamableMCPThread from .mcp_utils import ( DEFAULT_MCP_CLIENT_SESSION_TIMEOUT, MCPNamespaceWrap, - ReconnectingMCPServerStdio, - StreamableMCPThread, mcp_client_params, ) diff --git a/src/seclab_taskflow_agent/mcp_prompt.py b/src/seclab_taskflow_agent/mcp_prompt.py new file mode 100644 index 00000000..64b5b3ab --- /dev/null +++ b/src/seclab_taskflow_agent/mcp_prompt.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""MCP system-prompt construction. + +Assembles the full system prompt from a base personality prompt, available +tools/resources, important guidelines, and server-supplied instructions. +""" + +from __future__ import annotations + + +def mcp_system_prompt( + system_prompt: str, + task: str, + tools: list[str] | None = None, + resources: list[str] | None = None, + resource_templates: list[str] | None = None, + important_guidelines: list[str] | None = None, + server_prompts: list[str] | None = None, +) -> str: + """Build a well-structured system prompt for an MCP agent. + + Each optional section is appended only when its list is non-empty. + + Args: + system_prompt: Base personality / instruction text. + task: The primary task description for the agent. + tools: Human-readable tool descriptions. + resources: Human-readable resource descriptions. + resource_templates: Human-readable resource-template descriptions. + important_guidelines: Critical behavioural constraints. + server_prompts: Additional guidance supplied by MCP servers. + + Returns: + The fully assembled system prompt string. + """ + if tools is None: + tools = [] + if resources is None: + resources = [] + if resource_templates is None: + resource_templates = [] + if important_guidelines is None: + important_guidelines = [] + if server_prompts is None: + server_prompts = [] + + prompt = f""" +{system_prompt} +""" + + if tools: + prompt += """ + +# Available Tools + +- {tools} +""".format(tools="\n- ".join(tools)) + + if resources: + prompt += """ + +# Available Resources + +- {resources} +""".format(resources="\n- ".join(resources)) + + if resource_templates: + prompt += """ + +# Available Resource Templates + +- {resource_templates} +""".format(resource_templates="\n- ".join(resource_templates)) + + if important_guidelines: + prompt += """ + +# Important Guidelines + +- IMPORTANT: {guidelines} +""".format(guidelines="\n- IMPORTANT: ".join(important_guidelines)) + + if server_prompts: + prompt += """ + +# Additional Guidelines + +{server_prompts} + +""".format(server_prompts="\n\n".join(server_prompts)) + + if task: + prompt += f""" + +# Primary Task to Complete + +{task} + +""" + + return prompt diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py new file mode 100644 index 00000000..22ea1abe --- /dev/null +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""MCP transport-layer implementations. + +Provides thread-based local MCP server process management and specialised +stdio wrappers that work around asyncio and JSON-RPC edge cases. + +Classes: + StreamableMCPThread: Manages a local streamable MCP server process. + AsyncDebugMCPServerStdio: Debug wrapper that isolates the asyncio loop. + ReconnectingMCPServerStdio: Reconnecting wrapper for flaky stdio I/O. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import signal +import socket +import subprocess +import time +from collections.abc import Callable +from threading import Event, Thread +from typing import IO, Any +from urllib.parse import urlparse + +from agents.mcp import MCPServerStdio + +# Exit codes that are considered normal termination. +_EXPECTED_EXIT_CODES: frozenset[int] = frozenset({0, -signal.SIGTERM}) + + +class StreamableMCPThread(Thread): + """Thread that manages a local streamable MCP server subprocess. + + The thread starts the server, reads its stdout/stderr via callbacks, + and terminates the process when :meth:`stop` is called. + + Args: + cmd: Command-line tokens to launch the server. + url: URL the server will listen on (used for connection probes). + on_output: Callback invoked with each stdout line. + on_error: Callback invoked with each stderr line. + poll_interval: Seconds between process-alive checks. + env: Extra environment variables merged into the current env. + """ + + def __init__( + self, + cmd: list[str], + url: str = "", + on_output: Callable[[str], None] | None = None, + on_error: Callable[[str], None] | None = None, + poll_interval: float = 0.5, + env: dict[str, str] | None = None, + ) -> None: + super().__init__(daemon=True) + self.url: str = url + self.cmd: list[str] = cmd + self.on_output: Callable[[str], None] | None = on_output + self.on_error: Callable[[str], None] | None = on_error + self.poll_interval: float = poll_interval + self.env: dict[str, str] = os.environ.copy() # XXX: potential for environment leak to MCP + if env: + self.env.update(env) + self._stop_event: Event = Event() + self.process: subprocess.Popen[str] | None = None + self.exit_code: int | None = None + self.exception: BaseException | None = None + + async def async_wait_for_connection( + self, timeout: float = 30.0, poll_interval: float = 0.5 + ) -> None: + """Asynchronously poll until the server accepts TCP connections. + + Args: + timeout: Maximum seconds to wait. + poll_interval: Seconds between connection attempts. + + Raises: + ValueError: If *url* is missing host or port. + TimeoutError: If the server is not reachable within *timeout*. + """ + parsed = urlparse(self.url) + host = parsed.hostname + port = parsed.port + if host is None or port is None: + raise ValueError(f"URL must include a host and port: {self.url}") + deadline = asyncio.get_event_loop().time() + timeout + while True: + try: + reader, writer = await asyncio.open_connection(host, port) + writer.close() + await writer.wait_closed() + return + except (OSError, ConnectionRefusedError): + if asyncio.get_event_loop().time() > deadline: + raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + await asyncio.sleep(poll_interval) + + def wait_for_connection( + self, timeout: float = 30.0, poll_interval: float = 0.5 + ) -> None: + """Synchronously poll until the server accepts TCP connections. + + Args: + timeout: Maximum seconds to wait. + poll_interval: Seconds between connection attempts. + + Raises: + ValueError: If *url* is missing host or port. + TimeoutError: If the server is not reachable within *timeout*. + """ + parsed = urlparse(self.url) + host = parsed.hostname + port = parsed.port + if host is None or port is None: + raise ValueError(f"URL must include a host and port: {self.url}") + deadline = time.time() + timeout + while True: + try: + with socket.create_connection((host, port), timeout=2): + return + except OSError: + if time.time() > deadline: + raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + time.sleep(poll_interval) + + def run(self) -> None: + """Execute the subprocess and monitor it until stopped.""" + try: + self.process = subprocess.Popen( + self.cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + universal_newlines=True, + env=self.env, + ) + + stdout_thread = Thread(target=self._read_stream, args=(self.process.stdout, self.on_output)) + stderr_thread = Thread(target=self._read_stream, args=(self.process.stderr, self.on_error)) + stdout_thread.start() + stderr_thread.start() + + while self.process.poll() is None and not self._stop_event.is_set(): + time.sleep(self.poll_interval) + + if self.process.poll() is None: + self.process.terminate() + self.process.wait() + self.exit_code = self.process.returncode + + stdout_thread.join() + stderr_thread.join() + + if self.exit_code not in _EXPECTED_EXIT_CODES: + self.exception = subprocess.CalledProcessError(self.exit_code, self.cmd) + + except BaseException as e: + self.exception = e + + def _read_stream( + self, stream: IO[str] | None, callback: Callable[[str], None] | None + ) -> None: + """Drain *stream* line-by-line, forwarding to *callback*.""" + if stream is None or callback is None: + return + for line in iter(stream.readline, ""): + callback(line.rstrip("\n")) + stream.close() + + def stop(self) -> None: + """Request the subprocess to terminate.""" + self._stop_event.set() + if self.process and self.process.poll() is None: + self.process.terminate() + + def is_running(self) -> bool: + """Return whether the subprocess is still alive.""" + return self.process is not None and self.process.poll() is None + + def join_and_raise(self, timeout: float | None = None) -> None: + """Join the thread and re-raise any captured exception. + + Args: + timeout: Maximum seconds to wait for the thread to finish. + + Raises: + RuntimeError: If the thread is still alive after *timeout*. + """ + self.join(timeout) + if self.is_alive(): + raise RuntimeError("Process thread did not exit within timeout.") + if self.exception is not None: + raise self.exception + + +class AsyncDebugMCPServerStdio(MCPServerStdio): + """Debug wrapper that runs MCP stdio operations on a dedicated asyncio loop. + + Useful for diagnosing event-loop conflicts when the main loop is shared + with other coroutines. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + class _AsyncLoopThread(Thread): + """Daemon thread owning an isolated asyncio event loop.""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + + def run(self) -> None: + asyncio.set_event_loop(self.loop) + self.loop.run_forever() + + self.t = _AsyncLoopThread() + self.t.start() + self.lock: asyncio.Lock = asyncio.Lock() + + async def connect(self, *args: Any, **kwargs: Any) -> Any: + """Connect via the dedicated loop.""" + return asyncio.run_coroutine_threadsafe(super().connect(*args, **kwargs), self.t.loop).result() + + async def list_tools(self, *args: Any, **kwargs: Any) -> Any: + """List tools via the dedicated loop.""" + return asyncio.run_coroutine_threadsafe(super().list_tools(*args, **kwargs), self.t.loop).result() + + async def call_tool(self, *args: Any, **kwargs: Any) -> Any: + """Call a tool via the dedicated loop (serialised with a lock).""" + async with self.lock: + return asyncio.run_coroutine_threadsafe(super().call_tool(*args, **kwargs), self.t.loop).result() + + async def cleanup(self, *args: Any, **kwargs: Any) -> None: + """Clean up and shut down the dedicated loop.""" + try: + asyncio.run_coroutine_threadsafe(super().cleanup(*args, **kwargs), self.t.loop).result() + except asyncio.CancelledError: + pass + finally: + self.t.loop.stop() + self.t.join() + + +class ReconnectingMCPServerStdio(MCPServerStdio): + """Stdio wrapper that reconnects before every tool operation. + + Works around buggy JSON-RPC stdio behaviour in FastMCP 1.0 where + long-running, high-volume processes miss I/O and results never arrive + on the client side. Enable via ``reconnecting: true`` in your + toolbox config. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.reconnecting_lock: asyncio.Lock = asyncio.Lock() + + async def connect(self) -> None: + """No-op — connections are opened per-call.""" + logging.debug("Ignoring mcp connect request on purpose") + + async def cleanup(self) -> None: + """No-op — cleanup happens per-call.""" + logging.debug("Ignoring mcp cleanup request on purpose") + + async def list_tools(self, *args: Any, **kwargs: Any) -> Any: + """Connect, list tools, then disconnect.""" + async with self.reconnecting_lock: + await super().connect() + try: + result = await super().list_tools(*args, **kwargs) + finally: + await super().cleanup() + return result + + async def call_tool(self, *args: Any, **kwargs: Any) -> Any: + """Connect, call tool, then disconnect.""" + logging.debug("Using reconnecting call_tool for stdio mcp") + async with self.reconnecting_lock: + await super().connect() + try: + result = await super().call_tool(*args, **kwargs) + finally: + await super().cleanup() + return result diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index f2f9f11b..5fa3bed7 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -1,241 +1,72 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT -import asyncio +"""MCP client utilities. + +Provides tool-name compression, namespace-aware MCP wrappers with +confirmation support, and toolbox parameter resolution. +""" + +from __future__ import annotations + import hashlib import json import logging -import os import shutil -import socket -import subprocess -import time -from collections.abc import Callable -from threading import Event, Thread -from urllib.parse import urlparse - -from agents.mcp import MCPServerStdio +from typing import Any + from mcp.types import CallToolResult, TextContent from .available_tools import AvailableTools from .env_utils import swap_env -DEFAULT_MCP_CLIENT_SESSION_TIMEOUT = 120 +# Re-export transport classes and prompt builder so that existing +# ``from .mcp_utils import …`` statements continue to work. +from .mcp_prompt import mcp_system_prompt as mcp_system_prompt # noqa: F401 +from .mcp_transport import ( # noqa: F401 + AsyncDebugMCPServerStdio as AsyncDebugMCPServerStdio, + ReconnectingMCPServerStdio as ReconnectingMCPServerStdio, + StreamableMCPThread as StreamableMCPThread, +) + +DEFAULT_MCP_CLIENT_SESSION_TIMEOUT: int = 120 + +# The OpenAI API rejects tool names longer than 64 characters. +# We hash long names down to this many hex characters. +COMPRESSED_NAME_LENGTH: int = 12 + +def compress_name(name: str) -> str: + """Return a short hash of *name* to fit the OpenAI 64-char tool-name limit. -# The openai API complains if the name of a tool is longer than 64 -# chars. But it's easy to go over the limit if the yaml file is in a -# nested sub-directory, so this function converts a name to a 12 -# character hash. -def compress_name(name): + Args: + name: The original tool / toolbox name. + + Returns: + A 12-character lowercase hex digest. + """ m = hashlib.sha256() m.update(name.encode("utf-8")) - return m.hexdigest()[:12] - - -# A process management class for running in-process MCP streamable servers -class StreamableMCPThread(Thread): - """Process management for local streamable MCP servers""" - - def __init__( - self, - cmd, - url: str = "", - on_output: Callable[[str], None] | None = None, - on_error: Callable[[str], None] | None = None, - poll_interval: float = 0.5, - env: dict[str, str] | None = None, - ): - super().__init__(daemon=True) - self.url = url - self.cmd = cmd - self.on_output = on_output - self.on_error = on_error - self.poll_interval = poll_interval - self.env = os.environ.copy() # XXX: potential for environment leak to MCP - self.env.update(env) - self._stop_event = Event() - self.process = None - self.exit_code = None - self.exception: BaseException | None = None - - async def async_wait_for_connection(self, timeout=30.0, poll_interval=0.5): - parsed = urlparse(self.url) - host = parsed.hostname - port = parsed.port - if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") - deadline = asyncio.get_event_loop().time() + timeout - while True: - try: - reader, writer = await asyncio.open_connection(host, port) - writer.close() - await writer.wait_closed() - return # Success - except (OSError, ConnectionRefusedError): - if asyncio.get_event_loop().time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") - await asyncio.sleep(poll_interval) - - def wait_for_connection(self, timeout=30.0, poll_interval=0.5): - parsed = urlparse(self.url) - host = parsed.hostname - port = parsed.port - if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") - deadline = time.time() + timeout - while True: - try: - with socket.create_connection((host, port), timeout=2): - return # Success - except OSError: - if time.time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") - time.sleep(poll_interval) - - def run(self): - try: - self.process = subprocess.Popen( - self.cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, - universal_newlines=True, - env=self.env, - ) - - stdout_thread = Thread(target=self._read_stream, args=(self.process.stdout, self.on_output)) - stderr_thread = Thread(target=self._read_stream, args=(self.process.stderr, self.on_error)) - stdout_thread.start() - stderr_thread.start() - - while self.process.poll() is None and not self._stop_event.is_set(): - time.sleep(self.poll_interval) - - # Process ended or stop requested - if self.process.poll() is None: - self.process.terminate() - self.process.wait() - self.exit_code = self.process.returncode - - stdout_thread.join() - stderr_thread.join() - - # sigterm (-15) is expected - if self.exit_code not in [0, -15]: - self.exception = subprocess.CalledProcessError(self.exit_code, self.cmd) - - except BaseException as e: - self.exception = e - - def _read_stream(self, stream, callback): - if stream is None or callback is None: - return - for line in iter(stream.readline, ""): - callback(line.rstrip("\n")) - stream.close() - - def stop(self): - self._stop_event.set() - if self.process and self.process.poll() is None: - self.process.terminate() - - def is_running(self): - return self.process and self.process.poll() is None - - def join_and_raise(self, timeout: float | None = None): - self.join(timeout) - if self.is_alive(): - raise RuntimeError("Process thread did not exit within timeout.") - if self.exception is not None: - raise self.exception - - -# used for debugging asyncio event loop issues in mcp stdio servers -# lifts the asyncio event loop in use to a dedicated threaded loop -class AsyncDebugMCPServerStdio(MCPServerStdio): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - class AsyncLoopThread(Thread): - def __init__(self): - super().__init__(daemon=True) - self.loop = asyncio.new_event_loop() - - def run(self): - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - self.t = AsyncLoopThread() - self.t.start() - self.lock = asyncio.Lock() - - async def connect(self, *args, **kwargs): - return asyncio.run_coroutine_threadsafe(super().connect(*args, **kwargs), self.t.loop).result() - - async def list_tools(self, *args, **kwargs): - return asyncio.run_coroutine_threadsafe(super().list_tools(*args, **kwargs), self.t.loop).result() - - async def call_tool(self, *args, **kwargs): - async with self.lock: - return asyncio.run_coroutine_threadsafe(super().call_tool(*args, **kwargs), self.t.loop).result() - - async def cleanup(self, *args, **kwargs): - try: - asyncio.run_coroutine_threadsafe(super().cleanup(*args, **kwargs), self.t.loop).result() - except asyncio.CancelledError: - pass - finally: - self.t.loop.stop() - self.t.join() - - -# a hack class that works around buggy jsonrpc stdio behavior in FastMCP 1.0 -# long running high volume processes tend to get confused and miss i/o -# if you're seeing behavior where your mcp server tool call completes -# but the results never arrive to to the mcp client side, try and set -# reconnecting: true in your toolbox config -class ReconnectingMCPServerStdio(MCPServerStdio): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.reconnecting_lock = asyncio.Lock() - - async def connect(self): - logging.debug("Ignoring mcp connect request on purpose") - - async def cleanup(self): - logging.debug("Ignoring mcp cleanup request on purpose") - - async def list_tools(self, *args, **kwargs): - async with self.reconnecting_lock: - await super().connect() - try: - result = await super().list_tools(*args, **kwargs) - finally: - await super().cleanup() - return result - - async def call_tool(self, *args, **kwargs): - logging.debug("Using reconnecting call_tool for stdio mcp") - async with self.reconnecting_lock: - await super().connect() - try: - result = await super().call_tool(*args, **kwargs) - finally: - await super().cleanup() - return result + return m.hexdigest()[:COMPRESSED_NAME_LENGTH] class MCPNamespaceWrap: - """An MCP client object wrapper that provides us with namespace control""" + """MCP client wrapper that prefixes tool names with a namespace hash. + + Also provides optional interactive confirmation before calling + specific tools. - def __init__(self, confirms, obj): - self.confirms = confirms - self._obj = obj - self.namespace = compress_name(obj.name) + Args: + confirms: Tool names that require user confirmation. + obj: The underlying MCP server/client object to wrap. + """ - def __getattr__(self, name): + def __init__(self, confirms: list[str], obj: Any) -> None: + self.confirms: list[str] = confirms + self._obj: Any = obj + self.namespace: str = compress_name(obj.name) + + def __getattr__(self, name: str) -> Any: attr = getattr(self._obj, name) if callable(attr): match name: @@ -247,16 +78,26 @@ def __getattr__(self, name): return attr return attr - async def list_tools(self, *args, **kwargs): + async def list_tools(self, *args: Any, **kwargs: Any) -> list[Any]: + """List tools with namespace-prefixed names.""" result = await self._obj.list_tools(*args, **kwargs) - namespaced_tools = [] + namespaced_tools: list[Any] = [] for tool in result: tool_copy = tool.copy() tool_copy.name = f"{self.namespace}{tool.name}" namespaced_tools.append(tool_copy) return namespaced_tools - def confirm_tool(self, tool_name, args): + def confirm_tool(self, tool_name: str, args: list[Any]) -> bool: + """Interactively prompt the user for tool-call confirmation. + + Args: + tool_name: The tool being invoked. + args: Positional arguments to display. + + Returns: + ``True`` if the user approved the call. + """ while True: yn = input( f"** 🤖❗ Allow tool call?: {tool_name}({','.join([json.dumps(arg) for arg in args])}) (yes/no): " @@ -266,9 +107,10 @@ def confirm_tool(self, tool_name, args): if yn in ["no", "n"]: return False - async def call_tool(self, *args, **kwargs): + async def call_tool(self, *args: Any, **kwargs: Any) -> Any: + """Call a tool, stripping the namespace prefix and optionally confirming.""" _args = list(args) - tool_name = _args[0] + tool_name: str = _args[0] tool_name = tool_name.removeprefix(self.namespace) # to run headless, just make confirms an empty list if self.confirms and tool_name in self.confirms: @@ -283,15 +125,34 @@ async def call_tool(self, *args, **kwargs): return result -def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list): - """Return all the data needed to initialize an mcp server client.""" - client_params = {} +ClientParamsMap = dict[str, tuple[dict[str, Any], list[str], str | None, int | None]] + + +def mcp_client_params( + available_tools: AvailableTools, + requested_toolboxes: list[str], +) -> ClientParamsMap: + """Resolve toolbox configs into MCP server connection parameters. + + Args: + available_tools: The tool registry that can look up toolbox configs. + requested_toolboxes: Module paths of the toolboxes to resolve. + + Returns: + A mapping from toolbox name to a tuple of + ``(server_params, confirms, server_prompt, client_session_timeout)``. + + Raises: + ValueError: If the transport kind is not supported. + FileNotFoundError: If a streamable command cannot be found on ``$PATH``. + """ + client_params: ClientParamsMap = {} for tb in requested_toolboxes: toolbox = available_tools.get_toolbox(tb) sp = toolbox.server_params - kind = sp.kind - reconnecting = sp.reconnecting - server_params = {"kind": kind, "reconnecting": reconnecting} + kind: str = sp.kind + reconnecting: bool = sp.reconnecting + server_params: dict[str, Any] = {"kind": kind, "reconnecting": reconnecting} match kind: case "stdio": @@ -316,35 +177,13 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list server_params["env"] = env case "sse": - headers = dict(sp.headers) if sp.headers else None - if headers: - for k, v in headers.items(): - headers[k] = swap_env(v) - optional_headers = dict(sp.optional_headers) if sp.optional_headers else None - if optional_headers: - for k, v in list(optional_headers.items()): - try: - optional_headers[k] = swap_env(v) - except LookupError: - del optional_headers[k] - headers = _merge_headers(headers, optional_headers) + headers = _resolve_headers(sp.headers, sp.optional_headers) server_params["url"] = sp.url server_params["headers"] = headers server_params["timeout"] = sp.timeout case "streamable": - headers = dict(sp.headers) if sp.headers else None - if headers: - for k, v in headers.items(): - headers[k] = swap_env(v) - optional_headers = dict(sp.optional_headers) if sp.optional_headers else None - if optional_headers: - for k, v in list(optional_headers.items()): - try: - optional_headers[k] = swap_env(v) - except LookupError: - del optional_headers[k] - headers = _merge_headers(headers, optional_headers) + headers = _resolve_headers(sp.headers, sp.optional_headers) server_params["url"] = sp.url server_params["headers"] = headers server_params["timeout"] = sp.timeout @@ -384,79 +223,52 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list return client_params +def _resolve_headers( + headers: dict[str, str] | None, + optional_headers: dict[str, str] | None, +) -> dict[str, str] | None: + """Expand env references in headers and merge required + optional. + + Required headers raise on missing env vars; optional headers are + silently dropped when a referenced variable is absent. + + Args: + headers: Header dict whose values may contain ``{{ env('…') }}``. + optional_headers: Like *headers*, but missing env vars are tolerated. + + Returns: + Merged header dict, or ``None`` if both inputs are ``None``. + """ + resolved: dict[str, str] | None = None + if headers: + resolved = dict(headers) + for k, v in resolved.items(): + resolved[k] = swap_env(v) + resolved_optional: dict[str, str] | None = None + if optional_headers: + resolved_optional = dict(optional_headers) + for k, v in list(resolved_optional.items()): + try: + resolved_optional[k] = swap_env(v) + except LookupError: + del resolved_optional[k] + return _merge_headers(resolved, resolved_optional) + + def _merge_headers( headers: dict[str, str] | None, optional_headers: dict[str, str] | None, ) -> dict[str, str] | None: - """Merge required and optional headers.""" + """Merge required and optional header dicts. + + Args: + headers: Required headers (may be ``None``). + optional_headers: Optional headers (may be ``None``). + + Returns: + Combined header dict, or ``None`` if both are ``None``. + """ if headers and optional_headers: headers.update(optional_headers) return headers return headers or optional_headers - - -def mcp_system_prompt( - system_prompt: str, - task: str, - tools: list[str] = [], - resources: list[str] = [], - resource_templates: list[str] = [], - important_guidelines: list[str] = [], - server_prompts: list[str] = [], -): - """Return a well constructed system prompt""" - prompt = f""" -{system_prompt} -""" - - if tools: - prompt += """ - -# Available Tools - -- {tools} -""".format(tools="\n- ".join(tools)) - - if resources: - prompt += """ - -# Available Resources - -- {resources} -""".format(resources="\n- ".join(resources)) - - if resource_templates: - prompt += """ - -# Available Resource Templates - -- {resource_templates} -""".format(resource_templates="\n- ".join(resource_templates)) - - if important_guidelines: - prompt += """ - -# Important Guidelines - -- IMPORTANT: {guidelines} -""".format(guidelines="\n- IMPORTANT: ".join(important_guidelines)) - - if server_prompts: - prompt += """ - -# Additional Guidelines - -{server_prompts} - -""".format(server_prompts="\n\n".join(server_prompts)) - - if task: - prompt += f""" - -# Primary Task to Complete - -{task} - -""" - - return prompt diff --git a/src/seclab_taskflow_agent/prompt_parser.py b/src/seclab_taskflow_agent/prompt_parser.py new file mode 100644 index 00000000..c24c48c5 --- /dev/null +++ b/src/seclab_taskflow_agent/prompt_parser.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Legacy argparse-based prompt parser. + +When an agent has no explicit agents list, this module parses the user +prompt text to extract ``-p personality_name`` flags embedded in the +prompt itself. The Typer-based CLI in :mod:`~seclab_taskflow_agent.cli` +has superseded this for normal invocations, but the parser is still used +at runtime by :mod:`~seclab_taskflow_agent.runner` and in several tests. +""" + +from __future__ import annotations + +import argparse +import logging + +from .available_tools import AvailableTools + + +def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = None): + """Legacy CLI parser kept for backwards compatibility with tests. + + Returns: + Tuple of (personality, taskflow, list_models, cli_globals, prompt, help_msg). + """ + parser = argparse.ArgumentParser(add_help=False, description="SecLab Taskflow Agent") + parser.prog = "" + group = parser.add_mutually_exclusive_group() + group.add_argument("-p", help="The personality to use (mutex with -t)", required=False) + group.add_argument("-t", help="The taskflow to use (mutex with -p)", required=False) + group.add_argument("-l", help="List available tool call models and exit", action="store_true", required=False) + parser.add_argument( + "-g", + "--global", + dest="globals", + action="append", + help="Set global variable (KEY=VALUE). Can be used multiple times.", + required=False, + ) + parser.add_argument("prompt", nargs=argparse.REMAINDER) + + help_msg = parser.format_help() + help_msg += "\nExamples:\n\n" + help_msg += "`-p seclab_taskflow_agent.personalities.assistant explain modems to me please`\n" + help_msg += "`-t examples.taskflows.example_globals -g fruit=apples`\n" + try: + args = parser.parse_known_args(user_prompt.split(" ") if user_prompt else None) + except SystemExit as e: + if e.code == 2: + logging.exception(f"User provided incomplete prompt: {user_prompt}") + return None, None, None, None, help_msg + p = args[0].p.strip() if args[0].p else None + t = args[0].t.strip() if args[0].t else None + list_models = args[0].l + + cli_globals: dict[str, str] = {} + if args[0].globals: + for g in args[0].globals: + if "=" not in g: + logging.error(f"Invalid global variable format: {g}. Expected KEY=VALUE") + return None, None, None, None, None, help_msg + key, value = g.split("=", 1) + cli_globals[key.strip()] = value.strip() + + return p, t, list_models, cli_globals, " ".join(args[0].prompt), help_msg diff --git a/src/seclab_taskflow_agent/render_utils.py b/src/seclab_taskflow_agent/render_utils.py index 5a3d2476..f5fad8a7 100644 --- a/src/seclab_taskflow_agent/render_utils.py +++ b/src/seclab_taskflow_agent/render_utils.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""Utilities for rendering and buffering streamed model output.""" + import asyncio import logging @@ -16,7 +18,8 @@ render_logger.propagate = False -async def flush_async_output(task_id: str): +async def flush_async_output(task_id: str) -> None: + """Flush buffered async output for *task_id* to the console.""" async with async_output_lock: if task_id not in async_output: raise ValueError(f"No async output for task: {task_id}") @@ -26,7 +29,8 @@ async def flush_async_output(task_id: str): await render_model_output(data) -async def render_model_output(data: str, log: bool = True, async_task: bool = False, task_id: str | None = None): +async def render_model_output(data: str, log: bool = True, async_task: bool = False, task_id: str | None = None) -> None: + """Print model output to the console, optionally buffering for async tasks.""" async with async_output_lock: if async_task and task_id: if task_id in async_output: diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 330f49ef..c85afbd9 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -29,16 +29,186 @@ from .available_tools import AvailableTools from .env_utils import TmpEnv from .mcp_lifecycle import MCP_CLEANUP_TIMEOUT, build_mcp_servers, mcp_session_task -from .models import PersonalityDocument, TaskDefinition -from .mcp_utils import compress_name, mcp_client_params, mcp_system_prompt +from .models import ModelConfigDocument, PersonalityDocument, TaskDefinition +from .mcp_prompt import mcp_system_prompt +from .mcp_utils import compress_name, mcp_client_params from .render_utils import flush_async_output, render_model_output from .shell_utils import shell_tool_call from .template_utils import render_template -DEFAULT_MAX_TURNS = 50 -RATE_LIMIT_BACKOFF = 5 -MAX_RATE_LIMIT_BACKOFF = 120 -MAX_API_RETRY = 5 +DEFAULT_MAX_TURNS = 50 # Maximum agent turns before forced termination +RATE_LIMIT_BACKOFF = 5 # Initial backoff in seconds after a rate-limit response +MAX_RATE_LIMIT_BACKOFF = 120 # Maximum backoff cap in seconds for rate-limit retries +MAX_API_RETRY = 5 # Maximum number of consecutive API error retries + + +def _resolve_model_config( + available_tools: AvailableTools, + model_config_ref: str, +) -> tuple[list[str], dict[str, str], dict[str, dict[str, Any]]]: + """Load and validate the model configuration file. + + Args: + available_tools: Tool registry used to load the config file. + model_config_ref: Reference name for the model config document. + + Returns: + A tuple of (model_keys, model_dict, models_params) where model_keys is + the list of logical model names, model_dict maps them to provider model + IDs, and models_params holds per-model settings. + + Raises: + ValueError: If the config file has structural problems. + """ + m_config: ModelConfigDocument = available_tools.get_model_config(model_config_ref) + model_dict: dict[str, str] = m_config.models or {} + if model_dict and not isinstance(model_dict, dict): + raise ValueError(f"Models section of the model_config file {model_config_ref} must be a dictionary") + model_keys: list[str] = list(model_dict.keys()) + models_params: dict[str, dict[str, Any]] = m_config.model_settings or {} + if models_params and not isinstance(models_params, dict): + raise ValueError(f"Settings section of model_config file {model_config_ref} must be a dictionary") + if not set(models_params.keys()).difference(model_keys).issubset(set()): + raise ValueError( + f"Settings section of model_config file {model_config_ref} contains models not in the model section" + ) + for k, v in models_params.items(): + if not isinstance(v, dict): + raise ValueError(f"Settings for model {k} in model_config file {model_config_ref} is not a dictionary") + return model_keys, model_dict, models_params + + +def _merge_reusable_task( + available_tools: AvailableTools, + task: TaskDefinition, +) -> TaskDefinition: + """Merge a reusable taskflow into the current task definition. + + Args: + available_tools: Tool registry used to load the reusable taskflow. + task: Current task whose ``uses`` field references a reusable taskflow. + + Returns: + A new TaskDefinition with parent defaults filled in where the current + task uses its own defaults. + + Raises: + ValueError: If the reusable taskflow is missing or has more than 1 task. + """ + reusable_doc = available_tools.get_taskflow(task.uses) + if reusable_doc is None: + raise ValueError(f"No such reusable taskflow: {task.uses}") + if len(reusable_doc.taskflow) > 1: + raise ValueError("Reusable taskflows can only contain 1 task") + parent_task = reusable_doc.taskflow[0].task + merged: dict[str, Any] = parent_task.model_dump(by_alias=True, exclude_defaults=True) + current: dict[str, Any] = task.model_dump(by_alias=True, exclude_defaults=True) + merged.update(current) + return TaskDefinition.model_validate(merged) + + +def _resolve_task_model( + task: TaskDefinition, + model_keys: list[str], + model_dict: dict[str, str], + models_params: dict[str, dict[str, Any]], +) -> tuple[str, dict[str, Any]]: + """Resolve the final model name and settings for a task. + + Args: + task: The task definition containing optional model/model_settings. + model_keys: Logical model names from the model config. + model_dict: Mapping of logical model names to provider model IDs. + models_params: Per-model settings from the model config. + + Returns: + A tuple of (resolved_model_name, merged_model_settings). + + Raises: + ValueError: If task-level model_settings is not a dictionary. + """ + model: str = task.model or DEFAULT_MODEL + model_settings: dict[str, Any] = {} + if model in model_keys: + if model in models_params: + model_settings = models_params[model].copy() + model = model_dict[model] + task_model_settings: dict[str, Any] | Any = task.model_settings or {} + if not isinstance(task_model_settings, dict): + raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") + model_settings.update(task_model_settings) + return model, model_settings + + +async def _build_prompts_to_run( + task_prompt: str, + repeat_prompt: bool, + last_mcp_tool_results: list[str], + available_tools: AvailableTools, + global_variables: dict[str, Any], + inputs: dict[str, Any], +) -> list[str]: + """Build the list of prompts to execute for a task. + + For regular tasks the list contains a single rendered prompt. When + ``repeat_prompt`` is enabled, the last MCP tool result is parsed as an + iterable and a prompt is rendered for each element. + + Args: + task_prompt: The raw or pre-rendered prompt template string. + repeat_prompt: Whether to expand prompts over MCP tool results. + last_mcp_tool_results: Mutable list of prior MCP tool result strings. + available_tools: Tool registry (passed through to template rendering). + global_variables: Global template variables. + inputs: Task-level input variables. + + Returns: + List of rendered prompt strings to execute. + + Raises: + ValueError: If the last MCP result is missing or not valid JSON. + """ + prompts_to_run: list[str] = [] + if repeat_prompt: + if "result" not in task_prompt.lower(): + logging.warning("repeat_prompt enabled but no {{ result }} in prompt") + try: + last_result = json.loads(last_mcp_tool_results.pop()) + text = last_result.get("text", "") + try: + iterable_result = json.loads(text) + except json.JSONDecodeError as exc: + logging.critical(f"Could not parse result text: {text}") + raise ValueError("Result text is not valid JSON") from exc + try: + iter(iterable_result) + except TypeError: + logging.critical("Last MCP tool result is not iterable") + raise + except IndexError: + logging.critical("No last MCP tool result available") + raise + + if not iterable_result: + await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") + else: + logging.debug(f"Rendering templated prompts for results: {iterable_result}") + for value in iterable_result: + try: + rendered_prompt = render_template( + template_str=task_prompt, + available_tools=available_tools, + globals_dict=global_variables, + inputs_dict=inputs, + result_value=value, + ) + prompts_to_run.append(rendered_prompt) + except jinja2.TemplateError as e: + logging.error(f"Error rendering template for result {value}: {e}") + raise ValueError(f"Template rendering failed: {e}") + else: + prompts_to_run.append(task_prompt) + return prompts_to_run async def deploy_task_agents( @@ -229,8 +399,8 @@ async def _run_streamed() -> None: async def run_main( available_tools: AvailableTools, - p: str | None, - t: str | None, + personality_path: str | None, + taskflow_path: str | None, cli_globals: dict[str, str], prompt: str | None, ) -> None: @@ -238,8 +408,8 @@ async def run_main( Args: available_tools: Tool registry. - p: Personality module path, or None. - t: Taskflow module path, or None. + personality_path: Personality module path, or None. + taskflow_path: Taskflow module path, or None. cli_globals: Global variables from CLI. prompt: User prompt text. """ @@ -254,18 +424,18 @@ async def on_tool_start_hook(context: RunContextWrapper[TContext], agent: Agent[ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]) -> None: await render_model_output(f"\n** 🤖🤝 Agent Handoff: {source.name} -> {agent.name}\n") - if p: - personality = available_tools.get_personality(p) + if personality_path: + personality = available_tools.get_personality(personality_path) await deploy_task_agents( available_tools, - {p: personality}, + {personality_path: personality}, prompt or "", run_hooks=TaskRunHooks(on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook), ) - if t: - taskflow_doc = available_tools.get_taskflow(t) - await render_model_output(f"** 🤖💪 Running Task Flow: {t}\n") + if taskflow_path: + taskflow_doc = available_tools.get_taskflow(taskflow_path) + await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n") # Resolve global variables (file defaults + CLI overrides) global_variables = dict(taskflow_doc.globals or {}) @@ -275,52 +445,20 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo # Resolve model config model_config_ref = taskflow_doc.model_config_ref model_keys: list[str] = [] + model_dict: dict[str, str] = {} models_params: dict[str, dict[str, Any]] = {} if model_config_ref: - m_config = available_tools.get_model_config(model_config_ref) - model_dict = m_config.models or {} - if model_dict and not isinstance(model_dict, dict): - raise ValueError(f"Models section of the model_config file {model_config_ref} must be a dictionary") - model_keys = list(model_dict.keys()) - models_params = m_config.model_settings or {} - if models_params and not isinstance(models_params, dict): - raise ValueError(f"Settings section of model_config file {model_config_ref} must be a dictionary") - if not set(models_params.keys()).difference(model_keys).issubset(set()): - raise ValueError( - f"Settings section of model_config file {model_config_ref} contains models not in the model section" - ) - for k, v in models_params.items(): - if not isinstance(v, dict): - raise ValueError(f"Settings for model {k} in model_config file {model_config_ref} is not a dictionary") + model_keys, model_dict, models_params = _resolve_model_config(available_tools, model_config_ref) for task_wrapper in taskflow_doc.taskflow: task = task_wrapper.task # Reusable taskflow support: merge parent defaults into current task if task.uses: - reusable_doc = available_tools.get_taskflow(task.uses) - if reusable_doc is None: - raise ValueError(f"No such reusable taskflow: {task.uses}") - if len(reusable_doc.taskflow) > 1: - raise ValueError("Reusable taskflows can only contain 1 task") - # Merge: parent fields fill in where current task has defaults - parent_task = reusable_doc.taskflow[0].task - merged = parent_task.model_dump(by_alias=True, exclude_defaults=True) - current = task.model_dump(by_alias=True, exclude_defaults=True) - merged.update(current) # current task overrides parent - task = TaskDefinition.model_validate(merged) + task = _merge_reusable_task(available_tools, task) # Resolve model - model = task.model or DEFAULT_MODEL - model_settings: dict[str, Any] = {} - if model in model_keys: - if model in models_params: - model_settings = models_params[model].copy() - model = model_dict[model] - task_model_settings = task.model_settings or {} - if not isinstance(task_model_settings, dict): - raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") - model_settings.update(task_model_settings) + model, model_settings = _resolve_task_model(task, model_keys, model_dict, models_params) # Read task fields via typed attributes agents_list = task.agents or [] @@ -351,49 +489,13 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo ) except jinja2.TemplateError as e: logging.error(f"Template rendering error: {e}") - raise ValueError(f"Failed to render prompt template: {e}") + raise ValueError(f"Failed to render prompt template: {e}") from e with TmpEnv(env): - prompts_to_run: list[str] = [] - if repeat_prompt: - if "result" not in task_prompt.lower(): - logging.warning("repeat_prompt enabled but no {{ result }} in prompt") - try: - last_result = json.loads(last_mcp_tool_results.pop()) - text = last_result.get("text", "") - try: - iterable_result = json.loads(text) - except json.JSONDecodeError as exc: - logging.critical(f"Could not parse result text: {text}") - raise ValueError("Result text is not valid JSON") from exc - try: - iter(iterable_result) - except TypeError: - logging.critical("Last MCP tool result is not iterable") - raise - except IndexError: - logging.critical("No last MCP tool result available") - raise - - if not iterable_result: - await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") - else: - logging.debug(f"Rendering templated prompts for results: {iterable_result}") - for value in iterable_result: - try: - rendered_prompt = render_template( - template_str=task_prompt, - available_tools=available_tools, - globals_dict=global_variables, - inputs_dict=inputs, - result_value=value, - ) - prompts_to_run.append(rendered_prompt) - except jinja2.TemplateError as e: - logging.error(f"Error rendering template for result {value}: {e}") - raise ValueError(f"Template rendering failed: {e}") - else: - prompts_to_run.append(task_prompt) + prompts_to_run: list[str] = await _build_prompts_to_run( + task_prompt, repeat_prompt, last_mcp_tool_results, + available_tools, global_variables, inputs, + ) async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -> bool: if run: @@ -414,7 +516,7 @@ async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) - resolved_agents: dict[str, Any] = {} current_agents = list(agents_list) if not current_agents: - from .cli import parse_prompt_args + from .prompt_parser import parse_prompt_args p_val, _, _, _, p_prompt, _ = parse_prompt_args(available_tools, p_prompt) if p_val: current_agents.append(p_val) diff --git a/src/seclab_taskflow_agent/shell_utils.py b/src/seclab_taskflow_agent/shell_utils.py index cfe413a0..caeb7f26 100644 --- a/src/seclab_taskflow_agent/shell_utils.py +++ b/src/seclab_taskflow_agent/shell_utils.py @@ -8,17 +8,23 @@ from mcp.types import CallToolResult, TextContent -def shell_command_to_string(cmd): +def shell_command_to_string(cmd: list[str]) -> str: + """Execute a shell command and return its stdout. + + Raises: + RuntimeError: If the command exits with a non-zero return code. + """ logging.info(f"Executing: {cmd}") p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") stdout, stderr = p.communicate() p.wait() if p.returncode: - raise RuntimeError(stderr) + raise RuntimeError(f"Command {cmd} failed: {stderr}") return stdout -def shell_exec_with_temporary_file(script, shell="bash"): +def shell_exec_with_temporary_file(script: str, shell: str = "bash") -> str: + """Write *script* to a temp file and execute it with the given shell.""" with tempfile.NamedTemporaryFile(mode="w+", delete=True) as temp_file: temp_file.write(script) temp_file.flush() @@ -26,7 +32,8 @@ def shell_exec_with_temporary_file(script, shell="bash"): return result -def shell_tool_call(run): +def shell_tool_call(run: str) -> CallToolResult: + """Execute a shell script and return the output as a CallToolResult.""" stdout = shell_exec_with_temporary_file(run) # this allows e.g. shell based jq output to become available for repeat prompts result = CallToolResult(content=[TextContent(type="text", text=stdout, annotations=None, meta=None)]) diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 67652432..8de477b3 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -4,14 +4,18 @@ """Jinja2 template utilities for taskflow template rendering.""" import os +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + import jinja2 -from typing import Any, Dict, Optional + +if TYPE_CHECKING: + from .available_tools import AvailableTools class PromptLoader(jinja2.BaseLoader): """Custom Jinja2 loader for reusable prompts.""" - def __init__(self, available_tools): + def __init__(self, available_tools: "AvailableTools") -> None: """Initialize the prompt loader. Args: @@ -19,7 +23,9 @@ def __init__(self, available_tools): """ self.available_tools = available_tools - def get_source(self, environment, template): + def get_source( + self, environment: jinja2.Environment, template: str + ) -> tuple[str, str | None, Callable[[], bool]]: """Load prompt from available_tools by path. Args: @@ -69,7 +75,7 @@ def env_function(var_name: str, default: Optional[str] = None, required: bool = return value or "" -def create_jinja_environment(available_tools) -> jinja2.Environment: +def create_jinja_environment(available_tools: "AvailableTools") -> jinja2.Environment: """Create configured Jinja2 environment for taskflow templates. Args: @@ -102,7 +108,7 @@ def create_jinja_environment(available_tools) -> jinja2.Environment: def render_template( template_str: str, - available_tools, + available_tools: "AvailableTools", globals_dict: Optional[Dict[str, Any]] = None, inputs_dict: Optional[Dict[str, Any]] = None, result_value: Optional[Any] = None, From 8fa7fb4169f21e04898f657f18df90d8aaa23008 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 14:56:54 -0400 Subject: [PATCH 04/20] feat: add responses API support via model_config api_type Add api_type field to ModelConfigDocument (chat_completions|responses). Thread api_type through runner -> deploy_task_agents -> TaskAgent. TaskAgent switches between OpenAIChatCompletionsModel and OpenAIResponsesModel based on api_type. Default: chat_completions. --- src/seclab_taskflow_agent/__init__.py | 2 ++ src/seclab_taskflow_agent/agent.py | 20 +++++++++++++++----- src/seclab_taskflow_agent/models.py | 12 ++++++++++-- src/seclab_taskflow_agent/runner.py | 21 ++++++++++++++------- tests/test_models.py | 21 +++++++++++++++++++++ 5 files changed, 62 insertions(+), 14 deletions(-) diff --git a/src/seclab_taskflow_agent/__init__.py b/src/seclab_taskflow_agent/__init__.py index e10fb131..b87f30ce 100644 --- a/src/seclab_taskflow_agent/__init__.py +++ b/src/seclab_taskflow_agent/__init__.py @@ -23,6 +23,7 @@ """ __all__ = [ + "ApiType", "AvailableTools", "TaskAgent", "TaskRunHooks", @@ -38,6 +39,7 @@ from .agent import TaskAgent, TaskAgentHooks, TaskRunHooks from .available_tools import AvailableTools from .models import ( + ApiType, ModelConfigDocument, PersonalityDocument, PromptDocument, diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 67567216..3e9c09ad 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -12,6 +12,7 @@ Agent, AgentHooks, OpenAIChatCompletionsModel, + OpenAIResponsesModel, RunContextWrapper, RunHooks, Runner, @@ -149,19 +150,22 @@ def __init__( mcp_servers: list[Any] = [], model: str = DEFAULT_MODEL, model_settings: ModelSettings | None = None, + api_type: str = "chat_completions", run_hooks: TaskRunHooks | None = None, agent_hooks: TaskAgentHooks | None = None, ) -> None: - """Create a TaskAgent with the specified configuration.""" + """Create a TaskAgent with the specified configuration. + + Args: + api_type: OpenAI API type -- ``"chat_completions"`` or ``"responses"``. + """ client = AsyncOpenAI( base_url=api_endpoint, api_key=get_AI_token(), default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID}, ) set_default_openai_client(client) - # CAPI does not yet support the Responses API: https://github.com/github/copilot-api/issues/11185 - # as such we are implementing on chat completions for now - set_default_openai_api("chat_completions") + set_default_openai_api(api_type) set_tracing_disabled(True) self.run_hooks = run_hooks or TaskRunHooks() # useful agent patterns: @@ -173,11 +177,17 @@ def _ToolsToFinalOutputFunction( ) -> ToolsToFinalOutputResult: return ToolsToFinalOutputResult(True, "Excluding tool results from LLM context") + # Select model class based on api_type + if api_type == "responses": + model_impl = OpenAIResponsesModel(model=model, openai_client=client) + else: + model_impl = OpenAIChatCompletionsModel(model=model, openai_client=client) + self.agent = Agent( name=name, instructions=instructions, tool_use_behavior=_ToolsToFinalOutputFunction if exclude_from_context else "run_llm_again", - model=OpenAIChatCompletionsModel(model=model, openai_client=client), + model=model_impl, handoffs=handoffs, mcp_servers=mcp_servers, model_settings=model_settings or ModelSettings(), diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py index aaddc079..f5f658a8 100644 --- a/src/seclab_taskflow_agent/models.py +++ b/src/seclab_taskflow_agent/models.py @@ -10,10 +10,13 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +# Valid API type values for model configuration. +ApiType = Literal["chat_completions", "responses"] + # --------------------------------------------------------------------------- # Header @@ -174,11 +177,16 @@ class ToolboxDocument(BaseModel): class ModelConfigDocument(BaseModel): - """A model_config YAML document mapping logical model names to provider IDs.""" + """A model_config YAML document mapping logical model names to provider IDs. + + The ``api_type`` field controls which OpenAI API is used for all models + in this config: ``"chat_completions"`` (default) or ``"responses"``. + """ model_config = ConfigDict(extra="allow") header: TaskflowHeader = Field(alias="seclab-taskflow-agent") + api_type: ApiType = "chat_completions" models: dict[str, str] = Field(default_factory=dict) model_settings: dict[str, dict[str, Any]] = Field(default_factory=dict) diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index c85afbd9..65d8af92 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -45,7 +45,7 @@ def _resolve_model_config( available_tools: AvailableTools, model_config_ref: str, -) -> tuple[list[str], dict[str, str], dict[str, dict[str, Any]]]: +) -> tuple[list[str], dict[str, str], dict[str, dict[str, Any]], str]: """Load and validate the model configuration file. Args: @@ -53,9 +53,10 @@ def _resolve_model_config( model_config_ref: Reference name for the model config document. Returns: - A tuple of (model_keys, model_dict, models_params) where model_keys is - the list of logical model names, model_dict maps them to provider model - IDs, and models_params holds per-model settings. + A tuple of (model_keys, model_dict, models_params, api_type) where + model_keys is the list of logical model names, model_dict maps them + to provider model IDs, models_params holds per-model settings, and + api_type is ``"chat_completions"`` or ``"responses"``. Raises: ValueError: If the config file has structural problems. @@ -75,7 +76,7 @@ def _resolve_model_config( for k, v in models_params.items(): if not isinstance(v, dict): raise ValueError(f"Settings for model {k} in model_config file {model_config_ref} is not a dictionary") - return model_keys, model_dict, models_params + return model_keys, model_dict, models_params, m_config.api_type def _merge_reusable_task( @@ -224,6 +225,7 @@ async def deploy_task_agents( max_turns: int = DEFAULT_MAX_TURNS, model: str = DEFAULT_MODEL, model_par: dict[str, Any] | None = None, + api_type: str = "chat_completions", run_hooks: TaskRunHooks | None = None, agent_hooks: TaskAgentHooks | None = None, ) -> bool: @@ -231,8 +233,9 @@ async def deploy_task_agents( Args: available_tools: Tool registry. - agents: Mapping of agent name → PersonalityDocument. + agents: Mapping of agent name -> PersonalityDocument. prompt: User prompt to execute. + api_type: OpenAI API type -- ``"chat_completions"`` or ``"responses"``. Returns: True if the task completed successfully. @@ -308,6 +311,7 @@ async def deploy_task_agents( mcp_servers=[e.server for e in entries], model=model, model_settings=model_settings, + api_type=api_type, run_hooks=run_hooks, agent_hooks=agent_hooks, ).agent @@ -330,6 +334,7 @@ async def deploy_task_agents( mcp_servers=[e.server for e in entries], model=model, model_settings=model_settings, + api_type=api_type, run_hooks=run_hooks, agent_hooks=agent_hooks, ) @@ -447,8 +452,9 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo model_keys: list[str] = [] model_dict: dict[str, str] = {} models_params: dict[str, dict[str, Any]] = {} + api_type: str = "chat_completions" if model_config_ref: - model_keys, model_dict, models_params = _resolve_model_config(available_tools, model_config_ref) + model_keys, model_dict, models_params, api_type = _resolve_model_config(available_tools, model_config_ref) for task_wrapper in taskflow_doc.taskflow: task = task_wrapper.task @@ -543,6 +549,7 @@ async def _deploy(ra: dict, pp: str) -> bool: ), model=model, model_par=model_settings, + api_type=api_type, agent_hooks=TaskAgentHooks(on_handoff=on_handoff_hook), ) diff --git a/tests/test_models.py b/tests/test_models.py index 6db4cc46..048d8ca3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -199,6 +199,27 @@ def test_full_config(self): doc = ModelConfigDocument(**data) assert doc.models["gpt_default"] == "gpt-4.1" assert doc.model_settings["gpt_default"]["temperature"] == 0.7 + assert doc.api_type == "chat_completions" # default + + def test_api_type_responses(self): + """Test that api_type can be set to 'responses'.""" + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "model_config"}, + "api_type": "responses", + "models": {"o3": "o3"}, + } + doc = ModelConfigDocument(**data) + assert doc.api_type == "responses" + + def test_api_type_invalid(self): + """Test that invalid api_type values are rejected.""" + data = { + "seclab-taskflow-agent": {"version": "1.0", "filetype": "model_config"}, + "api_type": "invalid", + "models": {}, + } + with pytest.raises(ValidationError): + ModelConfigDocument(**data) class TestPromptDocument: From 7fe48e18f1ee09299ab82a84d33b4a7cac7c9418 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:15:12 -0400 Subject: [PATCH 05/20] feat: per-model api_type, endpoint, and token overrides model_settings in model_config can now set api_type, endpoint, and token per model. token is an env var name to resolve. Allows mixing chat_completions and responses API across models and routing to different endpoints within a single taskflow. --- examples/model_configs/responses_api.yaml | 17 +++++ examples/taskflows/echo_responses_api.yaml | 20 ++++++ examples/taskflows/example_responses_api.yaml | 19 ++++++ src/seclab_taskflow_agent/agent.py | 21 ++++-- src/seclab_taskflow_agent/runner.py | 67 +++++++++++++------ 5 files changed, 120 insertions(+), 24 deletions(-) create mode 100644 examples/model_configs/responses_api.yaml create mode 100644 examples/taskflows/echo_responses_api.yaml create mode 100644 examples/taskflows/example_responses_api.yaml diff --git a/examples/model_configs/responses_api.yaml b/examples/model_configs/responses_api.yaml new file mode 100644 index 00000000..4d47582b --- /dev/null +++ b/examples/model_configs/responses_api.yaml @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +# Example: per-model API type and endpoint configuration. +# gpt_responses uses the Responses API on the CAPI endpoint, +# reading its token from the CAPI_TOKEN env var. + +seclab-taskflow-agent: + version: "1.0" + filetype: model_config +models: + gpt_responses: gpt-5.1 +model_settings: + gpt_responses: + api_type: responses + endpoint: https://api.githubcopilot.com + token: CAPI_TOKEN diff --git a/examples/taskflows/echo_responses_api.yaml b/examples/taskflows/echo_responses_api.yaml new file mode 100644 index 00000000..6d9ef0c4 --- /dev/null +++ b/examples/taskflows/echo_responses_api.yaml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +# Echo taskflow using the Responses API with MCP tool calls. + +seclab-taskflow-agent: + version: "1.0" + filetype: taskflow + +model_config: examples.model_configs.responses_api + +taskflow: + - task: + max_steps: 5 + must_complete: true + agents: + - examples.personalities.echo + model: gpt_responses + user_prompt: | + Hello from the Responses API diff --git a/examples/taskflows/example_responses_api.yaml b/examples/taskflows/example_responses_api.yaml new file mode 100644 index 00000000..598932ff --- /dev/null +++ b/examples/taskflows/example_responses_api.yaml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +# Example taskflow using the Responses API instead of Chat Completions. +# Requires an endpoint that supports the Responses API (e.g. api.githubcopilot.com). + +seclab-taskflow-agent: + version: "1.0" + filetype: taskflow + +model_config: examples.model_configs.responses_api + +taskflow: + - task: + agents: + - seclab_taskflow_agent.personalities.assistant + model: gpt_responses + user_prompt: | + What is the capital of France? Reply in one sentence. diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 3e9c09ad..65640448 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -151,25 +151,36 @@ def __init__( model: str = DEFAULT_MODEL, model_settings: ModelSettings | None = None, api_type: str = "chat_completions", + endpoint: str | None = None, + token: str | None = None, run_hooks: TaskRunHooks | None = None, agent_hooks: TaskAgentHooks | None = None, ) -> None: """Create a TaskAgent with the specified configuration. Args: - api_type: OpenAI API type -- ``"chat_completions"`` or ``"responses"``. + api_type: ``"chat_completions"`` or ``"responses"``. + endpoint: Optional API endpoint URL override for this model. + token: Optional env var name whose value is used as the API key. """ + # Resolve per-model endpoint and token, falling back to defaults + resolved_endpoint = endpoint or api_endpoint + if token: + resolved_token = os.getenv(token, "") + if not resolved_token: + raise RuntimeError(f"Token env var {token!r} is not set") + else: + resolved_token = get_AI_token() + client = AsyncOpenAI( - base_url=api_endpoint, - api_key=get_AI_token(), + base_url=resolved_endpoint, + api_key=resolved_token, default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID}, ) set_default_openai_client(client) set_default_openai_api(api_type) set_tracing_disabled(True) self.run_hooks = run_hooks or TaskRunHooks() - # useful agent patterns: - # openai/openai-agents-python/blob/main/examples/agent_patterns # when we want to exclude tool results from context, we receive results here instead of sending to LLM def _ToolsToFinalOutputFunction( diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 65d8af92..1c5e74fc 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -108,37 +108,54 @@ def _merge_reusable_task( return TaskDefinition.model_validate(merged) +# Keys in model_settings that are handled by the engine, not ModelSettings. +_ENGINE_SETTING_KEYS = {"api_type", "endpoint", "token"} + + def _resolve_task_model( task: TaskDefinition, model_keys: list[str], model_dict: dict[str, str], models_params: dict[str, dict[str, Any]], -) -> tuple[str, dict[str, Any]]: - """Resolve the final model name and settings for a task. - - Args: - task: The task definition containing optional model/model_settings. - model_keys: Logical model names from the model config. - model_dict: Mapping of logical model names to provider model IDs. - models_params: Per-model settings from the model config. + default_api_type: str = "chat_completions", +) -> tuple[str, dict[str, Any], str, str | None, str | None]: + """Resolve the final model name, settings, and per-model overrides. Returns: - A tuple of (resolved_model_name, merged_model_settings). + A tuple of ``(model_id, model_settings, api_type, endpoint, token)`` + where *endpoint* and *token* are ``None`` when not overridden. Raises: ValueError: If task-level model_settings is not a dictionary. """ - model: str = task.model or DEFAULT_MODEL + logical_name: str = task.model or DEFAULT_MODEL model_settings: dict[str, Any] = {} - if model in model_keys: - if model in models_params: - model_settings = models_params[model].copy() - model = model_dict[model] + api_type: str = default_api_type + endpoint: str | None = None + token: str | None = None + + if logical_name in model_keys: + if logical_name in models_params: + model_settings = models_params[logical_name].copy() + logical_name = model_dict[logical_name] + + # Extract engine-level keys before merging task settings + api_type = model_settings.pop("api_type", api_type) + endpoint = model_settings.pop("endpoint", None) + token = model_settings.pop("token", None) + task_model_settings: dict[str, Any] | Any = task.model_settings or {} if not isinstance(task_model_settings, dict): raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") - model_settings.update(task_model_settings) - return model, model_settings + + # Task-level overrides can also set engine keys + task_settings = dict(task_model_settings) + api_type = task_settings.pop("api_type", api_type) + endpoint = task_settings.pop("endpoint", endpoint) + token = task_settings.pop("token", token) + + model_settings.update(task_settings) + return logical_name, model_settings, api_type, endpoint, token async def _build_prompts_to_run( @@ -226,6 +243,8 @@ async def deploy_task_agents( model: str = DEFAULT_MODEL, model_par: dict[str, Any] | None = None, api_type: str = "chat_completions", + endpoint: str | None = None, + token: str | None = None, run_hooks: TaskRunHooks | None = None, agent_hooks: TaskAgentHooks | None = None, ) -> bool: @@ -236,6 +255,8 @@ async def deploy_task_agents( agents: Mapping of agent name -> PersonalityDocument. prompt: User prompt to execute. api_type: OpenAI API type -- ``"chat_completions"`` or ``"responses"``. + endpoint: Optional per-model API endpoint URL override. + token: Optional env var name to resolve as the API token. Returns: True if the task completed successfully. @@ -312,6 +333,8 @@ async def deploy_task_agents( model=model, model_settings=model_settings, api_type=api_type, + endpoint=endpoint, + token=token, run_hooks=run_hooks, agent_hooks=agent_hooks, ).agent @@ -335,6 +358,8 @@ async def deploy_task_agents( model=model, model_settings=model_settings, api_type=api_type, + endpoint=endpoint, + token=token, run_hooks=run_hooks, agent_hooks=agent_hooks, ) @@ -463,8 +488,10 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo if task.uses: task = _merge_reusable_task(available_tools, task) - # Resolve model - model, model_settings = _resolve_task_model(task, model_keys, model_dict, models_params) + # Resolve model (name, settings, api_type, optional endpoint/token) + model, model_settings, task_api_type, task_endpoint, task_token = _resolve_task_model( + task, model_keys, model_dict, models_params, default_api_type=api_type, + ) # Read task fields via typed attributes agents_list = task.agents or [] @@ -549,7 +576,9 @@ async def _deploy(ra: dict, pp: str) -> bool: ), model=model, model_par=model_settings, - api_type=api_type, + api_type=task_api_type, + endpoint=task_endpoint, + token=task_token, agent_hooks=TaskAgentHooks(on_handoff=on_handoff_hook), ) From 54d5d2843bc955918b6b07ce0da79ecf2c04f9c8 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:25:32 -0400 Subject: [PATCH 06/20] docs: update README and GRAMMAR for responses API and per-model config Document api_type, endpoint, and token per-model overrides. Update architecture module listing in README. Add per-model settings reference table to GRAMMAR.md. --- README.md | 34 +++++++++++++++++++++++++++++++--- doc/GRAMMAR.md | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 448bef3d..6ada8981 100644 --- a/README.md +++ b/README.md @@ -44,14 +44,42 @@ You can find a detailed overview of the taskflow grammar [here](doc/GRAMMAR.md) └─────────────────────────────────────────────────────┘ Supporting modules: - models.py — Pydantic v2 grammar models (validation) + models.py — Pydantic v2 grammar models (validation) available_tools.py — YAML resource loader with caching template_utils.py — Jinja2 template environment - mcp_utils.py — MCP namespace wrapping, system prompts - capi.py — AI API endpoint management + mcp_utils.py — MCP client parameter resolution + mcp_transport.py — MCP transport implementations (stdio, streamable) + mcp_prompt.py — System prompt construction + prompt_parser.py — Legacy prompt argument parser + capi.py — AI API endpoint and token management path_utils.py — Platform-aware data/log directories ``` +### API Types + +The agent supports both the **Chat Completions** and **Responses** OpenAI APIs. +The API type can be configured globally or per model in a `model_config` file: + +```yaml +seclab-taskflow-agent: + version: "1.0" + filetype: model_config +api_type: chat_completions # default for all models +models: + gpt_default: gpt-4.1 + gpt_responses: gpt-5.1 +model_settings: + gpt_responses: + api_type: responses # override for this model + endpoint: https://api.githubcopilot.com + token: CAPI_TOKEN # env var name containing the API key +``` + +Per-model `model_settings` can include: +- **`api_type`** — `"chat_completions"` (default) or `"responses"` +- **`endpoint`** — API base URL override for this model +- **`token`** — name of an environment variable containing the API key + ## Use Cases and Examples The Seclab Taskflow Agent framework was primarily designed to fit the iterative feedback loop driven work involved in Agentic security research workflows and vulnerability triage tasks. diff --git a/doc/GRAMMAR.md b/doc/GRAMMAR.md index 1ebfaf5e..86c1ac62 100644 --- a/doc/GRAMMAR.md +++ b/doc/GRAMMAR.md @@ -509,4 +509,39 @@ When `gpt_latest` is used in the taskflow to specify a model, the value `gpt-5` ``` -This provides a easy way to update model versions in a taskflow. +This provides an easy way to update model versions in a taskflow. + +#### Per-model settings + +A `model_config` file can include per-model settings via `model_settings` and a +global `api_type` that applies to all models unless overridden: + +```yaml +seclab-taskflow-agent: + version: "1.0" + filetype: model_config +api_type: chat_completions # default for all models +models: + gpt_default: gpt-4.1 + gpt_responses: gpt-5.1 +model_settings: + gpt_default: + temperature: 0.7 + gpt_responses: + api_type: responses # use the Responses API for this model + endpoint: https://api.githubcopilot.com + token: CAPI_TOKEN # env var name containing the API key + temperature: 0.5 +``` + +The following keys in `model_settings` are handled by the engine and are not +passed to the underlying model provider: + +| Key | Description | Default | +|-----|-------------|---------| +| `api_type` | `"chat_completions"` or `"responses"` | Inherited from top-level `api_type`, or `"chat_completions"` | +| `endpoint` | API base URL for this model | The global `AI_API_ENDPOINT` env var | +| `token` | Name of an environment variable containing the API key | Uses `AI_API_TOKEN` / `COPILOT_TOKEN` | + +All other keys (e.g. `temperature`, `top_p`) are passed through as model +parameters to the OpenAI SDK. From 6d5f19e5fbbaa891559d376bc27cdcbcd5d8e377 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:37:49 -0400 Subject: [PATCH 07/20] refactor: fix mutable defaults, add __all__ exports and docstrings - Fix mutable default args in TaskAgent.__init__ (list -> None) - Replace bare except Exception with specific exception types - Add __all__ exports to all 14 submodules - Add docstrings to all hook methods in agent.py - Add return type to banner.py and prompt_parser.py - Add module docstrings to capi.py, env_utils.py, shell_utils.py - Remove B006 ruff suppression (no longer needed) --- pyproject.toml | 1 - src/seclab_taskflow_agent/agent.py | 24 ++++++++++++++++---- src/seclab_taskflow_agent/available_tools.py | 2 ++ src/seclab_taskflow_agent/banner.py | 8 ++++++- src/seclab_taskflow_agent/capi.py | 13 ++++++++++- src/seclab_taskflow_agent/cli.py | 2 ++ src/seclab_taskflow_agent/env_utils.py | 4 ++++ src/seclab_taskflow_agent/mcp_lifecycle.py | 2 ++ src/seclab_taskflow_agent/mcp_prompt.py | 2 ++ src/seclab_taskflow_agent/mcp_transport.py | 6 +++++ src/seclab_taskflow_agent/mcp_utils.py | 8 +++++++ src/seclab_taskflow_agent/models.py | 15 ++++++++++++ src/seclab_taskflow_agent/prompt_parser.py | 6 ++++- src/seclab_taskflow_agent/render_utils.py | 2 ++ src/seclab_taskflow_agent/runner.py | 9 ++++++++ src/seclab_taskflow_agent/shell_utils.py | 4 ++++ src/seclab_taskflow_agent/template_utils.py | 8 ++++++- 17 files changed, 107 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 894503de..bb75193f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,7 +178,6 @@ ignore = [ # Backwards-compatibility suppressions for existing code "A001", # Variable shadows built-in (existing API names) "A002", # Argument shadows built-in (existing API signatures) - "B006", # Mutable default argument (existing signatures, would break API) "FBT001", # Boolean positional arg (existing API) "FBT002", # Boolean default value (existing API) "N802", # Function name casing (existing API: get_AI_endpoint etc.) diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 65640448..d38e7cf2 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -30,6 +30,13 @@ from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token +__all__ = [ + "DEFAULT_MODEL", + "TaskAgent", + "TaskAgentHooks", + "TaskRunHooks", +] + # grab our secrets from .env, this must be in .gitignore load_dotenv(find_dotenv(usecwd=True)) @@ -64,16 +71,19 @@ def __init__( self._on_tool_end = on_tool_end async def on_agent_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + """Called when an agent begins execution.""" logging.debug(f"TaskRunHooks on_agent_start: {agent.name}") if self._on_agent_start: await self._on_agent_start(context, agent) async def on_agent_end(self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any) -> None: + """Called when an agent finishes execution.""" logging.debug(f"TaskRunHooks on_agent_end: {agent.name}") if self._on_agent_end: await self._on_agent_end(context, agent, output) async def on_tool_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: + """Called before a tool invocation begins.""" logging.debug(f"TaskRunHooks on_tool_start: {tool.name}") if self._on_tool_start: await self._on_tool_start(context, agent, tool) @@ -81,6 +91,7 @@ async def on_tool_start(self, context: RunContextWrapper[TContext], agent: Agent async def on_tool_end( self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str ) -> None: + """Called after a tool invocation completes.""" logging.debug(f"TaskRunHooks on_tool_end: {tool.name} ") if self._on_tool_end: await self._on_tool_end(context, agent, tool, result) @@ -107,21 +118,25 @@ def __init__( async def on_handoff( self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext] ) -> None: + """Called when control is handed off from one agent to another.""" logging.debug(f"TaskAgentHooks on_handoff: {source.name} -> {agent.name}") if self._on_handoff: await self._on_handoff(context, agent, source) async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + """Called when the agent starts processing.""" logging.debug(f"TaskAgentHooks on_start: {agent.name}") if self._on_start: await self._on_start(context, agent) async def on_end(self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any) -> None: + """Called when the agent finishes processing.""" logging.debug(f"TaskAgentHooks on_end: {agent.name}") if self._on_end: await self._on_end(context, agent, output) async def on_tool_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool) -> None: + """Called before a tool invocation begins.""" logging.debug(f"TaskAgentHooks on_tool_start: {tool.name}") if self._on_tool_start: await self._on_tool_start(context, agent, tool) @@ -129,6 +144,7 @@ async def on_tool_start(self, context: RunContextWrapper[TContext], agent: Agent async def on_tool_end( self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str ) -> None: + """Called after a tool invocation completes.""" logging.debug(f"TaskAgentHooks on_tool_end: {tool.name}") if self._on_tool_end: await self._on_tool_end(context, agent, tool, result) @@ -145,9 +161,9 @@ def __init__( self, name: str = "TaskAgent", instructions: str = "", - handoffs: list[Any] = [], + handoffs: list[Any] | None = None, exclude_from_context: bool = False, - mcp_servers: list[Any] = [], + mcp_servers: list[Any] | None = None, model: str = DEFAULT_MODEL, model_settings: ModelSettings | None = None, api_type: str = "chat_completions", @@ -199,8 +215,8 @@ def _ToolsToFinalOutputFunction( instructions=instructions, tool_use_behavior=_ToolsToFinalOutputFunction if exclude_from_context else "run_llm_again", model=model_impl, - handoffs=handoffs, - mcp_servers=mcp_servers, + handoffs=handoffs or [], + mcp_servers=mcp_servers or [], model_settings=model_settings or ModelSettings(), hooks=agent_hooks or TaskAgentHooks(), ) diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index 3966c756..812599a6 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -9,6 +9,8 @@ from __future__ import annotations +__all__ = ["AvailableTools"] + import importlib.resources from enum import Enum from typing import Union diff --git a/src/seclab_taskflow_agent/banner.py b/src/seclab_taskflow_agent/banner.py index cc4150a1..35bf0dd4 100644 --- a/src/seclab_taskflow_agent/banner.py +++ b/src/seclab_taskflow_agent/banner.py @@ -1,9 +1,15 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""ASCII banner displayed at agent startup.""" + from .capi import get_AI_endpoint -def get_banner(): +__all__ = ["get_banner"] + + +def get_banner() -> str: + """Return the ASCII art startup banner with the configured endpoint.""" api_endpoint = get_AI_endpoint() banner = f""" ╔══════════════════════════════════════════════════════════════════╗ diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index a43333d1..0461a748 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -1,7 +1,8 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT -# CAPI specific interactions +"""AI API endpoint and token management (CAPI integration).""" + import json import logging import os @@ -10,6 +11,16 @@ import httpx from strenum import StrEnum +__all__ = [ + "AI_API_ENDPOINT_ENUM", + "COPILOT_INTEGRATION_ID", + "get_AI_endpoint", + "get_AI_token", + "list_capi_models", + "list_tool_call_models", + "supports_tool_calls", +] + # Enumeration of currently supported API endpoints. class AI_API_ENDPOINT_ENUM(StrEnum): diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 31cf626e..96d59447 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -10,6 +10,8 @@ from __future__ import annotations +__all__ = ["app", "main"] + import asyncio import logging import os diff --git a/src/seclab_taskflow_agent/env_utils.py b/src/seclab_taskflow_agent/env_utils.py index 597d9ce2..3ccfd77c 100644 --- a/src/seclab_taskflow_agent/env_utils.py +++ b/src/seclab_taskflow_agent/env_utils.py @@ -1,11 +1,15 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""Environment variable utilities for taskflow execution.""" + import os from typing import Any import jinja2 +__all__ = ["TmpEnv", "swap_env"] + def swap_env(s: str) -> str: diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index 365a85aa..bee2f9c9 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -9,6 +9,8 @@ from __future__ import annotations +__all__ = ["MCP_CLEANUP_TIMEOUT", "build_mcp_servers", "mcp_session_task"] + import asyncio import logging from typing import TYPE_CHECKING diff --git a/src/seclab_taskflow_agent/mcp_prompt.py b/src/seclab_taskflow_agent/mcp_prompt.py index 64b5b3ab..f17332b8 100644 --- a/src/seclab_taskflow_agent/mcp_prompt.py +++ b/src/seclab_taskflow_agent/mcp_prompt.py @@ -9,6 +9,8 @@ from __future__ import annotations +__all__ = ["mcp_system_prompt"] + def mcp_system_prompt( system_prompt: str, diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py index 22ea1abe..f2ce5165 100644 --- a/src/seclab_taskflow_agent/mcp_transport.py +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -14,6 +14,12 @@ from __future__ import annotations +__all__ = [ + "AsyncDebugMCPServerStdio", + "ReconnectingMCPServerStdio", + "StreamableMCPThread", +] + import asyncio import logging import os diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index 5fa3bed7..a186bee2 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -9,6 +9,14 @@ from __future__ import annotations +__all__ = [ + "COMPRESSED_NAME_LENGTH", + "DEFAULT_MCP_CLIENT_SESSION_TIMEOUT", + "MCPNamespaceWrap", + "compress_name", + "mcp_client_params", +] + import hashlib import json import logging diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py index f5f658a8..39affc73 100644 --- a/src/seclab_taskflow_agent/models.py +++ b/src/seclab_taskflow_agent/models.py @@ -10,6 +10,21 @@ from __future__ import annotations +__all__ = [ + "ApiType", + "DOCUMENT_MODELS", + "ModelConfigDocument", + "PersonalityDocument", + "PromptDocument", + "SUPPORTED_VERSION", + "ServerParams", + "TaskDefinition", + "TaskWrapper", + "TaskflowDocument", + "TaskflowHeader", + "ToolboxDocument", +] + from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator diff --git a/src/seclab_taskflow_agent/prompt_parser.py b/src/seclab_taskflow_agent/prompt_parser.py index c24c48c5..332c93c0 100644 --- a/src/seclab_taskflow_agent/prompt_parser.py +++ b/src/seclab_taskflow_agent/prompt_parser.py @@ -17,8 +17,12 @@ from .available_tools import AvailableTools +__all__ = ["parse_prompt_args"] -def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = None): + +def parse_prompt_args( + available_tools: AvailableTools, user_prompt: str | None = None +) -> tuple[str | None, str | None, bool, dict[str, str], str, str] | tuple[None, None, None, None, str]: """Legacy CLI parser kept for backwards compatibility with tests. Returns: diff --git a/src/seclab_taskflow_agent/render_utils.py b/src/seclab_taskflow_agent/render_utils.py index f5fad8a7..7e91144f 100644 --- a/src/seclab_taskflow_agent/render_utils.py +++ b/src/seclab_taskflow_agent/render_utils.py @@ -8,6 +8,8 @@ from .path_utils import log_file_name +__all__ = ["flush_async_output", "render_model_output"] + async_output = {} async_output_lock = asyncio.Lock() diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 1c5e74fc..39281fec 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -10,6 +10,15 @@ from __future__ import annotations +__all__ = [ + "DEFAULT_MAX_TURNS", + "MAX_API_RETRY", + "MAX_RATE_LIMIT_BACKOFF", + "RATE_LIMIT_BACKOFF", + "deploy_task_agents", + "run_main", +] + import asyncio import json import logging diff --git a/src/seclab_taskflow_agent/shell_utils.py b/src/seclab_taskflow_agent/shell_utils.py index caeb7f26..75175eca 100644 --- a/src/seclab_taskflow_agent/shell_utils.py +++ b/src/seclab_taskflow_agent/shell_utils.py @@ -1,12 +1,16 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""Shell command execution utilities.""" + import logging import subprocess import tempfile from mcp.types import CallToolResult, TextContent +__all__ = ["shell_command_to_string", "shell_exec_with_temporary_file", "shell_tool_call"] + def shell_command_to_string(cmd: list[str]) -> str: """Execute a shell command and return its stdout. diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 8de477b3..2f21d4a6 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -8,9 +8,13 @@ import jinja2 +__all__ = ["PromptLoader", "create_jinja_environment", "env_function", "render_template"] + if TYPE_CHECKING: from .available_tools import AvailableTools +from .available_tools import BadToolNameError + class PromptLoader(jinja2.BaseLoader): """Custom Jinja2 loader for reusable prompts.""" @@ -46,7 +50,9 @@ def get_source( source = prompt_data.prompt or "" # Return: (source, filename, uptodate_func) return source, None, lambda: True - except Exception: + except jinja2.TemplateNotFound: + raise + except (BadToolNameError, KeyError, AttributeError, FileNotFoundError): raise jinja2.TemplateNotFound(template) From 85d9e14816204064d7d326c49a2c716c136ddae8 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:38:43 -0400 Subject: [PATCH 08/20] fix: lint errors in mcp_servers (missing import, bare except, type comparisons) --- src/seclab_taskflow_agent/mcp_servers/codeql/client.py | 1 + .../mcp_servers/codeql/jsonrpyc/__init__.py | 2 +- .../mcp_servers/memcache/memcache_backend/dictionary_file.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py index 59cd4859..af7d03d5 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py @@ -3,6 +3,7 @@ # a query-server2 codeql client import json +import logging import os import re import subprocess diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py index 3dfc5932..8d14d96b 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py @@ -112,7 +112,7 @@ def check_code(cls, code: int, /) -> None: """ try: get_error(code) - except: + except Exception: raise TypeError(f"invalid error code, got {code} ({type(code)})") @classmethod diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py index bb1a51ad..72407afb 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py @@ -84,10 +84,10 @@ def add_state(self, key, value): @self.with_memory def _add_state(key: str, value: Any) -> str: existing = self.memcache.get(key) - if type(existing) == type(value) and hasattr(existing, "__add__"): + if type(existing) is type(value) and hasattr(existing, "__add__"): self.memcache[key] = existing + value return f"Updated and added to value in memory for key: `{key}`" - if type(existing) == list: + if isinstance(existing, list): self.memcache[key].append(value) return f"Updated and added to value in memory for key: `{key}`" return f"Error: unsupported types for memcache add `{type(existing)} + {type(value)}` for key `{key}`" From 476967e1b0df1eb1c72ffc71e84a9cce5e165ed3 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:46:13 -0400 Subject: [PATCH 09/20] feat: concise error messages, full tracebacks only with --debug Add --debug/-d flag and TASK_AGENT_DEBUG env var. Default: one-line error chain. Debug: full traceback. --- src/seclab_taskflow_agent/cli.py | 44 +++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 96d59447..4603557d 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -15,6 +15,7 @@ import asyncio import logging import os +import traceback from typing import Annotated import typer @@ -61,6 +62,24 @@ def _setup_logging() -> None: root.addHandler(console_handler) +def _print_concise_error(exc: BaseException) -> None: + """Print a concise error chain without full tracebacks. + + Walks the exception cause chain and prints each error on a single + line. Use ``--debug`` or ``TASK_AGENT_DEBUG=1`` for full tracebacks. + """ + import typer + + seen: set[int] = set() + current: BaseException | None = exc + while current and id(current) not in seen: + seen.add(id(current)) + label = type(current).__qualname__ + typer.echo(f"Error: [{label}] {current}", err=True) + current = current.__cause__ or current.__context__ + typer.echo("(use --debug for full traceback)", err=True) + + @app.command() def main( personality: Annotated[ @@ -79,12 +98,19 @@ def main( list[str] | None, typer.Option("-g", "--global", help="Global variable as KEY=VALUE. Repeatable."), ] = None, + debug: Annotated[ + bool, + typer.Option("-d", "--debug", help="Show full tracebacks on errors."), + ] = False, prompt: Annotated[ list[str] | None, typer.Argument(help="Remaining prompt text."), ] = None, ) -> None: """Run a taskflow or personality-based agent session.""" + # Debug mode from flag or env var + debug = debug or bool(os.getenv("TASK_AGENT_DEBUG")) + # Validate mutual exclusivity specified = sum(bool(x) for x in [personality, taskflow, list_models]) if specified > 1: @@ -118,10 +144,20 @@ def main( from .runner import run_main - asyncio.run( - run_main(available_tools, personality, taskflow, cli_globals, user_prompt), - debug=os.getenv("TASK_AGENT_LOGLEVEL", "").upper() == "DEBUG", - ) + try: + asyncio.run( + run_main(available_tools, personality, taskflow, cli_globals, user_prompt), + debug=debug, + ) + except KeyboardInterrupt: + typer.echo("\nInterrupted.", err=True) + raise typer.Exit(code=130) + except Exception as exc: + if debug: + traceback.print_exc() + else: + _print_concise_error(exc) + raise typer.Exit(code=1) # --------------------------------------------------------------------------- From 30c9694266150ea0c5e16666461b7b81e2931f2a Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 15:58:52 -0400 Subject: [PATCH 10/20] feat: session checkpoint/resume with auto-retry Add TaskflowSession model for task-level checkpointing. Save progress after each task, auto-retry failed tasks 3x. New --resume flag to continue from last checkpoint. Includes 9 tests for session persistence. --- src/seclab_taskflow_agent/__init__.py | 1 + src/seclab_taskflow_agent/cli.py | 22 +++- src/seclab_taskflow_agent/path_utils.py | 24 ++++- src/seclab_taskflow_agent/runner.py | 99 ++++++++++++++++- src/seclab_taskflow_agent/session.py | 137 ++++++++++++++++++++++++ tests/test_session.py | 94 ++++++++++++++++ 6 files changed, 367 insertions(+), 10 deletions(-) create mode 100644 src/seclab_taskflow_agent/session.py create mode 100644 tests/test_session.py diff --git a/src/seclab_taskflow_agent/__init__.py b/src/seclab_taskflow_agent/__init__.py index b87f30ce..fe6f50a3 100644 --- a/src/seclab_taskflow_agent/__init__.py +++ b/src/seclab_taskflow_agent/__init__.py @@ -18,6 +18,7 @@ - :mod:`~seclab_taskflow_agent.mcp_utils` — MCP client parameter resolution - :mod:`~seclab_taskflow_agent.mcp_transport` — MCP transport implementations - :mod:`~seclab_taskflow_agent.mcp_prompt` — System prompt construction +- :mod:`~seclab_taskflow_agent.session` — Taskflow checkpoint / resume - :mod:`~seclab_taskflow_agent.template_utils` — Jinja2 template rendering - :mod:`~seclab_taskflow_agent.prompt_parser` — Legacy prompt argument parser """ diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 4603557d..3e11cf75 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -102,6 +102,10 @@ def main( bool, typer.Option("-d", "--debug", help="Show full tracebacks on errors."), ] = False, + resume: Annotated[ + str | None, + typer.Option("--resume", help="Resume a previous session by its ID."), + ] = None, prompt: Annotated[ list[str] | None, typer.Argument(help="Remaining prompt text."), @@ -111,7 +115,11 @@ def main( # Debug mode from flag or env var debug = debug or bool(os.getenv("TASK_AGENT_DEBUG")) - # Validate mutual exclusivity + # Validate mutual exclusivity (resume is standalone) + if resume and (personality or taskflow or list_models): + typer.echo("Error: --resume cannot be combined with -p, -t, or -l.", err=True) + raise typer.Exit(code=1) + specified = sum(bool(x) for x in [personality, taskflow, list_models]) if specified > 1: typer.echo("Error: -p, -t, and -l are mutually exclusive.", err=True) @@ -128,8 +136,8 @@ def main( typer.echo(model) raise typer.Exit() - if personality is None and taskflow is None: - typer.echo("Error: one of -p or -t is required.", err=True) + if personality is None and taskflow is None and resume is None: + typer.echo("Error: one of -p, -t, or --resume is required.", err=True) raise typer.Exit(code=1) # Parse global variables @@ -144,9 +152,15 @@ def main( from .runner import run_main + # When resuming, the session carries taskflow_path/globals/prompt + effective_taskflow = taskflow if not resume else None + try: asyncio.run( - run_main(available_tools, personality, taskflow, cli_globals, user_prompt), + run_main( + available_tools, personality, effective_taskflow, + cli_globals, user_prompt, resume_session_id=resume, + ), debug=debug, ) except KeyboardInterrupt: diff --git a/src/seclab_taskflow_agent/path_utils.py b/src/seclab_taskflow_agent/path_utils.py index aa069d23..5d7d8414 100644 --- a/src/seclab_taskflow_agent/path_utils.py +++ b/src/seclab_taskflow_agent/path_utils.py @@ -1,11 +1,31 @@ # SPDX-FileCopyrightText: GitHub, Inc. # SPDX-License-Identifier: MIT +"""Platform-aware data and log directory resolution.""" + import os from pathlib import Path import platformdirs +__all__ = [ + "log_dir", + "log_file", + "log_file_name", + "mcp_data_dir", +] + + +def _data_dir() -> Path: + """Return the top-level application data directory (created if needed).""" + return Path( + platformdirs.user_data_dir( + appname="seclab-taskflow-agent", + appauthor="GitHubSecurityLab", + ensure_exists=True, + ) + ) + def mcp_data_dir(packagename: str, mcpname: str, env_override: str | None) -> Path: """ @@ -23,10 +43,8 @@ def mcp_data_dir(packagename: str, mcpname: str, env_override: str | None) -> Pa p = os.getenv(env_override) if p: return Path(p) - # Use [platformdirs](https://pypi.org/project/platformdirs/) to - # choose an appropriate location. - d = platformdirs.user_data_dir(appname="seclab-taskflow-agent", appauthor="GitHubSecurityLab", ensure_exists=True) # Each MCP server gets its own sub-directory + d = _data_dir() p = Path(d).joinpath(packagename).joinpath(mcpname) p.mkdir(parents=True, exist_ok=True) return p diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 39281fec..078cfc1d 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -49,6 +49,8 @@ RATE_LIMIT_BACKOFF = 5 # Initial backoff in seconds after a rate-limit response MAX_RATE_LIMIT_BACKOFF = 120 # Maximum backoff cap in seconds for rate-limit retries MAX_API_RETRY = 5 # Maximum number of consecutive API error retries +TASK_RETRY_LIMIT = 3 # Maximum retry attempts for a failed task +TASK_RETRY_BACKOFF = 10 # Initial backoff in seconds between task retries def _resolve_model_config( @@ -442,6 +444,7 @@ async def run_main( taskflow_path: str | None, cli_globals: dict[str, str], prompt: str | None, + resume_session_id: str | None = None, ) -> None: """Main entry point for taskflow/personality execution. @@ -451,7 +454,10 @@ async def run_main( taskflow_path: Taskflow module path, or None. cli_globals: Global variables from CLI. prompt: User prompt text. + resume_session_id: Session ID to resume from a checkpoint. """ + from .session import TaskflowSession + last_mcp_tool_results: list[str] = [] async def on_tool_end_hook(context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool, result: str) -> None: @@ -472,7 +478,22 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo run_hooks=TaskRunHooks(on_tool_end=on_tool_end_hook, on_tool_start=on_tool_start_hook), ) - if taskflow_path: + if taskflow_path or resume_session_id: + # Handle session resume + session: TaskflowSession | None = None + if resume_session_id: + session = TaskflowSession.load(resume_session_id) + if session.finished: + await render_model_output(f"** 🤖✅ Session {resume_session_id} already completed\n") + return + taskflow_path = session.taskflow_path + cli_globals = session.cli_globals + prompt = session.prompt + last_mcp_tool_results = list(session.last_tool_results) + await render_model_output( + f"** 🤖🔄 Resuming session {resume_session_id} from task {session.next_task_index}\n" + ) + taskflow_doc = available_tools.get_taskflow(taskflow_path) await render_model_output(f"** 🤖💪 Running Task Flow: {taskflow_path}\n") @@ -490,7 +511,25 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo if model_config_ref: model_keys, model_dict, models_params, api_type = _resolve_model_config(available_tools, model_config_ref) - for task_wrapper in taskflow_doc.taskflow: + # Create session if this is a new run (not personality mode) + if session is None: + session = TaskflowSession( + taskflow_path=taskflow_path, + cli_globals=cli_globals, + prompt=prompt or "", + total_tasks=len(taskflow_doc.taskflow), + ) + session.save() + await render_model_output(f"** 🤖📋 Session: {session.session_id}\n") + + for task_index, task_wrapper in enumerate(taskflow_doc.taskflow): + # Skip already-completed tasks on resume + if task_index < session.next_task_index: + await render_model_output( + f"** 🤖⏭️ Skipping completed task {task_index}\n" + ) + continue + task = task_wrapper.task # Reusable taskflow support: merge parent defaults into current task @@ -610,9 +649,63 @@ async def _deploy(ra: dict, pp: str) -> bool: complete = result and complete return complete - task_complete = await run_prompts(async_task=async_task, max_concurrent_tasks=max_concurrent_tasks) + # Execute the task with auto-retry on failure + task_name = task.name or f"task-{task_index}" + task_complete = False + last_task_error: BaseException | None = None + + for attempt in range(TASK_RETRY_LIMIT): + try: + task_complete = await run_prompts( + async_task=async_task, + max_concurrent_tasks=max_concurrent_tasks, + ) + last_task_error = None + break + except (KeyboardInterrupt, SystemExit): + raise + except Exception as exc: + last_task_error = exc + remaining = TASK_RETRY_LIMIT - attempt - 1 + if remaining > 0: + backoff = TASK_RETRY_BACKOFF * (attempt + 1) + await render_model_output( + f"** 🤖🔄 Task {task_name!r} failed: {exc}\n" + f"** 🤖🔄 Retrying in {backoff}s ({remaining} attempts left)\n" + ) + logging.warning(f"Task {task_name!r} attempt {attempt + 1} failed: {exc}") + await asyncio.sleep(backoff) + else: + logging.error(f"Task {task_name!r} failed after {TASK_RETRY_LIMIT} attempts: {exc}") + + # If all retries exhausted with an exception, save and re-raise + if last_task_error is not None: + session.mark_failed(f"Task {task_name!r}: {last_task_error}") + await render_model_output( + f"** 🤖💾 Session saved: {session.session_id}\n" + f"** 🤖💡 Resume with: --resume {session.session_id}\n" + ) + raise last_task_error + + # Checkpoint after successful task + session.record_task( + index=task_index, + name=task_name, + success=task_complete, + tool_results=list(last_mcp_tool_results), + ) if must_complete and not task_complete: logging.critical("Required task not completed ... aborting!") await render_model_output("🤖💥 *Required task not completed ...\n") + session.mark_failed(f"Required task {task_name!r} did not complete") + await render_model_output( + f"** 🤖💾 Session saved: {session.session_id}\n" + f"** 🤖💡 Resume with: --resume {session.session_id}\n" + ) break + + # All tasks completed successfully + if session is not None and not session.error: + session.mark_finished() + await render_model_output(f"** 🤖✅ Session {session.session_id} completed\n") diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py new file mode 100644 index 00000000..cb8d0f8b --- /dev/null +++ b/src/seclab_taskflow_agent/session.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Taskflow session persistence for checkpoint/resume. + +Tracks task-level progress through a taskflow so that execution can be +resumed from the last successful checkpoint after an unrecoverable failure. + +Session files are stored as JSON in the platformdirs data directory. +""" + +from __future__ import annotations + +__all__ = [ + "TaskflowSession", + "session_dir", +] + +import logging +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from pydantic import BaseModel, Field + +from .path_utils import _data_dir + + +def session_dir() -> Path: + """Return (and create) the directory used for session checkpoint files.""" + d = _data_dir() / "sessions" + d.mkdir(parents=True, exist_ok=True) + return d + + +class CompletedTask(BaseModel): + """Record of a single completed task within a session.""" + + index: int + name: str = "" + result: bool = False + tool_results: list[str] = Field(default_factory=list) + + +class TaskflowSession(BaseModel): + """Persistent session state for a taskflow run. + + After each task completes the session is saved to disk so that a + subsequent ``--resume`` invocation can skip already-completed tasks. + """ + + session_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12]) + taskflow_path: str = "" + cli_globals: dict[str, str] = Field(default_factory=dict) + prompt: str = "" + created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + updated_at: str = "" + completed_tasks: list[CompletedTask] = Field(default_factory=list) + total_tasks: int = 0 + finished: bool = False + error: str = "" + + # Accumulated tool results carried across tasks (used by repeat_prompt) + last_tool_results: list[str] = Field(default_factory=list) + + @property + def next_task_index(self) -> int: + """Index of the next task to execute.""" + if not self.completed_tasks: + return 0 + return max(t.index for t in self.completed_tasks) + 1 + + @property + def file_path(self) -> Path: + """Path to this session's checkpoint file.""" + return session_dir() / f"{self.session_id}.json" + + def save(self) -> Path: + """Persist session state to disk, returns the file path.""" + self.updated_at = datetime.now(timezone.utc).isoformat() + path = self.file_path + path.write_text(self.model_dump_json(indent=2)) + logging.debug(f"Session checkpoint saved: {path}") + return path + + def record_task( + self, + index: int, + name: str, + success: bool, + tool_results: list[str] | None = None, + ) -> None: + """Record a completed task and save the checkpoint.""" + self.completed_tasks.append( + CompletedTask( + index=index, + name=name, + result=success, + tool_results=tool_results or [], + ) + ) + if tool_results: + self.last_tool_results = list(tool_results) + self.save() + + def mark_finished(self) -> None: + """Mark the session as fully completed and save.""" + self.finished = True + self.save() + + def mark_failed(self, error: str) -> None: + """Mark the session as failed with an error message and save.""" + self.error = error + self.save() + + @classmethod + def load(cls, session_id: str) -> TaskflowSession: + """Load a session from disk by its ID. + + Raises: + FileNotFoundError: If no checkpoint file exists for the ID. + """ + path = session_dir() / f"{session_id}.json" + if not path.exists(): + raise FileNotFoundError(f"No session checkpoint found: {session_id}") + return cls.model_validate_json(path.read_text()) + + @classmethod + def list_sessions(cls) -> list[TaskflowSession]: + """List all saved sessions, most recent first.""" + sessions: list[TaskflowSession] = [] + for f in sorted(session_dir().glob("*.json"), reverse=True): + try: + sessions.append(cls.model_validate_json(f.read_text())) + except Exception: + logging.warning(f"Skipping corrupt session file: {f}") + return sessions diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..d363a302 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for the session checkpoint/resume module.""" + +import pytest + +from seclab_taskflow_agent.session import CompletedTask, TaskflowSession, session_dir + + +class TestTaskflowSession: + """Tests for TaskflowSession persistence.""" + + def test_create_session(self): + """A new session gets a unique ID and starts at task 0.""" + s = TaskflowSession(taskflow_path="examples.taskflows.echo") + assert len(s.session_id) == 12 + assert s.next_task_index == 0 + assert s.finished is False + assert s.error == "" + + def test_record_task_advances_index(self): + """Recording a task increments next_task_index.""" + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="task-0", success=True, tool_results=["r1"]) + assert s.next_task_index == 1 + assert s.last_tool_results == ["r1"] + s.record_task(index=1, name="task-1", success=True) + assert s.next_task_index == 2 + + def test_save_and_load(self, tmp_path, monkeypatch): + """Session can round-trip through JSON on disk.""" + monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path) + s = TaskflowSession( + taskflow_path="examples.taskflows.echo", + cli_globals={"FOO": "bar"}, + total_tasks=3, + ) + s.record_task(index=0, name="first", success=True) + s.save() + + loaded = TaskflowSession.load(s.session_id) + assert loaded.session_id == s.session_id + assert loaded.taskflow_path == "examples.taskflows.echo" + assert loaded.next_task_index == 1 + assert loaded.cli_globals == {"FOO": "bar"} + + def test_load_missing_raises(self, tmp_path, monkeypatch): + """Loading a non-existent session raises FileNotFoundError.""" + monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path) + with pytest.raises(FileNotFoundError): + TaskflowSession.load("nonexistent") + + def test_mark_finished(self): + """mark_finished sets the finished flag.""" + s = TaskflowSession(taskflow_path="test.flow") + assert s.finished is False + s.mark_finished() + assert s.finished is True + + def test_mark_failed(self): + """mark_failed records the error message.""" + s = TaskflowSession(taskflow_path="test.flow") + s.mark_failed("something broke") + assert s.error == "something broke" + assert s.finished is False + + def test_list_sessions(self, tmp_path, monkeypatch): + """list_sessions returns all saved sessions.""" + monkeypatch.setattr("seclab_taskflow_agent.session.session_dir", lambda: tmp_path) + s1 = TaskflowSession(taskflow_path="flow1") + s2 = TaskflowSession(taskflow_path="flow2") + s1.save() + s2.save() + + sessions = TaskflowSession.list_sessions() + ids = {s.session_id for s in sessions} + assert s1.session_id in ids + assert s2.session_id in ids + + +class TestCompletedTask: + """Tests for CompletedTask model.""" + + def test_defaults(self): + t = CompletedTask(index=0) + assert t.name == "" + assert t.result is False + assert t.tool_results == [] + + def test_with_results(self): + t = CompletedTask(index=2, name="analyze", result=True, tool_results=["r1", "r2"]) + assert t.index == 2 + assert t.tool_results == ["r1", "r2"] From 5eaf29cb75bf4f7baed28835719eb85317014ab0 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 16:00:04 -0400 Subject: [PATCH 11/20] docs: add session recovery and error output sections to README --- README.md | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6ada8981..52005882 100644 --- a/README.md +++ b/README.md @@ -24,13 +24,13 @@ You can find a detailed overview of the taskflow grammar [here](doc/GRAMMAR.md) ``` ┌─────────────────────────────────────────────────────┐ │ CLI (cli.py) │ -│ Typer-based entry point: -p, -t, -l, -g KEY=VALUE │ +│ Typer-based entry point: -p, -t, -l, -g, --resume │ └─────────────────────┬───────────────────────────────┘ │ ┌─────────────────────▼───────────────────────────────┐ │ Runner (runner.py) │ │ Taskflow execution loop, model resolution, │ -│ template rendering, repeat-prompt iteration │ +│ template rendering, session checkpointing │ └─────────────────────┬───────────────────────────────┘ │ ┌─────────────────────▼───────────────────────────────┐ @@ -45,6 +45,7 @@ You can find a detailed overview of the taskflow grammar [here](doc/GRAMMAR.md) Supporting modules: models.py — Pydantic v2 grammar models (validation) + session.py — Task-level checkpoint / resume available_tools.py — YAML resource loader with caching template_utils.py — Jinja2 template environment mcp_utils.py — MCP client parameter resolution @@ -80,6 +81,40 @@ Per-model `model_settings` can include: - **`endpoint`** — API base URL override for this model - **`token`** — name of an environment variable containing the API key +### Session Recovery + +Taskflow runs are automatically checkpointed at the task level. If a task +fails after exhausting retries, the session is saved and can be resumed: + +``` +** 🤖💾 Session saved: abc123def456 +** 🤖💡 Resume with: --resume abc123def456 +``` + +Resume from the last successful checkpoint: + +```bash +python -m seclab_taskflow_agent --resume abc123def456 +``` + +Failed tasks are automatically retried up to 3 times with increasing backoff +before the session is saved. Session checkpoints are stored in the +platform-specific application data directory. + +### Error Output + +By default, errors are shown as concise one-line messages. Use `--debug` (or +set `TASK_AGENT_DEBUG=1`) for full tracebacks: + +```bash +# Concise (default) +Error: [BadRequestError] model 'foo' not found +(use --debug for full traceback) + +# Full traceback +python -m seclab_taskflow_agent --debug -t examples.taskflows.echo +``` + ## Use Cases and Examples The Seclab Taskflow Agent framework was primarily designed to fit the iterative feedback loop driven work involved in Agentic security research workflows and vulnerability triage tasks. From 2dd674b399982873f808961217bfc42a998a0dd0 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 17:11:12 -0400 Subject: [PATCH 12/20] test: add comprehensive taskflow exercising all grammar features Covers: shell tasks, repeat_prompt, async iteration, model_config, globals, inputs, env, MCP toolboxes, headless, blocked_tools, reusable tasks (uses), reusable prompts (include), and handoffs. --- examples/taskflows/comprehensive_test.yaml | 125 +++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 examples/taskflows/comprehensive_test.yaml diff --git a/examples/taskflows/comprehensive_test.yaml b/examples/taskflows/comprehensive_test.yaml new file mode 100644 index 00000000..fd251de8 --- /dev/null +++ b/examples/taskflows/comprehensive_test.yaml @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +# Comprehensive test taskflow that exercises every grammar feature: +# - model_config reference with model aliases +# - globals (with CLI override via -g) +# - inputs (task-level template variables) +# - env (task-scoped environment variables) +# - must_complete +# - exclude_from_context +# - max_steps +# - MCP toolboxes (echo) +# - shell task (run) +# - repeat_prompt + async iteration +# - reusable tasks (uses) +# - reusable prompts ({% include %}) +# - agent handoffs (multi-agent) +# - headless mode +# - blocked_tools + +seclab-taskflow-agent: + version: "1.0" + filetype: taskflow + +model_config: examples.model_configs.model_config + +globals: + topic: fruit + detail_level: brief + +taskflow: + # --------------------------------------------------------------- + # Task 1: Shell task — produces a JSON array for repeat_prompt + # Features: run, must_complete + # --------------------------------------------------------------- + - task: + name: generate-items + must_complete: true + run: | + echo '[{"name": "apple", "color": "red"}, {"name": "banana", "color": "yellow"}, {"name": "orange", "color": "orange"}]' + + # --------------------------------------------------------------- + # Task 2: Repeat prompt over shell output, async iteration + # Features: repeat_prompt, async, async_limit, exclude_from_context, + # model (alias), inputs, globals, env, max_steps + # --------------------------------------------------------------- + - task: + name: describe-items + repeat_prompt: true + async: true + async_limit: 3 + exclude_from_context: true + must_complete: true + model: gpt_default + max_steps: 10 + agents: + - examples.personalities.fruit_expert + inputs: + format: one-sentence + env: + FRUIT_MODE: "analysis" + user_prompt: | + The topic is {{ globals.topic }} at {{ globals.detail_level }} detail level. + Describe the {{ result.name }} (which is {{ result.color }}) in {{ inputs.format }} format. + + # --------------------------------------------------------------- + # Task 3: MCP tool call with echo server + # Features: toolboxes, headless, blocked_tools + # --------------------------------------------------------------- + - task: + name: echo-test + must_complete: true + headless: true + agents: + - examples.personalities.echo + user_prompt: | + Echo the following message: "All {{ globals.topic }} items processed successfully" + blocked_tools: + - nonexistent_tool_to_test_filtering + + # --------------------------------------------------------------- + # Task 4: Reusable task via `uses` + # Features: uses (inherits from single_step_taskflow) + # --------------------------------------------------------------- + - task: + name: reusable-task + uses: examples.taskflows.single_step_taskflow + model: gpt_default + + # --------------------------------------------------------------- + # Task 5: Reusable prompt via {% include %} + # Features: Jinja2 include directive, reusable prompts + # --------------------------------------------------------------- + - task: + name: include-prompt + agents: + - examples.personalities.fruit_expert + model: gpt_default + max_steps: 5 + user_prompt: | + Tell me about apples. + + {% include 'examples.prompts.example_prompt' %} + + Keep your answer to two sentences per fruit. + + # --------------------------------------------------------------- + # Task 6: Agent handoffs (multi-agent) + # Features: multiple agents (first=primary, rest=handoff targets) + # --------------------------------------------------------------- + - task: + name: handoff-test + model: gpt_default + max_steps: 15 + agents: + - examples.personalities.fruit_expert + - examples.personalities.apple_expert + - examples.personalities.banana_expert + - examples.personalities.orange_expert + user_prompt: | + You are a fruit coordinator. I need specific expert advice on each fruit. + Please hand off to the apple expert for a one-sentence fact about apples, + then to the banana expert for a one-sentence fact about bananas, + then to the orange expert for a one-sentence fact about oranges. + Each expert should provide exactly one interesting fact. From 336f83d656e53ca5fd1a978ce4e79030a64c842e Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 17:37:46 -0400 Subject: [PATCH 13/20] fix: address code scanning findings from PR review - fix uninitialized args and inconsistent return shape in prompt_parser - add explicit returns after exhaustive match statements in capi - narrow BaseException to Exception in mcp_transport thread - add comment for empty except in cleanup shutdown path - iterate over copy in mcp_lifecycle cleanup to avoid mutation during iteration - remove unused _ENGINE_SETTING_KEYS constant from runner - remove unused import in test_session - default TaskDefinition.name/description to empty string for per-index naming - use explicit re-export pattern (as X) in __main__ and cli --- src/seclab_taskflow_agent/__main__.py | 5 +++-- src/seclab_taskflow_agent/capi.py | 24 ++++++++-------------- src/seclab_taskflow_agent/cli.py | 2 +- src/seclab_taskflow_agent/mcp_lifecycle.py | 4 +--- src/seclab_taskflow_agent/mcp_transport.py | 4 ++-- src/seclab_taskflow_agent/models.py | 4 ++-- src/seclab_taskflow_agent/prompt_parser.py | 4 ++-- src/seclab_taskflow_agent/runner.py | 4 ---- tests/test_session.py | 2 +- 9 files changed, 20 insertions(+), 33 deletions(-) diff --git a/src/seclab_taskflow_agent/__main__.py b/src/seclab_taskflow_agent/__main__.py index 1c1147ed..89b6d422 100644 --- a/src/seclab_taskflow_agent/__main__.py +++ b/src/seclab_taskflow_agent/__main__.py @@ -18,8 +18,9 @@ load_dotenv(find_dotenv(usecwd=True)) # Re-export for backwards compatibility — some tests import from __main__ -from .prompt_parser import parse_prompt_args # noqa: E402, F401 -from .runner import deploy_task_agents, run_main # noqa: E402, F401 +from .prompt_parser import parse_prompt_args as parse_prompt_args # noqa: E402 +from .runner import deploy_task_agents as deploy_task_agents # noqa: E402 +from .runner import run_main as run_main # noqa: E402 if __name__ == "__main__": from .cli import app diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 0461a748..976c91d7 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -37,8 +37,7 @@ def to_url(self) -> str: return f"https://{self}/inference" case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: return f"https://{self}/v1" - case _: - raise ValueError(f"Unsupported endpoint: {self}") + raise ValueError(f"Unsupported endpoint: {self}") COPILOT_INTEGRATION_ID = "vscode-chat" @@ -114,25 +113,18 @@ def list_capi_models(token: str) -> dict[str, dict]: def supports_tool_calls(model: str, models: dict[str, dict]) -> bool: """Check whether the given model supports tool calls.""" api_endpoint = get_AI_endpoint() - match urlparse(api_endpoint).netloc: + netloc = urlparse(api_endpoint).netloc + match netloc: case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False) case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: return "tool-calling" in models.get(model, {}).get("capabilities", []) case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: - # OpenAI doesn't expose capabilities in the models list - # Check if model name indicates function calling support - model_lower = model.lower() - return any( - [ - "gpt-" in model_lower, - ] - ) - case _: - raise ValueError( - f"Unsupported Model Endpoint: {api_endpoint}\n" - f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" - ) + return "gpt-" in model.lower() + raise ValueError( + f"Unsupported Model Endpoint: {api_endpoint}\n" + f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" + ) def list_tool_call_models(token: str) -> dict[str, dict]: diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 3e11cf75..f0bb4584 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -178,4 +178,4 @@ def main( # Legacy compatibility shim — implementation moved to prompt_parser.py # --------------------------------------------------------------------------- -from .prompt_parser import parse_prompt_args # noqa: F401, E402 +from .prompt_parser import parse_prompt_args as parse_prompt_args # noqa: E402 diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index bee2f9c9..7d4edf04 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -145,7 +145,7 @@ async def mcp_session_task( connected.set() await cleanup.wait() - for entry in reversed(entries): + for entry in list(reversed(entries)): try: logging.debug(f"Starting cleanup for mcp server: {entry.server._name}") await entry.server.cleanup() @@ -158,8 +158,6 @@ async def mcp_session_task( logging.warning(f"Streamable mcp server process exception: {e}") except asyncio.CancelledError: logging.exception(f"Timeout on cleanup for mcp server: {entry.server._name}") - finally: - entries.remove(entry) except RuntimeError: logging.exception("RuntimeError in mcp session task") except asyncio.CancelledError: diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py index f2ce5165..a6c0166e 100644 --- a/src/seclab_taskflow_agent/mcp_transport.py +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -166,7 +166,7 @@ def run(self) -> None: if self.exit_code not in _EXPECTED_EXIT_CODES: self.exception = subprocess.CalledProcessError(self.exit_code, self.cmd) - except BaseException as e: + except Exception as e: self.exception = e def _read_stream( @@ -248,7 +248,7 @@ async def cleanup(self, *args: Any, **kwargs: Any) -> None: try: asyncio.run_coroutine_threadsafe(super().cleanup(*args, **kwargs), self.t.loop).result() except asyncio.CancelledError: - pass + pass # Swallow cancellation during cleanup shutdown finally: self.t.loop.stop() self.t.join() diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py index 39affc73..6445b647 100644 --- a/src/seclab_taskflow_agent/models.py +++ b/src/seclab_taskflow_agent/models.py @@ -81,8 +81,8 @@ class TaskDefinition(BaseModel): model_config = ConfigDict(extra="allow") - name: str = "taskflow" - description: str = "taskflow" + name: str = "" + description: str = "" agents: list[str] = Field(default_factory=list) user_prompt: str = "" run: str = "" diff --git a/src/seclab_taskflow_agent/prompt_parser.py b/src/seclab_taskflow_agent/prompt_parser.py index 332c93c0..397af5da 100644 --- a/src/seclab_taskflow_agent/prompt_parser.py +++ b/src/seclab_taskflow_agent/prompt_parser.py @@ -22,7 +22,7 @@ def parse_prompt_args( available_tools: AvailableTools, user_prompt: str | None = None -) -> tuple[str | None, str | None, bool, dict[str, str], str, str] | tuple[None, None, None, None, str]: +) -> tuple[str | None, str | None, bool, dict[str, str], str, str] | tuple[None, None, None, None, str, str]: """Legacy CLI parser kept for backwards compatibility with tests. Returns: @@ -53,7 +53,7 @@ def parse_prompt_args( except SystemExit as e: if e.code == 2: logging.exception(f"User provided incomplete prompt: {user_prompt}") - return None, None, None, None, help_msg + return None, None, None, None, "", help_msg p = args[0].p.strip() if args[0].p else None t = args[0].t.strip() if args[0].t else None list_models = args[0].l diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 078cfc1d..064f101b 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -119,10 +119,6 @@ def _merge_reusable_task( return TaskDefinition.model_validate(merged) -# Keys in model_settings that are handled by the engine, not ModelSettings. -_ENGINE_SETTING_KEYS = {"api_type", "endpoint", "token"} - - def _resolve_task_model( task: TaskDefinition, model_keys: list[str], diff --git a/tests/test_session.py b/tests/test_session.py index d363a302..f8563f95 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,7 +5,7 @@ import pytest -from seclab_taskflow_agent.session import CompletedTask, TaskflowSession, session_dir +from seclab_taskflow_agent.session import CompletedTask, TaskflowSession class TestTaskflowSession: From 8efcc25a9c2870cd3fe016e1abdd728cb003092d Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 17:57:47 -0400 Subject: [PATCH 14/20] fix: restore ruff baseline ignores for hatch fmt CI compatibility --- pyproject.toml | 97 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 75 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb75193f..8e1d2dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,37 +165,90 @@ exclude_lines = [ target-version = "py310" [tool.ruff.lint] -# Project-wide style choices and pragmatic suppressions. -# Rules removed from the original baseline are now enforced. +# Baseline suppressions carried from the original codebase. +# Many of these fire in pre-existing MCP server code or reflect +# deliberate style choices. Rules can be tightened incrementally. ignore = [ - # Style choices — these are deliberate project conventions - "EM101", # Exception string literals (pragmatic for this codebase) - "EM102", # Exception f-strings (pragmatic for this codebase) - "G004", # Logging f-strings (clearer than % formatting) - "T201", # print() used intentionally for user output - "TRY003", # Raise with inline message strings (pragmatic) + # Style choices — deliberate project conventions + "EM101", # Exception string literals + "EM102", # Exception f-strings + "G004", # Logging f-strings + "T201", # print() used for user output + "TRY003", # Raise with inline message strings # Backwards-compatibility suppressions for existing code - "A001", # Variable shadows built-in (existing API names) - "A002", # Argument shadows built-in (existing API signatures) - "FBT001", # Boolean positional arg (existing API) - "FBT002", # Boolean default value (existing API) - "N802", # Function name casing (existing API: get_AI_endpoint etc.) - "N806", # Variable casing (existing code conventions) - "SLF001", # Private member access (needed for MCP wrapper internals) + "A001", # Variable shadows built-in + "A002", # Argument shadows built-in + "A004", # Import shadows built-in + "FBT001", # Boolean positional arg + "FBT002", # Boolean default value + "N801", # Class name casing + "N802", # Function name casing + "N806", # Variable casing + "N818", # Exception name suffix + "SLF001", # Private member access (MCP internals) # Framework / ecosystem constraints - "ARG001", # Unused function argument (required by hook/callback signatures) - "B023", # Function uses loop variable (async closures in runner) - "INP001", # Implicit namespace package (project uses src layout) - "PLW2901", # Outer loop variable overwritten (iteration patterns) - "S701", # Jinja2 autoescape=False (YAML context, not HTML) + "ARG001", # Unused function argument (hook/callback signatures) + "B023", # Function uses loop variable (async closures) + "INP001", # Implicit namespace package (src layout) + "PLW2901", # Outer loop variable overwritten + "S701", # Jinja2 autoescape=False (YAML, not HTML) + "TID252", # Relative imports from parent modules - # Low-signal rules for this project + # Import organisation (deferred imports used for cycle-breaking) + "I001", # Import block unsorted + "PLC0414", # Useless import alias (intentional re-exports) + "PLC0415", # Import not at top-level (deferred imports) + + # Exception handling patterns + "B904", # raise without from in except + "BLE001", # Blind except (Exception) + "TRY004", # Prefer TypeError + "TRY300", # try-except-return + "TRY301", # Abstract raise to inner function + "TRY400", # logging.exception vs logging.error + + # Typing / annotation style + "FA100", # Missing from __future__ import annotations + "FA102", # Missing annotations import in stub + "PYI036", # __exit__ signature + "PYI041", # Use float instead of int | float + "TC001", # Move import into TYPE_CHECKING + "TC002", # Move import into TYPE_CHECKING + "TC003", # Move import into TYPE_CHECKING + "UP006", # Use X | Y for isinstance + "UP007", # Use X | Y union type + "UP035", # Import from collections.abc + "UP045", # Use X | None + + # Low-signal or noisy rules + "B006", # Mutable default argument (pre-existing in MCP servers) + "B007", # Unused loop control variable + "B008", # Function call in argument defaults + "C416", # Unnecessary comprehension + "LOG015", # Root logger usage + "PERF102", # Use keys()/values() instead of items() "PLR2004", # Magic value comparisons + "PLW0602", # Global variable not assigned + "PLW0603", # Using global statement + "PLW1508", # Invalid envvar default + "PT011", # pytest.raises too broad "RET503", # Missing explicit return + "RET504", # Unnecessary assignment before return "RET505", # Unnecessary else after return - "SIM102", # Collapsible if (readability preference) + "RSE102", # Unnecessary parentheses on raised exception + "RUF005", # Unpack instead of concatenation + "RUF022", # __all__ not sorted + "RUF023", # __slots__ not sorted + "RUF059", # Unused unpacked variable + "RUF100", # Unused noqa directive + "S108", # Hardcoded temp file + "S607", # Partial path to executable + "SIM102", # Collapsible if + "SIM115", # Use context handler for file + "SIM210", # Use ternary operator + "UP015", # Unnecessary open mode ] [tool.ruff.lint.per-file-ignores] From f03443c250b4e50ceb09218f2b306a585d1d31f0 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Wed, 11 Mar 2026 18:28:41 -0400 Subject: [PATCH 15/20] feat: TASKFLOW_ENV_DENYLIST to filter env vars from MCP subprocesses --- README.md | 13 +++++++ src/seclab_taskflow_agent/mcp_transport.py | 18 ++++++++- tests/test_mcp_transport.py | 45 ++++++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 tests/test_mcp_transport.py diff --git a/README.md b/README.md index 52005882..2600fcc1 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,19 @@ Error: [BadRequestError] model 'foo' not found python -m seclab_taskflow_agent --debug -t examples.taskflows.echo ``` +### MCP Environment Denylist + +By default, MCP server subprocesses inherit the parent environment. To prevent +specific variables from leaking to MCP servers, set `TASKFLOW_ENV_DENYLIST` to +a comma-separated list of variable names: + +```bash +export TASKFLOW_ENV_DENYLIST="MY_SECRET_TOKEN,PRIVATE_KEY,OTHER_CREDENTIAL" +``` + +Toolbox-level `env:` declarations in YAML still inject exactly what each server +needs, so explicitly configured variables are unaffected. + ## Use Cases and Examples The Seclab Taskflow Agent framework was primarily designed to fit the iterative feedback loop driven work involved in Agentic security research workflows and vulnerability triage tasks. diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py index a6c0166e..8632fd8d 100644 --- a/src/seclab_taskflow_agent/mcp_transport.py +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -18,6 +18,7 @@ "AsyncDebugMCPServerStdio", "ReconnectingMCPServerStdio", "StreamableMCPThread", + "_filtered_env", ] import asyncio @@ -38,6 +39,21 @@ _EXPECTED_EXIT_CODES: frozenset[int] = frozenset({0, -signal.SIGTERM}) +def _filtered_env() -> dict[str, str]: + """Return a copy of ``os.environ`` with denied variables removed. + + Set ``TASKFLOW_ENV_DENYLIST`` to a comma-separated list of variable + names that should not be forwarded to MCP server subprocesses. + Toolbox-level ``env:`` declarations in YAML still inject what each + server explicitly needs. + """ + denylist_raw = os.environ.get("TASKFLOW_ENV_DENYLIST", "") + if not denylist_raw: + return os.environ.copy() + denied = {k.strip() for k in denylist_raw.split(",") if k.strip()} + return {k: v for k, v in os.environ.items() if k not in denied} + + class StreamableMCPThread(Thread): """Thread that manages a local streamable MCP server subprocess. @@ -68,7 +84,7 @@ def __init__( self.on_output: Callable[[str], None] | None = on_output self.on_error: Callable[[str], None] | None = on_error self.poll_interval: float = poll_interval - self.env: dict[str, str] = os.environ.copy() # XXX: potential for environment leak to MCP + self.env: dict[str, str] = _filtered_env() if env: self.env.update(env) self._stop_event: Event = Event() diff --git a/tests/test_mcp_transport.py b/tests/test_mcp_transport.py new file mode 100644 index 00000000..e9bfda7b --- /dev/null +++ b/tests/test_mcp_transport.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for MCP transport utilities.""" + +import os + +from seclab_taskflow_agent.mcp_transport import _filtered_env + + +class TestFilteredEnv: + """Tests for _filtered_env environment denylist.""" + + def test_no_denylist_copies_env(self, monkeypatch): + """Without TASKFLOW_ENV_DENYLIST, returns full env copy.""" + monkeypatch.delenv("TASKFLOW_ENV_DENYLIST", raising=False) + monkeypatch.setenv("TEST_VAR_A", "hello") + result = _filtered_env() + assert result["TEST_VAR_A"] == "hello" + assert result is not os.environ + + def test_denylist_strips_variables(self, monkeypatch): + """Comma-separated denylist removes matching variables.""" + monkeypatch.setenv("SECRET_TOKEN", "s3cret") + monkeypatch.setenv("MY_API_KEY", "key123") + monkeypatch.setenv("SAFE_VAR", "keep") + monkeypatch.setenv("TASKFLOW_ENV_DENYLIST", "SECRET_TOKEN,MY_API_KEY") + result = _filtered_env() + assert "SECRET_TOKEN" not in result + assert "MY_API_KEY" not in result + assert result["SAFE_VAR"] == "keep" + + def test_denylist_handles_whitespace(self, monkeypatch): + """Whitespace around denylist entries is trimmed.""" + monkeypatch.setenv("FOO", "bar") + monkeypatch.setenv("TASKFLOW_ENV_DENYLIST", " FOO , ") + result = _filtered_env() + assert "FOO" not in result + + def test_empty_denylist_copies_env(self, monkeypatch): + """Empty TASKFLOW_ENV_DENYLIST behaves like unset.""" + monkeypatch.setenv("TASKFLOW_ENV_DENYLIST", "") + monkeypatch.setenv("KEEP_ME", "yes") + result = _filtered_env() + assert result["KEEP_ME"] == "yes" From 3765891967da89f05cfd4aa78b39af3e6231ca35 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Thu, 12 Mar 2026 13:35:40 -0400 Subject: [PATCH 16/20] fix: address PR review findings - remove global OpenAI client/api mutations (race condition with concurrent agents) - add empty resolved_agents validation before deploy_task_agents - store toolbox name on MCPServerEntry instead of accessing private _name - update last_tool_results unconditionally in session.record_task - move raises into match default cases to fix mixed return warnings - catch non-SystemExit exceptions in prompt parser --- src/seclab_taskflow_agent/agent.py | 4 ---- src/seclab_taskflow_agent/capi.py | 12 +++++++----- src/seclab_taskflow_agent/mcp_lifecycle.py | 15 ++++++++------- src/seclab_taskflow_agent/prompt_parser.py | 3 +++ src/seclab_taskflow_agent/runner.py | 6 ++++++ src/seclab_taskflow_agent/session.py | 3 +-- 6 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index d38e7cf2..dd104879 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -19,8 +19,6 @@ TContext, Tool, result, - set_default_openai_api, - set_default_openai_client, set_tracing_disabled, ) from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult @@ -193,8 +191,6 @@ def __init__( api_key=resolved_token, default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID}, ) - set_default_openai_client(client) - set_default_openai_api(api_type) set_tracing_disabled(True) self.run_hooks = run_hooks or TaskRunHooks() diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 976c91d7..a0510041 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -37,7 +37,8 @@ def to_url(self) -> str: return f"https://{self}/inference" case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: return f"https://{self}/v1" - raise ValueError(f"Unsupported endpoint: {self}") + case _: + raise ValueError(f"Unsupported endpoint: {self}") COPILOT_INTEGRATION_ID = "vscode-chat" @@ -121,10 +122,11 @@ def supports_tool_calls(model: str, models: dict[str, dict]) -> bool: return "tool-calling" in models.get(model, {}).get("capabilities", []) case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: return "gpt-" in model.lower() - raise ValueError( - f"Unsupported Model Endpoint: {api_endpoint}\n" - f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" - ) + case _: + raise ValueError( + f"Unsupported Model Endpoint: {api_endpoint}\n" + f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" + ) def list_tool_call_models(token: str) -> dict[str, dict]: diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index 7d4edf04..117f52a8 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -33,11 +33,12 @@ class MCPServerEntry: """A paired MCP server wrapper and optional local process.""" - __slots__ = ("server", "process") + __slots__ = ("server", "process", "name") - def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None = None): + def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None = None, name: str = ""): self.server = server self.process = process + self.name = name def build_mcp_servers( @@ -117,7 +118,7 @@ def _print_err(line: str) -> None: case _: raise ValueError(f"Unsupported MCP transport: {params['kind']}") - entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc)) + entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb)) return entries @@ -136,7 +137,7 @@ async def mcp_session_task( """ try: for entry in entries: - logging.debug(f"Connecting mcp server: {entry.server._name}") + logging.debug(f"Connecting mcp server: {entry.name}") if entry.process is not None: entry.process.start() await entry.process.async_wait_for_connection(poll_interval=0.1) @@ -147,9 +148,9 @@ async def mcp_session_task( for entry in list(reversed(entries)): try: - logging.debug(f"Starting cleanup for mcp server: {entry.server._name}") + logging.debug(f"Starting cleanup for mcp server: {entry.name}") await entry.server.cleanup() - logging.debug(f"Cleaned up mcp server: {entry.server._name}") + logging.debug(f"Cleaned up mcp server: {entry.name}") if entry.process is not None: entry.process.stop() try: @@ -157,7 +158,7 @@ async def mcp_session_task( except Exception as e: logging.warning(f"Streamable mcp server process exception: {e}") except asyncio.CancelledError: - logging.exception(f"Timeout on cleanup for mcp server: {entry.server._name}") + logging.exception(f"Timeout on cleanup for mcp server: {entry.name}") except RuntimeError: logging.exception("RuntimeError in mcp session task") except asyncio.CancelledError: diff --git a/src/seclab_taskflow_agent/prompt_parser.py b/src/seclab_taskflow_agent/prompt_parser.py index 397af5da..c5ff4153 100644 --- a/src/seclab_taskflow_agent/prompt_parser.py +++ b/src/seclab_taskflow_agent/prompt_parser.py @@ -54,6 +54,9 @@ def parse_prompt_args( if e.code == 2: logging.exception(f"User provided incomplete prompt: {user_prompt}") return None, None, None, None, "", help_msg + except Exception: + logging.exception(f"Failed to parse prompt: {user_prompt}") + return None, None, None, None, "", help_msg p = args[0].p.strip() if args[0].p else None t = args[0].t.strip() if args[0].t else None list_models = args[0].l diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 064f101b..ed0bb1ef 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -603,6 +603,12 @@ async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) - raise ValueError(f"No such personality: {agent_name}") resolved_agents[agent_name] = personality + if not resolved_agents: + raise ValueError( + "No agents resolved for this task. " + "Specify a personality with -p or provide an agents list." + ) + async def _deploy(ra: dict, pp: str) -> bool: async with semaphore: return await deploy_task_agents( diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py index cb8d0f8b..f02d30cc 100644 --- a/src/seclab_taskflow_agent/session.py +++ b/src/seclab_taskflow_agent/session.py @@ -99,8 +99,7 @@ def record_task( tool_results=tool_results or [], ) ) - if tool_results: - self.last_tool_results = list(tool_results) + self.last_tool_results = list(tool_results or []) self.save() def mark_finished(self) -> None: From 468f97b79fcad77ebd676899b6cfefd925b1b18d Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Thu, 12 Mar 2026 13:41:52 -0400 Subject: [PATCH 17/20] fix: session resume, resource loading, and error path consistency - don't advance resume cursor past failed must_complete tasks - peek at last tool result before consuming (safe for retry) - use Traversable.open() for zip/wheel compatibility - return empty string (not None) for prompt on invalid globals --- src/seclab_taskflow_agent/available_tools.py | 2 +- src/seclab_taskflow_agent/prompt_parser.py | 2 +- src/seclab_taskflow_agent/runner.py | 22 ++++++++++++-------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index 812599a6..577ae067 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -121,7 +121,7 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: f"Cannot load {toolname} because {pkg_dir} is not a valid directory." ) filepath = pkg_dir.joinpath(filename + ".yaml") - with open(filepath) as fh: + with filepath.open() as fh: raw = yaml.safe_load(fh) # Validate header before full parse diff --git a/src/seclab_taskflow_agent/prompt_parser.py b/src/seclab_taskflow_agent/prompt_parser.py index c5ff4153..10baccd0 100644 --- a/src/seclab_taskflow_agent/prompt_parser.py +++ b/src/seclab_taskflow_agent/prompt_parser.py @@ -66,7 +66,7 @@ def parse_prompt_args( for g in args[0].globals: if "=" not in g: logging.error(f"Invalid global variable format: {g}. Expected KEY=VALUE") - return None, None, None, None, None, help_msg + return None, None, None, None, "", help_msg key, value = g.split("=", 1) cli_globals[key.strip()] = value.strip() diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index ed0bb1ef..5099008c 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -198,7 +198,7 @@ async def _build_prompts_to_run( if "result" not in task_prompt.lower(): logging.warning("repeat_prompt enabled but no {{ result }} in prompt") try: - last_result = json.loads(last_mcp_tool_results.pop()) + last_result = json.loads(last_mcp_tool_results[-1]) text = last_result.get("text", "") try: iterable_result = json.loads(text) @@ -214,6 +214,9 @@ async def _build_prompts_to_run( logging.critical("No last MCP tool result available") raise + # Consume only after successful parse + last_mcp_tool_results.pop() + if not iterable_result: await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") else: @@ -689,14 +692,6 @@ async def _deploy(ra: dict, pp: str) -> bool: ) raise last_task_error - # Checkpoint after successful task - session.record_task( - index=task_index, - name=task_name, - success=task_complete, - tool_results=list(last_mcp_tool_results), - ) - if must_complete and not task_complete: logging.critical("Required task not completed ... aborting!") await render_model_output("🤖💥 *Required task not completed ...\n") @@ -707,6 +702,15 @@ async def _deploy(ra: dict, pp: str) -> bool: ) break + # Checkpoint after task (must_complete failures break above + # without advancing the resume cursor) + session.record_task( + index=task_index, + name=task_name, + success=task_complete, + tool_results=list(last_mcp_tool_results), + ) + # All tasks completed successfully if session is not None and not session.error: session.mark_finished() From 69fc1e3ca76d826f056524a26597e275bd03aa64 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Fri, 20 Mar 2026 13:18:09 -0400 Subject: [PATCH 18/20] fix: address remaining PR feedback, expand test coverage - cli: TASK_AGENT_DEBUG="0"/"false" no longer enables debug mode - capi: allow arbitrary API endpoints with graceful fallback - runner: defer tool result pop until after template rendering - test: 72 new unit tests for runner, cli, session, prompt parser, capi - examples: add edge_case_test.yaml for nested JSON repeat_prompt --- examples/taskflows/edge_case_test.yaml | 69 ++++ src/seclab_taskflow_agent/capi.py | 17 +- src/seclab_taskflow_agent/cli.py | 2 +- src/seclab_taskflow_agent/runner.py | 7 +- tests/test_api_endpoint_config.py | 13 +- tests/test_capi_extended.py | 97 ++++++ tests/test_cli.py | 84 +++++ tests/test_prompt_parser_edge.py | 123 +++++++ tests/test_runner.py | 447 +++++++++++++++++++++++++ tests/test_session_edge.py | 107 ++++++ 10 files changed, 947 insertions(+), 19 deletions(-) create mode 100644 examples/taskflows/edge_case_test.yaml create mode 100644 tests/test_capi_extended.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_prompt_parser_edge.py create mode 100644 tests/test_runner.py create mode 100644 tests/test_session_edge.py diff --git a/examples/taskflows/edge_case_test.yaml b/examples/taskflows/edge_case_test.yaml new file mode 100644 index 00000000..046e9775 --- /dev/null +++ b/examples/taskflows/edge_case_test.yaml @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +# Edge-case test taskflow targeting less-exercised code paths: +# - shell task producing nested JSON for repeat_prompt +# - repeat_prompt over dictionary items (not just arrays) +# - env variable scoping (task-level env) +# - globals CLI override combined with file defaults +# - max_steps constraint +# - must_complete on a non-tool task +# - empty taskflow section handling + +seclab-taskflow-agent: + version: "1.0" + filetype: taskflow + +model_config: examples.model_configs.model_config + +globals: + category: edge-cases + default_value: from-file + +taskflow: + # --------------------------------------------------------------- + # Task 1: Shell task with nested JSON structure + # Tests: run, must_complete, complex JSON output + # --------------------------------------------------------------- + - task: + name: nested-json-shell + must_complete: true + run: | + echo '[{"id": 1, "data": {"label": "alpha", "score": 0.95}}, {"id": 2, "data": {"label": "beta", "score": 0.87}}]' + + # --------------------------------------------------------------- + # Task 2: Repeat over nested structure, sequential (not async) + # Tests: repeat_prompt (sequential), nested result access, + # globals reference, inputs, env scoping, max_steps + # --------------------------------------------------------------- + - task: + name: sequential-repeat + repeat_prompt: true + must_complete: true + model: gpt_default + max_steps: 5 + agents: + - examples.personalities.fruit_expert + inputs: + output_format: json + env: + EDGE_TEST_MODE: "sequential" + user_prompt: | + Category: {{ globals.category }}, default: {{ globals.default_value }}. + Item ID {{ result.id }}: label={{ result.data.label }}, score={{ result.data.score }}. + Respond with exactly one sentence summarizing this item in {{ inputs.output_format }} awareness. + + # --------------------------------------------------------------- + # Task 3: Simple prompt with no tools (headless, no toolboxes) + # Tests: pure LLM task, exclude_from_context, model alias + # --------------------------------------------------------------- + - task: + name: pure-llm-task + model: gpt_default + exclude_from_context: true + agents: + - examples.personalities.fruit_expert + max_steps: 3 + user_prompt: | + The category is {{ globals.category }}. + Say "edge case test passed" and nothing else. diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index a0510041..504554f3 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -79,10 +79,8 @@ def list_capi_models(token: str) -> dict[str, dict]: case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: models_catalog = "models" case _: - raise ValueError( - f"Unsupported Model Endpoint: {api_endpoint}\n" - f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" - ) + # Unknown endpoint — try the OpenAI-style models catalog + models_catalog = "models" r = httpx.get( httpx.URL(api_endpoint).join(models_catalog), headers={ @@ -100,6 +98,10 @@ def list_capi_models(token: str) -> dict[str, dict]: models_list = r.json() case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: models_list = r.json().get("data", []) + case _: + # Unknown endpoint — try OpenAI-style {"data": [...]} + body = r.json() + models_list = body.get("data", body) if isinstance(body, dict) else body for model in models_list: models[model.get("id")] = dict(model) except httpx.RequestError: @@ -123,10 +125,9 @@ def supports_tool_calls(model: str, models: dict[str, dict]) -> bool: case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: return "gpt-" in model.lower() case _: - raise ValueError( - f"Unsupported Model Endpoint: {api_endpoint}\n" - f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" - ) + # Unknown endpoint — optimistically assume tool-call support + # if the model is present in the catalog. + return model in models def list_tool_call_models(token: str) -> dict[str, dict]: diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index f0bb4584..bd9fc139 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -113,7 +113,7 @@ def main( ) -> None: """Run a taskflow or personality-based agent session.""" # Debug mode from flag or env var - debug = debug or bool(os.getenv("TASK_AGENT_DEBUG")) + debug = debug or os.getenv("TASK_AGENT_DEBUG", "").strip().lower() in ("1", "true", "yes") # Validate mutual exclusivity (resume is standalone) if resume and (personality or taskflow or list_models): diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 5099008c..04bc6d75 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -214,9 +214,6 @@ async def _build_prompts_to_run( logging.critical("No last MCP tool result available") raise - # Consume only after successful parse - last_mcp_tool_results.pop() - if not iterable_result: await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") else: @@ -234,6 +231,10 @@ async def _build_prompts_to_run( except jinja2.TemplateError as e: logging.error(f"Error rendering template for result {value}: {e}") raise ValueError(f"Template rendering failed: {e}") + + # Consume only after all prompts rendered successfully so that + # the result remains available for retry/resume on failure. + last_mcp_tool_results.pop() else: prompts_to_run.append(task_prompt) return prompts_to_run diff --git a/tests/test_api_endpoint_config.py b/tests/test_api_endpoint_config.py index 4c3d6b0c..3912496b 100644 --- a/tests/test_api_endpoint_config.py +++ b/tests/test_api_endpoint_config.py @@ -62,15 +62,14 @@ def test_to_url_openai(self): assert endpoint.to_url() == "https://api.openai.com/v1" def test_unsupported_endpoint(self, monkeypatch): - """Test that unsupported API endpoint raises ValueError.""" + """Test that unsupported API endpoint falls back gracefully.""" api_endpoint = "https://unsupported.example.com" monkeypatch.setenv("AI_API_ENDPOINT", api_endpoint) - with pytest.raises(ValueError) as excinfo: - list_capi_models("abc") - msg = str(excinfo.value) - assert "Unsupported Model Endpoint" in msg - assert "https://models.github.ai/inference" in msg - assert "https://api.githubcopilot.com" in msg + # Unknown endpoints should not raise; they try OpenAI-style catalog + # and return an empty dict on connection failure. + result = list_capi_models("abc") + assert isinstance(result, dict) + assert result == {} if __name__ == "__main__": diff --git a/tests/test_capi_extended.py b/tests/test_capi_extended.py new file mode 100644 index 00000000..297f6ae7 --- /dev/null +++ b/tests/test_capi_extended.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Extended tests for capi module.""" + +from __future__ import annotations + +from seclab_taskflow_agent.capi import AI_API_ENDPOINT_ENUM, supports_tool_calls + + +class TestSupportsToolCalls: + """Tests for supports_tool_calls with unknown endpoints.""" + + def test_unknown_endpoint_known_model(self, monkeypatch): + """Unknown endpoint returns True when model is in the catalog.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://custom.api.example.com/v1") + models = {"my-model": {"id": "my-model"}} + assert supports_tool_calls("my-model", models) is True + + def test_unknown_endpoint_unknown_model(self, monkeypatch): + """Unknown endpoint returns False when model is NOT in the catalog.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://custom.api.example.com/v1") + models = {"other-model": {"id": "other-model"}} + assert supports_tool_calls("missing-model", models) is False + + def test_copilot_endpoint_with_capabilities(self, monkeypatch): + """Copilot endpoint checks capabilities.supports.tool_calls.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.githubcopilot.com") + models = { + "gpt-4o": { + "id": "gpt-4o", + "capabilities": {"supports": {"tool_calls": True}}, + } + } + assert supports_tool_calls("gpt-4o", models) is True + + def test_copilot_endpoint_without_capabilities(self, monkeypatch): + """Copilot endpoint returns False when tool_calls not in capabilities.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.githubcopilot.com") + models = { + "text-only": { + "id": "text-only", + "capabilities": {"supports": {}}, + } + } + assert supports_tool_calls("text-only", models) is False + + def test_models_github_endpoint(self, monkeypatch): + """models.github.ai checks for 'tool-calling' in capabilities list.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://models.github.ai/inference") + models = { + "openai/gpt-4o": { + "id": "openai/gpt-4o", + "capabilities": ["tool-calling", "chat"], + } + } + assert supports_tool_calls("openai/gpt-4o", models) is True + + def test_models_github_endpoint_no_tool_calling(self, monkeypatch): + """models.github.ai returns False when 'tool-calling' not in list.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://models.github.ai/inference") + models = { + "some-model": { + "id": "some-model", + "capabilities": ["chat"], + } + } + assert supports_tool_calls("some-model", models) is False + + def test_openai_endpoint_gpt_model(self, monkeypatch): + """OpenAI endpoint returns True for models containing 'gpt-'.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") + assert supports_tool_calls("gpt-4o", {}) is True + + def test_openai_endpoint_non_gpt_model(self, monkeypatch): + """OpenAI endpoint returns False for non-GPT models.""" + monkeypatch.setenv("AI_API_ENDPOINT", "https://api.openai.com/v1") + assert supports_tool_calls("claude-3-opus", {}) is False + + +class TestAIAPIEndpointEnum: + """Tests for the AI_API_ENDPOINT_ENUM StrEnum.""" + + def test_enum_values(self): + """All expected endpoint values exist.""" + assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB == "models.github.ai" + assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT == "api.githubcopilot.com" + assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI == "api.openai.com" + + def test_to_url_models_github(self): + assert AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB.to_url() == "https://models.github.ai/inference" + + def test_to_url_copilot(self): + assert AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT.to_url() == "https://api.githubcopilot.com" + + def test_to_url_openai(self): + assert AI_API_ENDPOINT_ENUM.AI_API_OPENAI.to_url() == "https://api.openai.com/v1" diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..068bf3a8 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Unit tests for the Typer CLI module.""" + +from __future__ import annotations + +import pytest +import typer + +from seclab_taskflow_agent.cli import _parse_global + + +class TestParseGlobal: + """Tests for _parse_global KEY=VALUE parsing.""" + + def test_valid_key_value(self): + """Standard KEY=VALUE is parsed correctly.""" + assert _parse_global("fruit=apple") == ("fruit", "apple") + + def test_missing_equals_raises(self): + """A string without '=' raises BadParameter.""" + with pytest.raises(typer.BadParameter, match="Expected KEY=VALUE"): + _parse_global("no_equals_here") + + def test_value_with_equals_sign(self): + """Only the first '=' is used as the delimiter.""" + key, val = _parse_global("url=https://example.com?foo=bar") + assert key == "url" + assert val == "https://example.com?foo=bar" + + def test_whitespace_stripped(self): + """Leading/trailing whitespace in key and value is stripped.""" + key, val = _parse_global(" key = value ") + assert key == "key" + assert val == "value" + + def test_empty_value(self): + """An empty value after '=' is allowed.""" + key, val = _parse_global("key=") + assert key == "key" + assert val == "" + + def test_empty_key(self): + """An empty key before '=' is technically allowed by the parser.""" + key, val = _parse_global("=value") + assert key == "" + assert val == "value" + + +class TestDebugEnvParsing: + """Tests for the TASK_AGENT_DEBUG environment variable expression.""" + + @staticmethod + def _is_debug(env_value: str) -> bool: + """Reproduce the debug expression from cli.py.""" + return env_value.strip().lower() in ("1", "true", "yes") + + def test_zero_is_false(self): + assert self._is_debug("0") is False + + def test_one_is_true(self): + assert self._is_debug("1") is True + + def test_true_string_is_true(self): + assert self._is_debug("true") is True + + def test_TRUE_string_is_true(self): + assert self._is_debug("TRUE") is True + + def test_yes_string_is_true(self): + assert self._is_debug("yes") is True + + def test_empty_string_is_false(self): + assert self._is_debug("") is False + + def test_false_string_is_false(self): + assert self._is_debug("false") is False + + def test_whitespace_trimmed(self): + assert self._is_debug(" 1 ") is True + + def test_random_text_is_false(self): + assert self._is_debug("enabled") is False diff --git a/tests/test_prompt_parser_edge.py b/tests/test_prompt_parser_edge.py new file mode 100644 index 00000000..8dfb366d --- /dev/null +++ b/tests/test_prompt_parser_edge.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Edge-case tests for the legacy prompt parser.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from seclab_taskflow_agent.available_tools import AvailableTools +from seclab_taskflow_agent.prompt_parser import parse_prompt_args + + +def _tools() -> AvailableTools: + return MagicMock(spec=AvailableTools) + + +class TestParsePromptArgs: + """Tests for parse_prompt_args edge cases.""" + + def test_none_prompt_returns_defaults(self): + """None prompt causes argparse to read sys.argv; result still has 6 elements.""" + result = parse_prompt_args(_tools(), None) + assert len(result) == 6 + # When None is passed, argparse parses sys.argv[1:] (pytest args), + # so personality and taskflow remain None + p, t, list_flag, g, prompt, help_msg = result + assert p is None + assert t is None + + def test_empty_string_returns_defaults(self): + """Empty string prompt returns default values.""" + result = parse_prompt_args(_tools(), "") + assert len(result) == 6 + p, t, list_flag, g, prompt, help_msg = result + assert p is None + assert t is None + + def test_personality_flag(self): + """-p flag sets the personality.""" + p, t, list_flag, g, prompt, _ = parse_prompt_args( + _tools(), "-p my.personality hello world" + ) + assert p == "my.personality" + assert t is None + assert prompt == "hello world" + + def test_taskflow_flag(self): + """-t flag sets the taskflow.""" + p, t, list_flag, g, prompt, _ = parse_prompt_args( + _tools(), "-t my.taskflow do stuff" + ) + assert t == "my.taskflow" + assert p is None + assert prompt == "do stuff" + + def test_list_models_flag(self): + """-l flag sets list_models to True.""" + p, t, list_flag, g, prompt, _ = parse_prompt_args(_tools(), "-l") + assert list_flag is True + assert p is None + assert t is None + + def test_invalid_global_format_returns_none_tuple(self): + """-g with no = returns the None/error tuple.""" + result = parse_prompt_args(_tools(), "-g badformat") + p, t, list_flag, g, prompt, help_msg = result + assert p is None + assert t is None + assert list_flag is None + assert g is None + + def test_mutual_exclusivity_p_and_t(self): + """-p and -t together triggers SystemExit → None tuple.""" + result = parse_prompt_args(_tools(), "-p foo -t bar") + p, t, list_flag, g, prompt, help_msg = result + # argparse raises SystemExit(2) which is caught → None tuple + assert p is None + assert t is None + assert list_flag is None + assert g is None + + def test_prompt_remainder_collected(self): + """Remaining text after flags is collected as prompt.""" + _, _, _, _, prompt, _ = parse_prompt_args( + _tools(), "-p my.personality tell me a joke" + ) + assert prompt == "tell me a joke" + + def test_return_tuple_always_has_six_elements(self): + """Return value always has exactly 6 elements.""" + # Success case + result = parse_prompt_args(_tools(), "-p my.personality hello") + assert len(result) == 6 + + # Error case + result = parse_prompt_args(_tools(), "-p foo -t bar") + assert len(result) == 6 + + # None case + result = parse_prompt_args(_tools(), None) + assert len(result) == 6 + + def test_global_variable_valid(self): + """-g KEY=VALUE correctly parses global variables.""" + _, _, _, g, _, _ = parse_prompt_args( + _tools(), "-p my.personality -g fruit=apple hello" + ) + assert g == {"fruit": "apple"} + + def test_multiple_globals(self): + """Multiple -g flags are all captured.""" + _, _, _, g, _, _ = parse_prompt_args( + _tools(), "-p my.personality -g fruit=apple -g color=red" + ) + assert g == {"fruit": "apple", "color": "red"} + + def test_global_value_with_equals(self): + """-g with value containing = uses only first = as delimiter.""" + _, _, _, g, _, _ = parse_prompt_args( + _tools(), "-p my.personality -g url=https://x.com?a=1" + ) + assert g == {"url": "https://x.com?a=1"} diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000..401737e7 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Unit tests for runner helper functions (no API calls).""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from seclab_taskflow_agent.models import ( + ModelConfigDocument, + TaskDefinition, + TaskflowDocument, + TaskflowHeader, + TaskWrapper, +) +from seclab_taskflow_agent.runner import ( + _build_prompts_to_run, + _merge_reusable_task, + _resolve_model_config, + _resolve_task_model, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_header() -> TaskflowHeader: + return TaskflowHeader(version="1.0", filetype="taskflow") + + +def _make_model_config_header() -> TaskflowHeader: + return TaskflowHeader(version="1.0", filetype="model_config") + + +def _make_model_config( + models: dict[str, str] | None = None, + model_settings: dict[str, dict[str, Any]] | None = None, + api_type: str = "chat_completions", +) -> ModelConfigDocument: + return ModelConfigDocument( + **{ + "seclab-taskflow-agent": _make_model_config_header(), + "api_type": api_type, + "models": models or {}, + "model_settings": model_settings or {}, + } + ) + + +def _make_taskflow_doc(tasks: list[TaskDefinition]) -> TaskflowDocument: + return TaskflowDocument( + **{ + "seclab-taskflow-agent": _make_header(), + "taskflow": [TaskWrapper(task=t) for t in tasks], + } + ) + + +def _mock_available_tools() -> MagicMock: + return MagicMock() + + +# =================================================================== +# _resolve_model_config +# =================================================================== + +class TestResolveModelConfig: + """Tests for _resolve_model_config.""" + + def test_basic_model_resolution(self): + """Model keys and dict are extracted from config.""" + at = _mock_available_tools() + at.get_model_config.return_value = _make_model_config( + models={"fast": "gpt-4o-mini", "smart": "gpt-4o"}, + ) + keys, mdict, params, api_type = _resolve_model_config(at, "ref") + assert set(keys) == {"fast", "smart"} + assert mdict == {"fast": "gpt-4o-mini", "smart": "gpt-4o"} + assert params == {} + assert api_type == "chat_completions" + + def test_api_type_flows_through(self): + """api_type from the config document is returned.""" + at = _mock_available_tools() + at.get_model_config.return_value = _make_model_config( + models={"m1": "provider-model"}, + api_type="responses", + ) + _, _, _, api_type = _resolve_model_config(at, "ref") + assert api_type == "responses" + + def test_model_settings_extraction(self): + """Per-model settings are returned and keyed by logical name.""" + at = _mock_available_tools() + at.get_model_config.return_value = _make_model_config( + models={"m1": "provider-m1"}, + model_settings={"m1": {"temperature": 0.5}}, + ) + _, _, params, _ = _resolve_model_config(at, "ref") + assert params == {"m1": {"temperature": 0.5}} + + def test_validation_error_on_non_dict_settings(self): + """Non-dict model_settings raises ValueError.""" + at = _mock_available_tools() + cfg = _make_model_config(models={"m1": "p-m1"}) + # Manually set model_settings to a non-dict after construction + object.__setattr__(cfg, "model_settings", "not-a-dict") + at.get_model_config.return_value = cfg + with pytest.raises(ValueError, match="must be a dictionary"): + _resolve_model_config(at, "ref") + + +# =================================================================== +# _merge_reusable_task +# =================================================================== + +class TestMergeReusableTask: + """Tests for _merge_reusable_task.""" + + def test_current_fields_override_parent(self): + """Fields explicitly set on the current task override the parent.""" + parent = TaskDefinition(name="parent", user_prompt="parent prompt", model="slow") + doc = _make_taskflow_doc([parent]) + + at = _mock_available_tools() + at.get_taskflow.return_value = doc + + current = TaskDefinition(uses="pkg.reusable", name="child", model="fast") + merged = _merge_reusable_task(at, current) + assert merged.name == "child" + assert merged.model == "fast" + # Parent's prompt should fill in where child uses the default + assert merged.user_prompt == "parent prompt" + + def test_parent_defaults_fill_in(self): + """Parent defaults are used when the current task does not set a field.""" + parent = TaskDefinition( + name="parent", + user_prompt="do something", + headless=True, + must_complete=True, + ) + doc = _make_taskflow_doc([parent]) + + at = _mock_available_tools() + at.get_taskflow.return_value = doc + + current = TaskDefinition(uses="pkg.reusable", name="override-name") + merged = _merge_reusable_task(at, current) + assert merged.name == "override-name" + assert merged.headless is True + assert merged.must_complete is True + + def test_raises_if_reusable_has_multiple_tasks(self): + """ValueError raised when reusable taskflow has more than 1 task.""" + t1 = TaskDefinition(name="t1") + t2 = TaskDefinition(name="t2") + doc = _make_taskflow_doc([t1, t2]) + + at = _mock_available_tools() + at.get_taskflow.return_value = doc + + current = TaskDefinition(uses="pkg.multi") + with pytest.raises(ValueError, match="only contain 1 task"): + _merge_reusable_task(at, current) + + def test_raises_if_reusable_not_found(self): + """ValueError raised when the reusable taskflow does not exist.""" + at = _mock_available_tools() + at.get_taskflow.return_value = None + + current = TaskDefinition(uses="pkg.missing") + with pytest.raises(ValueError, match="No such reusable taskflow"): + _merge_reusable_task(at, current) + + +# =================================================================== +# _resolve_task_model +# =================================================================== + +class TestResolveTaskModel: + """Tests for _resolve_task_model (pure function).""" + + def test_logical_name_mapped_to_provider_id(self): + """A logical model name is resolved to the provider model ID.""" + model_id, _, _, _, _ = _resolve_task_model( + TaskDefinition(model="fast"), + model_keys=["fast"], + model_dict={"fast": "gpt-4o-mini"}, + models_params={}, + ) + assert model_id == "gpt-4o-mini" + + def test_model_settings_from_config(self): + """Settings from models_params are included in the result.""" + _, settings, _, _, _ = _resolve_task_model( + TaskDefinition(model="fast"), + model_keys=["fast"], + model_dict={"fast": "gpt-4o-mini"}, + models_params={"fast": {"temperature": 0.7, "max_tokens": 100}}, + ) + assert settings["temperature"] == 0.7 + assert settings["max_tokens"] == 100 + + def test_task_level_settings_override_config(self): + """Task-level model_settings override config-level settings.""" + _, settings, _, _, _ = _resolve_task_model( + TaskDefinition(model="fast", model_settings={"temperature": 0.2}), + model_keys=["fast"], + model_dict={"fast": "gpt-4o-mini"}, + models_params={"fast": {"temperature": 0.7, "max_tokens": 100}}, + ) + assert settings["temperature"] == 0.2 + assert settings["max_tokens"] == 100 + + def test_engine_keys_extracted(self): + """Engine keys (api_type, endpoint, token) are popped from settings.""" + _, settings, api_type, endpoint, token = _resolve_task_model( + TaskDefinition(model="fast"), + model_keys=["fast"], + model_dict={"fast": "gpt-4o-mini"}, + models_params={ + "fast": { + "api_type": "responses", + "endpoint": "https://custom.api", + "token": "secret", + "temperature": 0.5, + } + }, + ) + assert api_type == "responses" + assert endpoint == "https://custom.api" + assert token == "secret" + assert "api_type" not in settings + assert "endpoint" not in settings + assert "token" not in settings + assert settings["temperature"] == 0.5 + + def test_default_model_when_empty(self): + """Empty model string falls back to DEFAULT_MODEL.""" + from seclab_taskflow_agent.agent import DEFAULT_MODEL + + model_id, _, _, _, _ = _resolve_task_model( + TaskDefinition(model=""), + model_keys=[], + model_dict={}, + models_params={}, + ) + assert model_id == DEFAULT_MODEL + + def test_model_not_in_keys_passes_through(self): + """A model name not in model_keys passes through as-is.""" + model_id, _, _, _, _ = _resolve_task_model( + TaskDefinition(model="claude-3-opus"), + model_keys=["fast", "smart"], + model_dict={"fast": "gpt-4o-mini", "smart": "gpt-4o"}, + models_params={}, + ) + assert model_id == "claude-3-opus" + + def test_task_engine_keys_override_config(self): + """Task-level model_settings can override engine keys from config.""" + _, _, api_type, endpoint, token = _resolve_task_model( + TaskDefinition( + model="fast", + model_settings={"api_type": "responses", "endpoint": "https://task.api"}, + ), + model_keys=["fast"], + model_dict={"fast": "gpt-4o-mini"}, + models_params={"fast": {"api_type": "chat_completions"}}, + ) + assert api_type == "responses" + assert endpoint == "https://task.api" + + +# =================================================================== +# _build_prompts_to_run +# =================================================================== + +class TestBuildPromptsToRun: + """Tests for _build_prompts_to_run (async, run via asyncio.run).""" + + @staticmethod + def _result_entry(data: Any) -> str: + """Build a JSON string mimicking an MCP tool result.""" + return json.dumps({"text": json.dumps(data)}) + + @staticmethod + def _run(coro): + """Run an async coroutine with render_model_output mocked out.""" + with patch("seclab_taskflow_agent.runner.render_model_output", new_callable=AsyncMock): + return asyncio.run(coro) + + def test_non_repeat_returns_single_prompt(self): + """Without repeat_prompt, the original prompt is returned as-is.""" + result = self._run( + _build_prompts_to_run( + task_prompt="hello world", + repeat_prompt=False, + last_mcp_tool_results=[], + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + assert result == ["hello world"] + + def test_repeat_with_json_array(self): + """repeat_prompt with a JSON array generates one prompt per element.""" + items = [{"name": "apple"}, {"name": "banana"}] + results = [self._result_entry(items)] + prompts = self._run( + _build_prompts_to_run( + task_prompt="Process {{ result.name }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + assert len(prompts) == 2 + assert "apple" in prompts[0] + assert "banana" in prompts[1] + + def test_repeat_with_dict_items(self): + """repeat_prompt iterates over dict keys when result is a dict.""" + data = {"a": 1, "b": 2} + results = [self._result_entry(data)] + prompts = self._run( + _build_prompts_to_run( + task_prompt="Key: {{ result }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + assert len(prompts) == 2 + + def test_repeat_with_empty_iterable(self): + """repeat_prompt with an empty list renders no prompts.""" + results = [self._result_entry([])] + prompts = self._run( + _build_prompts_to_run( + task_prompt="Process {{ result }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + assert prompts == [] + + def test_raises_index_error_when_no_last_result(self): + """IndexError when last_mcp_tool_results is empty.""" + with pytest.raises(IndexError): + self._run( + _build_prompts_to_run( + task_prompt="Process {{ result }}", + repeat_prompt=True, + last_mcp_tool_results=[], + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + + def test_raises_value_error_on_non_json_result(self): + """ValueError when MCP result text is not valid JSON.""" + results = [json.dumps({"text": "not json!!"})] + with pytest.raises(ValueError, match="not valid JSON"): + self._run( + _build_prompts_to_run( + task_prompt="Process {{ result }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + + def test_pop_happens_after_successful_render(self): + """The last result is only consumed after all prompts render.""" + items = [{"name": "x"}] + results = [self._result_entry(items)] + original_len = len(results) + + self._run( + _build_prompts_to_run( + task_prompt="Process {{ result.name }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + # After success, the entry should be consumed + assert len(results) == original_len - 1 + + def test_pop_does_not_happen_on_render_failure(self): + """On template error the result is NOT consumed (available for retry).""" + items = [{"name": "x"}] + results = [self._result_entry(items)] + + with patch( + "seclab_taskflow_agent.runner.render_template", + side_effect=Exception("template boom"), + ), pytest.raises(Exception, match="template boom"): + self._run( + _build_prompts_to_run( + task_prompt="Process {{ result.name }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) + # Result should still be there for retry + assert len(results) == 1 + + def test_raises_type_error_on_non_iterable_result(self): + """TypeError when MCP result parses to a non-iterable (e.g. int).""" + results = [self._result_entry(42)] + with pytest.raises(TypeError): + self._run( + _build_prompts_to_run( + task_prompt="Process {{ result }}", + repeat_prompt=True, + last_mcp_tool_results=results, + available_tools=_mock_available_tools(), + global_variables={}, + inputs={}, + ) + ) diff --git a/tests/test_session_edge.py b/tests/test_session_edge.py new file mode 100644 index 00000000..02cc5714 --- /dev/null +++ b/tests/test_session_edge.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Edge-case tests for session checkpoint/resume module.""" + +from __future__ import annotations + +import pytest + +from seclab_taskflow_agent.session import TaskflowSession, session_dir + + +class TestSessionEdgeCases: + """Edge-case tests for TaskflowSession.""" + + def test_record_task_empty_tool_results(self, tmp_path, monkeypatch): + """record_task with empty list sets last_tool_results to [].""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="t0", success=True, tool_results=[]) + assert s.last_tool_results == [] + + def test_record_task_none_tool_results(self, tmp_path, monkeypatch): + """record_task with None tool_results defaults to [].""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="t0", success=True, tool_results=None) + assert s.last_tool_results == [] + + def test_next_task_index_non_sequential(self, tmp_path, monkeypatch): + """next_task_index uses max(indices) + 1, even if non-sequential.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="t0", success=True) + s.record_task(index=5, name="t5", success=True) + assert s.next_task_index == 6 + + def test_save_load_roundtrip_preserves_tool_results(self, tmp_path, monkeypatch): + """save + load roundtrip preserves last_tool_results.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="t0", success=True, tool_results=["res1", "res2"]) + sid = s.session_id + + loaded = TaskflowSession.load(sid) + assert loaded.last_tool_results == ["res1", "res2"] + assert loaded.taskflow_path == "test.flow" + + def test_list_sessions_skips_corrupt_files(self, tmp_path, monkeypatch): + """list_sessions gracefully skips files with invalid JSON.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + + # Create a valid session + s = TaskflowSession(taskflow_path="valid.flow") + s.save() + + # Write a corrupt file + sdir = session_dir() + corrupt_path = sdir / "corrupt.json" + corrupt_path.write_text("{invalid json!!") + + sessions = TaskflowSession.list_sessions() + # Only the valid session should be returned + assert len(sessions) == 1 + assert sessions[0].taskflow_path == "valid.flow" + + def test_multiple_record_task_accumulate(self, tmp_path, monkeypatch): + """Multiple record_task calls accumulate completed_tasks.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.record_task(index=0, name="t0", success=True, tool_results=["r0"]) + s.record_task(index=1, name="t1", success=False, tool_results=["r1"]) + s.record_task(index=2, name="t2", success=True, tool_results=["r2"]) + + assert len(s.completed_tasks) == 3 + assert s.completed_tasks[0].name == "t0" + assert s.completed_tasks[0].result is True + assert s.completed_tasks[1].result is False + assert s.completed_tasks[2].name == "t2" + # last_tool_results reflects the last call + assert s.last_tool_results == ["r2"] + + def test_mark_failed_then_save_preserves_error(self, tmp_path, monkeypatch): + """mark_failed persists the error through save/load.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.mark_failed("something went wrong") + assert s.error == "something went wrong" + + loaded = TaskflowSession.load(s.session_id) + assert loaded.error == "something went wrong" + assert loaded.finished is False + + def test_mark_finished_then_load(self, tmp_path, monkeypatch): + """mark_finished flag persists through save/load.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + s = TaskflowSession(taskflow_path="test.flow") + s.mark_finished() + + loaded = TaskflowSession.load(s.session_id) + assert loaded.finished is True + + def test_load_nonexistent_raises(self, tmp_path, monkeypatch): + """Loading a non-existent session raises FileNotFoundError.""" + monkeypatch.setattr("seclab_taskflow_agent.session._data_dir", lambda: tmp_path) + with pytest.raises(FileNotFoundError, match="No session checkpoint found"): + TaskflowSession.load("nonexistent-id") From a147e9e8161c6295013b36c3ef9ef96d1f218aa6 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Fri, 20 Mar 2026 14:53:19 -0400 Subject: [PATCH 19/20] Address second round of review feedback - capi: unknown endpoint fallback now type-checks response before iterating, avoids AttributeError on dict responses without data key - render_utils: flush_async_output is a no-op when no buffer exists, prevents masking the real error when an async agent fails early - session: list_sessions sorts by file mtime instead of filename so most-recent-first ordering is actually correct - runner: replace confusing issubset(set()) idiom with straightforward set difference check for unknown model settings keys - runner: catch JSONDecodeError on outer tool result parse separately from the inner text parse so malformed results get a clear error - tests: suppress S105 false positive on test token assertion --- src/seclab_taskflow_agent/capi.py | 9 +++++-- src/seclab_taskflow_agent/render_utils.py | 6 ++--- src/seclab_taskflow_agent/runner.py | 31 +++++++++++++---------- src/seclab_taskflow_agent/session.py | 2 +- tests/test_runner.py | 2 +- 5 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 504554f3..4ed8dcd9 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -99,9 +99,14 @@ def list_capi_models(token: str) -> dict[str, dict]: case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: models_list = r.json().get("data", []) case _: - # Unknown endpoint — try OpenAI-style {"data": [...]} + # Unknown endpoint — try common response shapes body = r.json() - models_list = body.get("data", body) if isinstance(body, dict) else body + if isinstance(body, dict): + models_list = body.get("data", []) + elif isinstance(body, list): + models_list = body + else: + models_list = [] for model in models_list: models[model.get("id")] = dict(model) except httpx.RequestError: diff --git a/src/seclab_taskflow_agent/render_utils.py b/src/seclab_taskflow_agent/render_utils.py index 7e91144f..7a018506 100644 --- a/src/seclab_taskflow_agent/render_utils.py +++ b/src/seclab_taskflow_agent/render_utils.py @@ -24,9 +24,9 @@ async def flush_async_output(task_id: str) -> None: """Flush buffered async output for *task_id* to the console.""" async with async_output_lock: if task_id not in async_output: - raise ValueError(f"No async output for task: {task_id}") - data = async_output[task_id] - del async_output[task_id] + # No buffered output (agent may have failed before producing any). + return + data = async_output.pop(task_id) await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n") await render_model_output(data) diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 04bc6d75..73d37a34 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -80,9 +80,10 @@ def _resolve_model_config( models_params: dict[str, dict[str, Any]] = m_config.model_settings or {} if models_params and not isinstance(models_params, dict): raise ValueError(f"Settings section of model_config file {model_config_ref} must be a dictionary") - if not set(models_params.keys()).difference(model_keys).issubset(set()): + unknown = set(models_params) - set(model_keys) + if unknown: raise ValueError( - f"Settings section of model_config file {model_config_ref} contains models not in the model section" + f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}" ) for k, v in models_params.items(): if not isinstance(v, dict): @@ -199,20 +200,24 @@ async def _build_prompts_to_run( logging.warning("repeat_prompt enabled but no {{ result }} in prompt") try: last_result = json.loads(last_mcp_tool_results[-1]) - text = last_result.get("text", "") - try: - iterable_result = json.loads(text) - except json.JSONDecodeError as exc: - logging.critical(f"Could not parse result text: {text}") - raise ValueError("Result text is not valid JSON") from exc - try: - iter(iterable_result) - except TypeError: - logging.critical("Last MCP tool result is not iterable") - raise except IndexError: logging.critical("No last MCP tool result available") raise + except json.JSONDecodeError as exc: + logging.critical(f"Could not parse tool result as JSON: {last_mcp_tool_results[-1][:200]}") + raise ValueError("Tool result is not valid JSON") from exc + + text = last_result.get("text", "") + try: + iterable_result = json.loads(text) + except json.JSONDecodeError as exc: + logging.critical(f"Could not parse result text: {text}") + raise ValueError("Result text is not valid JSON") from exc + try: + iter(iterable_result) + except TypeError: + logging.critical("Last MCP tool result is not iterable") + raise if not iterable_result: await render_model_output("** 🤖❗MCP tool result iterable is empty!\n") diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py index f02d30cc..9b771511 100644 --- a/src/seclab_taskflow_agent/session.py +++ b/src/seclab_taskflow_agent/session.py @@ -128,7 +128,7 @@ def load(cls, session_id: str) -> TaskflowSession: def list_sessions(cls) -> list[TaskflowSession]: """List all saved sessions, most recent first.""" sessions: list[TaskflowSession] = [] - for f in sorted(session_dir().glob("*.json"), reverse=True): + for f in sorted(session_dir().glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True): try: sessions.append(cls.model_validate_json(f.read_text())) except Exception: diff --git a/tests/test_runner.py b/tests/test_runner.py index 401737e7..ca509281 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -237,7 +237,7 @@ def test_engine_keys_extracted(self): ) assert api_type == "responses" assert endpoint == "https://custom.api" - assert token == "secret" + assert token == "secret" # noqa: S105 assert "api_type" not in settings assert "endpoint" not in settings assert "token" not in settings From 8b4899191723ab0a63763cd24f535f55cab5f938 Mon Sep 17 00:00:00 2001 From: Bas Alberts Date: Sat, 28 Mar 2026 14:30:36 -0400 Subject: [PATCH 20/20] fix: address human review feedback on PR #166 - Remove redundant isinstance checks in _resolve_model_config; Pydantic ModelConfigDocument validates dict[str, dict[str, Any]] at parse time, making runtime type checks dead code. Update test to verify Pydantic validation instead of testing the removed checks. - Remove shadowing local imports of os and typer in cli.py that duplicate top-level imports. - Don't inject default temperature (0.0) when api_type is 'responses'; the responses API rejects unsupported parameters. MODEL_TEMP env var and YAML model_settings still override for both API types. Preserves backward-compatible 0.0 default for chat_completions. - Narrow task-level retry to transient network exceptions only (APIConnectionError, APITimeoutError, ConnectionError, TimeoutError). Non-retriable errors (ValueError, RuntimeError, etc.) now break immediately instead of burning through retry attempts. Prevents blind retries of deterministic failures and side-effectful repeat_prompt re-execution on non-transient errors. --- src/seclab_taskflow_agent/cli.py | 3 --- src/seclab_taskflow_agent/runner.py | 30 ++++++++++++++++++----------- tests/test_runner.py | 12 ++++-------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index bd9fc139..75694319 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -43,7 +43,6 @@ def _parse_global(value: str) -> tuple[str, str]: def _setup_logging() -> None: """Configure root logger: file (DEBUG) + console (ERROR).""" - import os from logging.handlers import RotatingFileHandler root = logging.getLogger("") @@ -68,8 +67,6 @@ def _print_concise_error(exc: BaseException) -> None: Walks the exception cause chain and prints each error on a single line. Use ``--debug`` or ``TASK_AGENT_DEBUG=1`` for full tracebacks. """ - import typer - seen: set[int] = set() current: BaseException | None = exc while current and id(current) not in seen: diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index 73d37a34..26345b75 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -31,7 +31,7 @@ from agents.agent import ModelSettings from agents.exceptions import AgentsException, MaxTurnsExceeded from agents.extensions.handoff_prompt import prompt_with_handoff_instructions -from openai import APITimeoutError, BadRequestError, RateLimitError +from openai import APIConnectionError, APITimeoutError, BadRequestError, RateLimitError from openai.types.responses import ResponseTextDeltaEvent from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks @@ -74,20 +74,13 @@ def _resolve_model_config( """ m_config: ModelConfigDocument = available_tools.get_model_config(model_config_ref) model_dict: dict[str, str] = m_config.models or {} - if model_dict and not isinstance(model_dict, dict): - raise ValueError(f"Models section of the model_config file {model_config_ref} must be a dictionary") model_keys: list[str] = list(model_dict.keys()) models_params: dict[str, dict[str, Any]] = m_config.model_settings or {} - if models_params and not isinstance(models_params, dict): - raise ValueError(f"Settings section of model_config file {model_config_ref} must be a dictionary") unknown = set(models_params) - set(model_keys) if unknown: raise ValueError( f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}" ) - for k, v in models_params.items(): - if not isinstance(v, dict): - raise ValueError(f"Settings for model {k} in model_config file {model_config_ref} is not a dictionary") return model_keys, model_dict, models_params, m_config.api_type @@ -299,10 +292,17 @@ async def deploy_task_agents( # Model settings parallel_tool_calls = bool(os.getenv("MODEL_PARALLEL_TOOL_CALLS")) model_params: dict[str, Any] = { - "temperature": os.getenv("MODEL_TEMP", default=0.0), "tool_choice": "auto" if toolboxes else None, "parallel_tool_calls": parallel_tool_calls if toolboxes else None, } + # Only inject a default temperature for chat_completions; the responses + # API rejects unsupported parameters. MODEL_TEMP env override applies + # to both API types. + model_temp = os.getenv("MODEL_TEMP") + if model_temp is not None: + model_params["temperature"] = model_temp + elif api_type != "responses": + model_params["temperature"] = 0.0 model_params.update(model_par) model_settings = ModelSettings(**model_params) @@ -660,7 +660,11 @@ async def _deploy(ra: dict, pp: str) -> bool: complete = result and complete return complete - # Execute the task with auto-retry on failure + # Execute the task with auto-retry on transient failures. + # Only retry on network/API errors — deterministic failures + # and errors after side-effectful work should not be retried + # blindly (e.g. repeat_prompt tasks may have already written + # data to external systems). task_name = task.name or f"task-{task_index}" task_complete = False last_task_error: BaseException | None = None @@ -675,7 +679,7 @@ async def _deploy(ra: dict, pp: str) -> bool: break except (KeyboardInterrupt, SystemExit): raise - except Exception as exc: + except (APIConnectionError, APITimeoutError, ConnectionError, TimeoutError) as exc: last_task_error = exc remaining = TASK_RETRY_LIMIT - attempt - 1 if remaining > 0: @@ -688,6 +692,10 @@ async def _deploy(ra: dict, pp: str) -> bool: await asyncio.sleep(backoff) else: logging.error(f"Task {task_name!r} failed after {TASK_RETRY_LIMIT} attempts: {exc}") + except Exception as exc: + last_task_error = exc + logging.error(f"Task {task_name!r} failed (non-retriable): {exc}") + break # If all retries exhausted with an exception, save and re-raise if last_task_error is not None: diff --git a/tests/test_runner.py b/tests/test_runner.py index ca509281..a50c0f2a 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -107,14 +107,10 @@ def test_model_settings_extraction(self): assert params == {"m1": {"temperature": 0.5}} def test_validation_error_on_non_dict_settings(self): - """Non-dict model_settings raises ValueError.""" - at = _mock_available_tools() - cfg = _make_model_config(models={"m1": "p-m1"}) - # Manually set model_settings to a non-dict after construction - object.__setattr__(cfg, "model_settings", "not-a-dict") - at.get_model_config.return_value = cfg - with pytest.raises(ValueError, match="must be a dictionary"): - _resolve_model_config(at, "ref") + """Pydantic rejects non-dict model_settings at parse time.""" + from pydantic import ValidationError + with pytest.raises(ValidationError): + _make_model_config(models={"m1": "p-m1"}, model_settings="not-a-dict") # ===================================================================