diff --git a/README.md b/README.md index abd3532..70e80af 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Instead of generating CodeQL queries itself, the CodeQL MCP Server is used to pr ## Requirements -Python >= 3.9 or Docker +Python >= 3.10 or Docker ## Configuration diff --git a/pyproject.toml b/pyproject.toml index f15b4e1..fed0b54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "seclab-taskflow-agent" dynamic = ["version"] description = "A taskflow agent for the SecLab project, enabling secure and automated workflow execution." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = "MIT" keywords = [] authors = [ @@ -16,8 +16,6 @@ authors = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/release_tools/copy_files.py b/release_tools/copy_files.py index c6fd68c..0c7561c 100644 --- a/release_tools/copy_files.py +++ b/release_tools/copy_files.py @@ -3,8 +3,8 @@ import os import shutil -import sys import subprocess +import sys def read_file_list(list_path): @@ -12,7 +12,7 @@ def read_file_list(list_path): Reads a file containing file paths, ignoring empty lines and lines starting with '#'. Returns a list of relative file paths. """ - with open(list_path, "r") as f: + with open(list_path) as f: lines = [line.strip() for line in f] return [line for line in lines if line and not line.startswith("#")] diff --git a/release_tools/publish_docker.py b/release_tools/publish_docker.py index 494fe9c..ee56e98 100644 --- a/release_tools/publish_docker.py +++ b/release_tools/publish_docker.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import os -import shutil import subprocess import sys diff --git a/src/seclab_taskflow_agent/__main__.py b/src/seclab_taskflow_agent/__main__.py index ff1a07e..b3a0fc5 100644 --- a/src/seclab_taskflow_agent/__main__.py +++ b/src/seclab_taskflow_agent/__main__.py @@ -1,48 +1,46 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import asyncio -from threading import Thread import argparse -import os -import sys -from dotenv import load_dotenv, find_dotenv +import asyncio +import json import logging -from logging.handlers import RotatingFileHandler -from pprint import pprint, pformat +import os +import pathlib import re -import json +import sys import uuid -import pathlib +from collections.abc import Callable +from logging.handlers import RotatingFileHandler +from pprint import pformat -from .agent import DEFAULT_MODEL, TaskRunHooks, TaskAgentHooks +from agents import Agent, RunContextWrapper, TContext, Tool +from agents.agent import ModelSettings # from agents.run import DEFAULT_MAX_TURNS # XXX: this is 10, we need more than that -from agents.exceptions import MaxTurnsExceeded, AgentsException -from agents.agent import ModelSettings -from agents.mcp import MCPServer, MCPServerStdio, MCPServerSse, MCPServerStreamableHttp, create_static_tool_filter +from agents.exceptions import AgentsException, MaxTurnsExceeded from agents.extensions.handoff_prompt import prompt_with_handoff_instructions -from agents import Tool, RunContextWrapper, TContext, Agent -from openai import BadRequestError, APITimeoutError, RateLimitError +from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter +from dotenv import find_dotenv, load_dotenv +from openai import APITimeoutError, BadRequestError, RateLimitError from openai.types.responses import ResponseTextDeltaEvent -from typing import Callable -from .shell_utils import shell_tool_call +from .agent import DEFAULT_MODEL, TaskAgent, TaskAgentHooks, TaskRunHooks +from .available_tools import AvailableTools +from .capi import get_AI_token, list_tool_call_models +from .env_utils import TmpEnv from .mcp_utils import ( DEFAULT_MCP_CLIENT_SESSION_TIMEOUT, - ReconnectingMCPServerStdio, MCPNamespaceWrap, - mcp_client_params, - mcp_system_prompt, + ReconnectingMCPServerStdio, StreamableMCPThread, compress_name, + mcp_client_params, + mcp_system_prompt, ) -from .render_utils import render_model_output, flush_async_output -from .env_utils import TmpEnv -from .agent import TaskAgent -from .capi import list_tool_call_models, get_AI_token -from .available_tools import AvailableTools from .path_utils import log_file_name +from .render_utils import flush_async_output, render_model_output +from .shell_utils import shell_tool_call load_dotenv(find_dotenv(usecwd=True)) @@ -91,7 +89,7 @@ def parse_prompt_args(available_tools: AvailableTools, user_prompt: str | None = args = parser.parse_known_args(user_prompt.split(" ") if user_prompt else None) except SystemExit as e: if e.code == 2: - logging.error(f"User provided incomplete prompt: {user_prompt}") + logging.exception(f"User provided incomplete prompt: {user_prompt}") return None, None, None, None, help_msg p = args[0].p.strip() if args[0].p else None t = args[0].t.strip() if args[0].t else None @@ -258,14 +256,13 @@ async def mcp_session_task(mcp_servers: list, connected: asyncio.Event, cleanup: except Exception as e: print(f"Streamable mcp server process exception: {e}") except asyncio.CancelledError: - logging.error(f"Timeout on cleanup for mcp server: {server._name}") + logging.exception(f"Timeout on cleanup for mcp server: {server._name}") finally: mcp_servers.remove(s) except RuntimeError as e: - logging.error(f"RuntimeError in mcp session task: {e}") + logging.exception("RuntimeError in mcp session task") except asyncio.CancelledError as e: - logging.error(f"Timeout on main session task: {e}") - pass + logging.exception("Timeout on main session task") finally: mcp_servers.clear() @@ -353,17 +350,17 @@ async def _run_streamed(): return except APITimeoutError: if not max_retry: - logging.error(f"Max retries for APITimeoutError reached") + logging.exception("Max retries for APITimeoutError reached") raise max_retry -= 1 except RateLimitError: if rate_limit_backoff == MAX_RATE_LIMIT_BACKOFF: - raise APITimeoutError(f"Max rate limit backoff reached") + raise APITimeoutError("Max rate limit backoff reached") if rate_limit_backoff > MAX_RATE_LIMIT_BACKOFF: rate_limit_backoff = MAX_RATE_LIMIT_BACKOFF else: rate_limit_backoff += rate_limit_backoff - logging.error(f"Hit rate limit ... holding for {rate_limit_backoff}") + logging.exception(f"Hit rate limit ... holding for {rate_limit_backoff}") await asyncio.sleep(rate_limit_backoff) await _run_streamed() @@ -372,16 +369,16 @@ async def _run_streamed(): # raise exceptions up to here for anything that indicates a task failure except MaxTurnsExceeded as e: await render_model_output(f"** 🤖❗ Max Turns Reached: {e}\n", async_task=async_task, task_id=task_id) - logging.error(f"Exceeded max_turns: {max_turns}") + logging.exception(f"Exceeded max_turns: {max_turns}") except AgentsException as e: await render_model_output(f"** 🤖❗ Agent Exception: {e}\n", async_task=async_task, task_id=task_id) - logging.error(f"Agent Exception: {e}") + logging.exception("Agent Exception") except BadRequestError as e: await render_model_output(f"** 🤖❗ Request Error: {e}\n", async_task=async_task, task_id=task_id) - logging.error(f"Bad Request: {e}") + logging.exception("Bad Request") except APITimeoutError as e: await render_model_output(f"** 🤖❗ Timeout Error: {e}\n", async_task=async_task, task_id=task_id) - logging.error(f"Bad Request: {e}") + logging.exception("Bad Request") if async_task: await flush_async_output(task_id) @@ -392,14 +389,14 @@ async def _run_streamed(): # signal mcp sessions task that it can disconnect our servers start_cleanup.set() cleanup_attempts_left = len(mcp_servers) - while cleanup_attempts_left and len(mcp_servers): + while cleanup_attempts_left and mcp_servers: try: cleanup_attempts_left -= 1 await asyncio.wait_for(mcp_sessions, timeout=MCP_CLEANUP_TIMEOUT) - except asyncio.TimeoutError as e: + except asyncio.TimeoutError: continue except Exception as e: - logging.error(f"Exception in mcp server cleanup task: {e}") + logging.exception("Exception in mcp server cleanup task") async def main(available_tools: AvailableTools, p: str | None, t: str | None, cli_globals: dict, prompt: str | None): @@ -581,7 +578,7 @@ def preprocess_prompt(prompt: str, tag: str, kv: Callable[[str], dict], kv_subke async def run_prompts(async_task=False, max_concurrent_tasks=5): # if this is a shell task, execute that and append the results if run: - await render_model_output(f"** 🤖🐚 Executing Shell Task\n") + await render_model_output("** 🤖🐚 Executing Shell Task\n") # this allows e.g. shell based jq output to become available for repeat prompts try: result = shell_tool_call(run).content[0].model_dump_json() @@ -589,7 +586,7 @@ async def run_prompts(async_task=False, max_concurrent_tasks=5): return True except RuntimeError as e: await render_model_output(f"** 🤖❗ Shell Task Exception: {e}\n") - logging.error(f"Shell task error: {e}") + logging.exception("Shell task error") return False tasks = [] diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 3e6fc75..d3b1bf8 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -2,34 +2,32 @@ # SPDX-License-Identifier: MIT # https://openai.github.io/openai-agents-python/agents/ -import os import logging -from dotenv import load_dotenv, find_dotenv +import os from collections.abc import Callable from typing import Any from urllib.parse import urlparse -from openai import AsyncOpenAI -from agents.agent import ModelSettings, ToolsToFinalOutputResult, FunctionToolResult -from agents.run import DEFAULT_MAX_TURNS -from agents.run import RunHooks from agents import ( Agent, - Runner, AgentHooks, - RunHooks, - result, - function_tool, - Tool, + OpenAIChatCompletionsModel, RunContextWrapper, + RunHooks, + Runner, TContext, - OpenAIChatCompletionsModel, - set_default_openai_client, + Tool, + result, set_default_openai_api, + set_default_openai_client, set_tracing_disabled, ) +from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult +from agents.run import DEFAULT_MAX_TURNS, RunHooks +from dotenv import find_dotenv, load_dotenv +from openai import AsyncOpenAI -from .capi import COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token, AI_API_ENDPOINT_ENUM +from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token # grab our secrets from .env, this must be in .gitignore load_dotenv(find_dotenv(usecwd=True)) diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index 15808a7..d73f9d3 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -from enum import Enum -import logging import importlib.resources +from enum import Enum + import yaml diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 2fc1d0d..7f29b63 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -2,13 +2,14 @@ # SPDX-License-Identifier: MIT # CAPI specific interactions -import httpx import json import logging import os -from strenum import StrEnum from urllib.parse import urlparse +import httpx +from strenum import StrEnum + # Enumeration of currently supported API endpoints. class AI_API_ENDPOINT_ENUM(StrEnum): @@ -96,11 +97,11 @@ def list_capi_models(token: str) -> dict[str, dict]: for model in models_list: models[model.get("id")] = dict(model) except httpx.RequestError as e: - logging.error(f"Request error: {e}") + logging.exception("Request error") except json.JSONDecodeError as e: - logging.error(f"JSON error: {e}") + logging.exception("JSON error") except httpx.HTTPStatusError as e: - logging.error(f"HTTP error: {e}") + logging.exception("HTTP error") return models diff --git a/src/seclab_taskflow_agent/env_utils.py b/src/seclab_taskflow_agent/env_utils.py index 48eceb3..509155b 100644 --- a/src/seclab_taskflow_agent/env_utils.py +++ b/src/seclab_taskflow_agent/env_utils.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import re import os +import re def swap_env(s): diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py index 6314072..d2bd10b 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: MIT # a query-server2 codeql client -import subprocess -import re import json -from pathlib import Path +import os +import re +import subprocess import tempfile import time -from urllib.parse import urlparse, unquote -import os import zipfile +from pathlib import Path +from urllib.parse import unquote, urlparse + import yaml + from seclab_taskflow_agent.path_utils import log_file_name # this is a local fork of https://github.com/riga/jsonrpyc modified for our purposes @@ -276,7 +278,7 @@ def _search_path(self): def _search_paths_from_codeql_config(self, config="~/.config/codeql/config"): try: - with open(config, "r") as f: + with open(config) as f: match = re.search(r"^--search-path(\s+|=)\s*(.*)", f.read()) if match and match.group(2): return match.group(2).split(":") @@ -412,20 +414,19 @@ def __enter__(self): global _ACTIVE_CODEQL_SERVERS if self.database in _ACTIVE_CODEQL_SERVERS: return _ACTIVE_CODEQL_SERVERS[self.database] - else: - if not self.active_connection: - self._server_start() - print("Waiting for server start ...") - while not self.active_connection: - time.sleep(WAIT_INTERVAL) - if not self.active_database: - self._server_register_database(self.database) - print("Waiting for database registration ...") - while not self.active_database: - time.sleep(WAIT_INTERVAL) - if self.keep_alive: - _ACTIVE_CODEQL_SERVERS[self.database] = self - return self + if not self.active_connection: + self._server_start() + print("Waiting for server start ...") + while not self.active_connection: + time.sleep(WAIT_INTERVAL) + if not self.active_database: + self._server_register_database(self.database) + print("Waiting for database registration ...") + while not self.active_database: + time.sleep(WAIT_INTERVAL) + if self.keep_alive: + _ACTIVE_CODEQL_SERVERS[self.database] = self + return self def __exit__(self, exc_type, exc_val, exc_tb): if self.database not in _ACTIVE_CODEQL_SERVERS: @@ -534,7 +535,7 @@ def _file_from_src_archive(relative_path: str | Path, database_path: str | Path, # fall back to relative path if resolved_path does not exist (might be a build dep file) if str(resolved_path) not in files: resolved_path = Path(relative_path) - file_data = shell_command_to_string(["unzip", "-p", src_path, f"{str(resolved_path)}"]) + file_data = shell_command_to_string(["unzip", "-p", src_path, f"{resolved_path!s}"]) if region: def region_from_file(): diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py index 3407dbc..9a793eb 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__init__.py @@ -1,35 +1,32 @@ -# coding: utf-8 from __future__ import annotations __all__: list[str] = [] +import io +import json import os +import re import sys -import json -import io -import time import threading -import re -from typing import Any, Callable, Type, Protocol, Optional - -from typing_extensions import TypeAlias +import time +from collections.abc import Callable +from typing import Any, Optional, Protocol, Type, TypeAlias # package infos from .__meta__ import ( # noqa - __doc__, __author__, - __email__, + __contact__, __copyright__, __credits__, - __contact__, + __doc__, + __email__, __license__, __status__, __version__, ) - -Callback: TypeAlias = Callable[[Optional[Exception], Optional[Any]], None] +Callback: TypeAlias = Callable[[Exception | None, Any | None], None] class InputStream(Protocol): @@ -60,7 +57,7 @@ def write(self, b: str) -> int: ... def flush(self) -> None: ... -class Spec(object): +class Spec: """ This class wraps methods that create JSON-RPC 2.0 compatible string representations of request, response and error objects. All methods are class members, so you might never want to @@ -246,7 +243,7 @@ def error( return err -class RPC(object): +class RPC: """ The main class of *jsonrpyc*. Instances of this class wrap an input stream *stdin* and an output stream *stdout* in order to communicate with other services. A service is not even forced to be @@ -351,13 +348,13 @@ def __init__( if stdin is None: stdin = sys.stdin self.original_stdin = stdin - self.stdin = io.open(stdin.fileno(), "rb") + self.stdin = open(stdin.fileno(), "rb") # open output stream if stdout is None: stdout = sys.stdout self.original_stdout = stdout - self.stdout = io.open(stdout.fileno(), "wb") + self.stdout = open(stdout.fileno(), "wb") # other attributes self._i = -1 @@ -394,7 +391,7 @@ def call( *, callback: Callback | None = None, block: int = 0, - timeout: float | int = 0, + timeout: float = 0, params: dict | None = None, ) -> int: """ @@ -667,7 +664,7 @@ def __init__( self, rpc: RPC, name: str = "watchdog", - interval: float | int = 0.1, + interval: float = 0.1, daemon: bool = False, start: bool = True, ) -> None: @@ -729,12 +726,12 @@ def run(self) -> None: break # Keep linter happy - if self.rpc.original_stdin and self.rpc.original_stdin.closed: # type: ignore[attr-defined] # noqa + if self.rpc.original_stdin and self.rpc.original_stdin.closed: # type: ignore[attr-defined] break try: line = self.rpc.stdin.readline() - except IOError: + except OSError: line = None if line: @@ -821,11 +818,11 @@ def __str__(self) -> str: return self.message -error_map_code: dict[int, Type[RPCError]] = {} -error_map_code_range: dict[tuple[int, int], Type[RPCError]] = {} +error_map_code: dict[int, type[RPCError]] = {} +error_map_code_range: dict[tuple[int, int], type[RPCError]] = {} -def register_error(cls: Type[RPCError]) -> Type[RPCError]: +def register_error(cls: type[RPCError]) -> type[RPCError]: """ Decorator that registers a new RPC error derived from :py:class:`RPCError`. The purpose of error registration is to have a mapping of error codes/code ranges to error classes for faster @@ -855,7 +852,7 @@ class MyCustomRPCError(RPCError): return cls -def get_error(code: int) -> Type[RPCError]: +def get_error(code: int) -> type[RPCError]: """ Returns the RPC error class that was previously registered to *code*. A ``ValueError`` is raised if no error class was found for *code*. diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__meta__.py b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__meta__.py index 2c192eb..0ee8a89 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__meta__.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/jsonrpyc/__meta__.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Minimal python RPC implementation in a single file based on the JSON-RPC 2.0 specs from diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py index ad8acf2..a5fdf7b 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py @@ -1,21 +1,19 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import logging -from .client import run_query, file_from_uri, list_src_files, _debug_log, search_in_src_archive -from pydantic import Field - -# from mcp.server.fastmcp import FastMCP, Context -from fastmcp import FastMCP, Context # use FastMCP 2.0 -from pathlib import Path -import os import csv import json -import time +import logging import re -from urllib.parse import urlparse, unquote -import zipfile -from seclab_taskflow_agent.path_utils import mcp_data_dir, log_file_name +from pathlib import Path + +# from mcp.server.fastmcp import FastMCP, Context +from fastmcp import FastMCP # use FastMCP 2.0 +from pydantic import Field + +from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir + +from .client import _debug_log, file_from_uri, list_src_files, run_query, search_in_src_archive logging.basicConfig( level=logging.DEBUG, @@ -137,8 +135,7 @@ def get_file_contents( try: # fix up any incorrectly formatted relative path uri if not file_uri.startswith("file:///"): - if file_uri.startswith("file://"): - file_uri = file_uri[len("file://") :] + file_uri = file_uri.removeprefix("file://") file_uri = "file:///" + file_uri.lstrip("/") results = _get_file_contents(database_path, file_uri) except Exception as e: diff --git a/src/seclab_taskflow_agent/mcp_servers/echo/echo.py b/src/seclab_taskflow_agent/mcp_servers/echo/echo.py index 9ffd470..9cd3bf3 100644 --- a/src/seclab_taskflow_agent/mcp_servers/echo/echo.py +++ b/src/seclab_taskflow_agent/mcp_servers/echo/echo.py @@ -5,6 +5,7 @@ # from mcp.server.fastmcp import FastMCP from fastmcp import FastMCP # move to FastMCP 2.0 + from seclab_taskflow_agent.path_utils import log_file_name logging.basicConfig( diff --git a/src/seclab_taskflow_agent/mcp_servers/logbook/logbook.py b/src/seclab_taskflow_agent/mcp_servers/logbook/logbook.py index b01abc2..9a11bf6 100644 --- a/src/seclab_taskflow_agent/mcp_servers/logbook/logbook.py +++ b/src/seclab_taskflow_agent/mcp_servers/logbook/logbook.py @@ -1,13 +1,14 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT +import json import logging +from pathlib import Path # from mcp.server.fastmcp import FastMCP from fastmcp import FastMCP # move to FastMCP 2.0 -import json -from pathlib import Path -from seclab_taskflow_agent.path_utils import mcp_data_dir, log_file_name + +from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir logging.basicConfig( level=logging.DEBUG, @@ -48,7 +49,7 @@ def inflate_log(): ensure_log() global LOG global LOGBOOK - with open(LOGBOOK, "r") as logbook: + with open(LOGBOOK) as logbook: LOG = json.loads(logbook.read()) diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py index ca7ba9b..1ebcc63 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache.py @@ -1,17 +1,18 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT +import json import logging +import os +from typing import Any # from mcp.server.fastmcp import FastMCP from fastmcp import FastMCP # move to FastMCP 2.0 -import json -from pathlib import Path -import os -from typing import Any + +from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir + from .memcache_backend.dictionary_file import MemcacheDictionaryFileBackend from .memcache_backend.sqlite import SqliteBackend -from seclab_taskflow_agent.path_utils import mcp_data_dir, log_file_name logging.basicConfig( level=logging.DEBUG, diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py index 48d9d01..04f9a8e 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/dictionary_file.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -from .backend import Backend import json from pathlib import Path from typing import Any +from .backend import Backend + class MemcacheDictionaryFileBackend(Backend): """A simple dictionary file backend for a memory cache.""" @@ -32,7 +33,7 @@ def _deflate_memory(self): def _inflate_memory(self): self._ensure_memory() - with open(self.memory, "r") as memory: + with open(self.memory) as memory: self.memcache = json.loads(memory.read()) def with_memory(self, f): @@ -68,8 +69,7 @@ def _delete_state(key: str) -> str: if key in self.memcache: del self.memcache[key] return f"Deleted key `{key}` from memory cache." - else: - return f"Key `{key}` not found in memory cache." + return f"Key `{key}` not found in memory cache." return _delete_state(key) @@ -87,11 +87,10 @@ def _add_state(key: str, value: Any) -> str: if type(existing) == type(value) and hasattr(existing, "__add__"): self.memcache[key] = existing + value return f"Updated and added to value in memory for key: `{key}`" - elif type(existing) == list: + if type(existing) == list: self.memcache[key].append(value) return f"Updated and added to value in memory for key: `{key}`" - else: - return f"Error: unsupported types for memcache add `{type(existing)} + {type(value)}` for key `{key}`" + return f"Error: unsupported types for memcache add `{type(existing)} + {type(value)}` for key `{key}`" return _add_state(key, value) diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sql_models.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sql_models.py index 5c0618e..89e6509 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sql_models.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sql_models.py @@ -1,9 +1,8 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -from sqlalchemy import String, Text, Integer, ForeignKey, Column -from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship -from typing import Optional +from sqlalchemy import Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): diff --git a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py index 612ade7..be922c3 100644 --- a/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py +++ b/src/seclab_taskflow_agent/mcp_servers/memcache/memcache_backend/sqlite.py @@ -1,15 +1,16 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT +import json import os from pathlib import Path +from typing import Any + from sqlalchemy import create_engine from sqlalchemy.orm import Session -from typing import Any -import json -from .sql_models import KeyValue, Base from .backend import Backend +from .sql_models import Base, KeyValue class SqliteBackend(Backend): @@ -45,7 +46,7 @@ def get_state(self, key: str) -> Any: for r in results[1:]: existing.append(r) return existing - elif hasattr(existing, "__add__"): + if hasattr(existing, "__add__"): try: for r in results[1:]: existing += r @@ -78,8 +79,7 @@ def delete_state(self, key: str) -> str: session.commit() if result: return f"Deleted key `{key}` from memory cache." - else: - return f"Key `{key}` not found in memory cache." + return f"Key `{key}` not found in memory cache." def clear_cache(self) -> str: with Session(self.engine) as session: diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index d924d82..be4afa7 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -1,25 +1,24 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import logging import asyncio -from threading import Thread, Event +import hashlib import json -import subprocess -from typing import Optional, Callable -import shutil -import time +import logging import os +import shutil import socket -import signal -import hashlib +import subprocess +import time +from collections.abc import Callable +from threading import Event, Thread from urllib.parse import urlparse -from mcp.types import CallToolResult, TextContent from agents.mcp import MCPServerStdio +from mcp.types import CallToolResult, TextContent +from .available_tools import AvailableTools, AvailableToolType from .env_utils import swap_env -from .available_tools import AvailableToolType, AvailableTools DEFAULT_MCP_CLIENT_SESSION_TIMEOUT = 120 @@ -42,10 +41,10 @@ def __init__( self, cmd, url: str = "", - on_output: Optional[Callable[[str], None]] = None, - on_error: Optional[Callable[[str], None]] = None, + on_output: Callable[[str], None] | None = None, + on_error: Callable[[str], None] | None = None, poll_interval: float = 0.5, - env: Optional[dict[str, str]] = None, + env: dict[str, str] | None = None, ): super().__init__(daemon=True) self.url = url @@ -58,7 +57,7 @@ def __init__( self._stop_event = Event() self.process = None self.exit_code = None - self.exception: Optional[BaseException] = None + self.exception: BaseException | None = None async def async_wait_for_connection(self, timeout=30.0, poll_interval=0.5): parsed = urlparse(self.url) @@ -145,7 +144,7 @@ def stop(self): def is_running(self): return self.process and self.process.poll() is None - def join_and_raise(self, timeout: Optional[float] = None): + def join_and_raise(self, timeout: float | None = None): self.join(timeout) if self.is_alive(): raise RuntimeError("Process thread did not exit within timeout.") @@ -264,7 +263,7 @@ def confirm_tool(self, tool_name, args): ) if yn in ["yes", "y"]: return True - elif yn in ["no", "n"]: + if yn in ["no", "n"]: return False async def call_tool(self, *args, **kwargs): @@ -326,7 +325,7 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list for k, v in dict(optional_headers).items(): try: optional_headers[k] = swap_env(v) - except LookupError as e: + except LookupError: del optional_headers[k] if isinstance(headers, dict): if isinstance(optional_headers, dict): @@ -354,7 +353,7 @@ def mcp_client_params(available_tools: AvailableTools, requested_toolboxes: list for k, v in dict(optional_headers).items(): try: optional_headers[k] = swap_env(v) - except LookupError as e: + except LookupError: del optional_headers[k] if isinstance(headers, dict): if isinstance(optional_headers, dict): @@ -411,9 +410,9 @@ def mcp_system_prompt( server_prompts: list[str] = [], ): """Return a well constructed system prompt""" - prompt = """ + prompt = f""" {system_prompt} -""".format(system_prompt=system_prompt) +""" if tools: prompt += """ @@ -457,12 +456,12 @@ def mcp_system_prompt( """.format(server_prompts="\n\n".join(server_prompts)) if task: - prompt += """ + prompt += f""" # Primary Task to Complete {task} -""".format(task=task) +""" return prompt diff --git a/src/seclab_taskflow_agent/path_utils.py b/src/seclab_taskflow_agent/path_utils.py index 31e0254..9e6cfa5 100644 --- a/src/seclab_taskflow_agent/path_utils.py +++ b/src/seclab_taskflow_agent/path_utils.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import platformdirs import os from pathlib import Path +import platformdirs + def mcp_data_dir(packagename: str, mcpname: str, env_override: str | None) -> Path: """ diff --git a/src/seclab_taskflow_agent/render_utils.py b/src/seclab_taskflow_agent/render_utils.py index 98440d6..4e410cf 100644 --- a/src/seclab_taskflow_agent/render_utils.py +++ b/src/seclab_taskflow_agent/render_utils.py @@ -1,8 +1,9 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT -import logging import asyncio +import logging + from .path_utils import log_file_name async_output = {} @@ -19,9 +20,8 @@ async def flush_async_output(task_id: str): async with async_output_lock: if task_id not in async_output: raise ValueError(f"No async output for task: {task_id}") - else: - data = async_output[task_id] - del async_output[task_id] + data = async_output[task_id] + del async_output[task_id] await render_model_output(f"** 🤖✏️ Output for async task: {task_id}\n\n") await render_model_output(data) diff --git a/src/seclab_taskflow_agent/shell_utils.py b/src/seclab_taskflow_agent/shell_utils.py index ff8a497..7c7504c 100644 --- a/src/seclab_taskflow_agent/shell_utils.py +++ b/src/seclab_taskflow_agent/shell_utils.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: 2025 GitHub # SPDX-License-Identifier: MIT +import logging import subprocess import tempfile -import logging from mcp.types import CallToolResult, TextContent diff --git a/tests/test_api_endpoint_config.py b/tests/test_api_endpoint_config.py index a386c33..97498d9 100644 --- a/tests/test_api_endpoint_config.py +++ b/tests/test_api_endpoint_config.py @@ -5,10 +5,12 @@ Test API endpoint configuration. """ -import pytest import os from urllib.parse import urlparse -from seclab_taskflow_agent.capi import get_AI_endpoint, AI_API_ENDPOINT_ENUM, list_capi_models + +import pytest + +from seclab_taskflow_agent.capi import AI_API_ENDPOINT_ENUM, get_AI_endpoint, list_capi_models class TestAPIEndpoint: diff --git a/tests/test_cli_parser.py b/tests/test_cli_parser.py index 62bcab6..20fda8c 100644 --- a/tests/test_cli_parser.py +++ b/tests/test_cli_parser.py @@ -6,6 +6,7 @@ """ import pytest + from seclab_taskflow_agent.available_tools import AvailableTools diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index ad8e4fe..4c7a01e 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -8,6 +8,7 @@ """ import pytest + from seclab_taskflow_agent.available_tools import AvailableTools