Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ ignore = [
"A002", # Argument shadows built-in
"A004", # Import shadows built-in
"FBT001", # Boolean positional arg
"FBT002", # Boolean default value
"N801", # Class name casing
"N802", # Function name casing
"N806", # Variable casing
Expand Down
1 change: 1 addition & 0 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
name: str = "TaskAgent",
instructions: str = "",
handoffs: list[Any] | None = None,
*,
exclude_from_context: bool = False,
mcp_servers: list[Any] | None = None,
model: str = DEFAULT_MODEL,
Expand Down
1 change: 1 addition & 0 deletions src/seclab_taskflow_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def main(
str | None,
typer.Option("-t", "--taskflow", help="Taskflow module path (mutually exclusive with -p)."),
] = None,
*,
list_models: Annotated[
bool,
typer.Option("-l", "--list-models", help="List available tool-call models and exit."),
Expand Down
1 change: 1 addition & 0 deletions src/seclab_taskflow_agent/mcp_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def build_mcp_servers(
available_tools: AvailableTools,
toolboxes: list[str],
blocked_tools: list[str] | None = None,
*,
headless: bool = False,
) -> list[MCPServerEntry]:
"""Build MCP server instances for the given toolboxes.
Expand Down
10 changes: 6 additions & 4 deletions src/seclab_taskflow_agent/mcp_servers/codeql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
self,
codeql_cli=os.getenv("CODEQL_CLI", default="codeql"),
server_options=["--threads=0", "--quiet"],
*,
log_stderr=False,
):
self.server_options = server_options.copy()
Expand Down Expand Up @@ -406,7 +407,7 @@ def _bqrs_to_sarif(self, bqrs_path, query_info, max_paths=10):


class QueryServer(CodeQL):
def __init__(self, database: Path, keep_alive=False, log_stderr=False):
def __init__(self, database: Path, *, keep_alive=False, log_stderr=False):
super().__init__(log_stderr=log_stderr)
self.database = database
self.keep_alive = keep_alive
Expand Down Expand Up @@ -476,7 +477,7 @@ def _file_uri_to_path(uri):
return path, region


def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str:
def _get_source_prefix(database_path: Path, *, strip_leading_slash=True) -> str:
# grab the source prefix from codeql-database.yml
db_yml_path = Path(database_path) / Path("codeql-database.yml")
with open(db_yml_path) as stream:
Expand All @@ -491,7 +492,7 @@ def _get_source_prefix(database_path: Path, strip_leading_slash=True) -> str:
raise


def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True):
def list_src_files(database_path: str | Path, *, as_uri=False, strip_prefix=True):
src_path = Path(database_path) / Path("src.zip")
files = shell_command_to_string(["zipinfo", "-1", src_path]).split("\n")
source_prefix = _get_source_prefix(Path(database_path))
Expand All @@ -503,7 +504,7 @@ def list_src_files(database_path: str | Path, as_uri=False, strip_prefix=True):
return files


def search_in_src_archive(database_path: str, search_term: str, as_uri=False, strip_prefix=True):
def search_in_src_archive(database_path: str, search_term: str, *, as_uri=False, strip_prefix=True):
database_path = Path(database_path)
src_path = database_path / Path("src.zip")
results = {}
Expand Down Expand Up @@ -595,6 +596,7 @@ def run_query(
target="",
progress_callback=None,
template_values=None,
*,
# keep the query server alive if desired
keep_alive=True,
log_stderr=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def __init__(
rpc: RPC,
name: str = "watchdog",
interval: float = 0.1,
*,
daemon: bool = False,
start: bool = True,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/seclab_taskflow_agent/render_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def flush_async_output(task_id: str) -> None:
await render_model_output(data)


async def render_model_output(data: str, log: bool = True, async_task: bool = False, task_id: str | None = None) -> None:
async def render_model_output(data: str, *, log: bool = True, async_task: bool = False, task_id: str | None = None) -> None:
"""Print model output to the console, optionally buffering for async tasks."""
async with async_output_lock:
if async_task and task_id:
Expand Down
4 changes: 2 additions & 2 deletions src/seclab_taskflow_agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ async def deploy_task_agents(
model_settings = ModelSettings(**model_params)

# Build MCP servers and collect server prompts
entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless)
entries = build_mcp_servers(available_tools, toolboxes, blocked_tools, headless=headless)
mcp_params = mcp_client_params(available_tools, toolboxes)
server_prompts = [sp for _, (_, _, sp, _) in mcp_params.items()]

Expand Down Expand Up @@ -585,7 +585,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo
available_tools, global_variables, inputs,
)

async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) -> bool:
async def run_prompts(*, async_task: bool = False, max_concurrent_tasks: int = 5) -> bool:
if run:
await render_model_output("** 🤖🐚 Executing Shell Task\n")
try:
Expand Down
2 changes: 1 addition & 1 deletion src/seclab_taskflow_agent/template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_source(
raise jinja2.TemplateNotFound(template)


def env_function(var_name: str, default: Optional[str] = None, required: bool = True) -> str:
def env_function(var_name: str, default: Optional[str] = None, *, required: bool = True) -> str:
"""Jinja2 function to access environment variables.

Args:
Expand Down
Loading