Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions release_tools/copy_files.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT
# ruff: noqa: T201, S607

import os
import shutil
import sys
import subprocess
import sys


def read_file_list(list_path):
"""
Reads a file containing file paths, ignoring empty lines and lines starting with '#'.
Returns a list of relative file paths.
"""
with open(list_path, "r") as f:
with open(list_path) as f:
lines = [line.strip() for line in f]
return [line for line in lines if line and not line.startswith("#")]

Expand Down
6 changes: 3 additions & 3 deletions release_tools/publish_docker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT
# ruff: noqa: T201, S607

import os
import shutil
import subprocess
import sys


def get_image_digest(image_name, tag):
result = subprocess.run(
["docker", "buildx", "imagetools", "inspect", f"{image_name}:{tag}"],
Expand All @@ -28,7 +28,7 @@ def build_and_push_image(dest_dir, image_name, tag):
print(f"Pushed {image_name}:{tag}")
digest = get_image_digest(image_name, tag)
print(f"Image digest: {digest}")
with open("/tmp/digest.txt", "w") as f:
with open("/tmp/digest.txt", "w") as f: # noqa: S108
f.write(digest)

if __name__ == "__main__":
Expand Down
169 changes: 103 additions & 66 deletions src/seclab_taskflow_agent/__main__.py

Large diffs are not rendered by default.

56 changes: 36 additions & 20 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,32 @@
# SPDX-License-Identifier: MIT

# https://openai.github.io/openai-agents-python/agents/
import os
import logging
from dotenv import load_dotenv, find_dotenv
import os
from collections.abc import Callable
from typing import Any
from urllib.parse import urlparse

from openai import AsyncOpenAI
from agents.agent import ModelSettings, ToolsToFinalOutputResult, FunctionToolResult
from agents import (
Agent,
AgentHooks,
OpenAIChatCompletionsModel,
RunContextWrapper,
RunHooks,
Runner,
TContext,
Tool,
result,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
)
from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult
from agents.run import DEFAULT_MAX_TURNS
from agents.run import RunHooks
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled
from dotenv import find_dotenv, load_dotenv
from openai import AsyncOpenAI

from .capi import COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token, AI_API_ENDPOINT_ENUM
from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token

# grab our secrets from .env, this must be in .gitignore
load_dotenv(find_dotenv(usecwd=True))
Expand Down Expand Up @@ -52,7 +64,7 @@ async def on_agent_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext]) -> None:
logging.debug(f"TaskRunHooks on_agent_start: {agent.name}")
logging.debug("TaskRunHooks on_agent_start: %s", agent.name)
if self._on_agent_start:
await self._on_agent_start(context, agent)

Expand All @@ -61,7 +73,7 @@ async def on_agent_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
output: Any) -> None:
logging.debug(f"TaskRunHooks on_agent_end: {agent.name}")
logging.debug("TaskRunHooks on_agent_end: %s", agent.name)
if self._on_agent_end:
await self._on_agent_end(context, agent, output)

Expand All @@ -70,7 +82,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool) -> None:
logging.debug(f"TaskRunHooks on_tool_start: {tool.name}")
logging.debug("TaskRunHooks on_tool_start: %s", tool.name)
if self._on_tool_start:
await self._on_tool_start(context, agent, tool)

Expand All @@ -80,7 +92,7 @@ async def on_tool_end(
agent: Agent[TContext],
tool: Tool,
result: str) -> None:
logging.debug(f"TaskRunHooks on_tool_end: {tool.name} ")
logging.debug("TaskRunHooks on_tool_end: %s", tool.name)
if self._on_tool_end:
await self._on_tool_end(context, agent, tool, result)

Expand All @@ -103,15 +115,15 @@ async def on_handoff(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
source: Agent[TContext]) -> None:
logging.debug(f"TaskAgentHooks on_handoff: {source.name} -> {agent.name}")
logging.debug("TaskAgentHooks on_handoff: %s -> %s", source.name, agent.name)
if self._on_handoff:
await self._on_handoff(context, agent, source)

async def on_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext]) -> None:
logging.debug(f"TaskAgentHooks on_start: {agent.name}")
logging.debug("TaskAgentHooks on_start: %s", agent.name)
if self._on_start:
await self._on_start(context, agent)

Expand All @@ -120,7 +132,7 @@ async def on_end(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
output: Any) -> None:
logging.debug(f"TaskAgentHooks on_end: {agent.name}")
logging.debug("TaskAgentHooks on_end: %s", agent.name)
if self._on_end:
await self._on_end(context, agent, output)

Expand All @@ -129,7 +141,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool) -> None:
logging.debug(f"TaskAgentHooks on_tool_start: {tool.name}")
logging.debug("TaskAgentHooks on_tool_start: %s", tool.name)
if self._on_tool_start:
await self._on_tool_start(context, agent, tool)

Expand All @@ -139,21 +151,25 @@ async def on_tool_end(
agent: Agent[TContext],
tool: Tool,
result: str) -> None:
logging.debug(f"TaskAgentHooks on_tool_end: {tool.name}")
logging.debug("TaskAgentHooks on_tool_end: %s", tool.name)
if self._on_tool_end:
await self._on_tool_end(context, agent, tool, result)

class TaskAgent:
def __init__(self,
name: str = 'TaskAgent',
instructions: str = '',
handoffs: list = [],
handoffs: list | None = None,
exclude_from_context: bool = False,
mcp_servers: dict = [],
mcp_servers: dict | None = None,
model: str = DEFAULT_MODEL,
model_settings: ModelSettings | None = None,
run_hooks: TaskRunHooks | None = None,
agent_hooks: TaskAgentHooks | None = None):
if handoffs is None:
handoffs = []
if mcp_servers is None:
mcp_servers = {}
client = AsyncOpenAI(base_url=api_endpoint,
api_key=get_AI_token(),
default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID})
Expand All @@ -167,13 +183,13 @@ def __init__(self,
# openai/openai-agents-python/blob/main/examples/agent_patterns

# when we want to exclude tool results from context, we receive results here instead of sending to LLM
def _ToolsToFinalOutputFunction(context: RunContextWrapper[TContext],
def _tools_to_final_output_function(context: RunContextWrapper[TContext],
results: list[FunctionToolResult]) -> ToolsToFinalOutputResult:
return ToolsToFinalOutputResult(True, "Excluding tool results from LLM context")

self.agent = Agent(name=name,
instructions=instructions,
tool_use_behavior=_ToolsToFinalOutputFunction if exclude_from_context else 'run_llm_again',
tool_use_behavior=_tools_to_final_output_function if exclude_from_context else 'run_llm_again',
model=OpenAIChatCompletionsModel(model=model, openai_client=client),
handoffs=handoffs,
mcp_servers=mcp_servers,
Expand Down
23 changes: 12 additions & 11 deletions src/seclab_taskflow_agent/available_tools.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT

from enum import Enum
import logging
import importlib.resources
from enum import Enum

import yaml


class BadToolNameError(Exception):
pass

class VersionException(Exception):
class VersionError(Exception):
pass

class FileTypeException(Exception):
class FileTypeError(Exception):
pass

class AvailableToolType(Enum):
Expand Down Expand Up @@ -71,19 +72,19 @@ def get_tool(self, tooltype: AvailableToolType, toolname: str):
header = y['seclab-taskflow-agent']
version = header['version']
if version != 1:
raise VersionException(str(version))
filetype = header['filetype']
raise VersionError(str(version))
filetype = header['filetype']
if filetype != tooltype.value:
raise FileTypeException(
raise FileTypeError(
f'Error in {f}: expected filetype to be {tooltype}, but it\'s {filetype}.')
if tooltype not in self.__yamlcache:
self.__yamlcache[tooltype] = {}
self.__yamlcache[tooltype][toolname] = y
return y
except ModuleNotFoundError as e:
raise BadToolNameError(f'Cannot load {toolname}: {e}')
except FileNotFoundError:
raise BadToolNameError(f'Cannot load {toolname}: {e}') from e
except FileNotFoundError as e:
# deal with editor temp files etc. that might have disappeared
raise BadToolNameError(f'Cannot load {toolname} because {f} is not a valid file.')
raise BadToolNameError(f'Cannot load {toolname} because {f} is not a valid file.') from e
except ValueError as e:
raise BadToolNameError(f'Cannot load {toolname}: {e}')
raise BadToolNameError(f'Cannot load {toolname}: {e}') from e
16 changes: 9 additions & 7 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-License-Identifier: MIT

# CAPI specific interactions
import httpx
import json
import logging
import os
from strenum import StrEnum
from urllib.parse import urlparse

import httpx
from strenum import StrEnum


# Enumeration of currently supported API endpoints.
class AI_API_ENDPOINT_ENUM(StrEnum):
AI_API_MODELS_GITHUB = 'models.github.ai'
Expand All @@ -35,10 +37,10 @@ def to_url(self):
# but beware that your taskflows need to reference the correct model id
# since different APIs use their own id schema, use -l with your desired
# endpoint to retrieve the correct id names to use for your taskflow
def get_AI_endpoint():
def get_AI_endpoint(): # noqa: N802
return os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')

def get_AI_token():
def get_AI_token(): # noqa: N802
"""
Get the token for the AI API from the environment.
The environment variable can be named either AI_API_TOKEN
Expand Down Expand Up @@ -87,11 +89,11 @@ def list_capi_models(token: str) -> dict[str, dict]:
for model in models_list:
models[model.get('id')] = dict(model)
except httpx.RequestError as e:
logging.error(f"Request error: {e}")
logging.exception("Request error: %s", e)
except json.JSONDecodeError as e:
logging.error(f"JSON error: {e}")
logging.exception("JSON error: %s", e)
except httpx.HTTPStatusError as e:
logging.error(f"HTTP error: {e}")
logging.exception("HTTP error: %s", e)
return models

def supports_tool_calls(model: str, models: dict) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions src/seclab_taskflow_agent/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT

import re
import os
import re


def swap_env(s):
match = re.search(r"{{\s*(env)\s+([A-Z0-9_]+)\s*}}", s)
Expand All @@ -20,7 +21,7 @@ def __enter__(self):
os.environ[k] = swap_env(v)

def __exit__(self, exc_type, exc_val, exc_tb):
for k, v in self.env.items():
for k, _v in self.env.items():
del os.environ[k]
if k in self.restore_env:
os.environ[k] = self.restore_env[k]
Loading