Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ dependencies = [
"starlette==0.49.1",
"strenum==0.4.15",
"tqdm==4.67.1",
"tenacity>=8.0.0",
"typer==0.16.0",
"types-requests==2.32.4.20250611",
"typing-inspection==0.4.1",
Expand Down
20 changes: 6 additions & 14 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,16 @@
)
from agents.agent import FunctionToolResult, ModelSettings, ToolsToFinalOutputResult
from agents.run import DEFAULT_MAX_TURNS
from dotenv import find_dotenv, load_dotenv
from openai import AsyncOpenAI

from .capi import get_AI_endpoint, get_AI_token, get_provider
from .capi import get_AI_endpoint, get_AI_token, get_default_model, get_provider

__all__ = [
"DEFAULT_MODEL",
"TaskAgent",
"TaskAgentHooks",
"TaskRunHooks",
]

# grab our secrets from .env, this must be in .gitignore
load_dotenv(find_dotenv(usecwd=True))

api_endpoint = get_AI_endpoint()
_default_provider = get_provider(api_endpoint)
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model)


class TaskRunHooks(RunHooks):
"""RunHooks that monitor the entire lifetime of a runner, including across Agent handoffs."""
Expand Down Expand Up @@ -152,7 +143,7 @@ def __init__(
handoffs: list[Any] | None = None,
exclude_from_context: bool = False,
mcp_servers: list[Any] | None = None,
model: str = DEFAULT_MODEL,
model: str | None = None,
model_settings: ModelSettings | None = None,
api_type: str = "chat_completions",
endpoint: str | None = None,
Expand All @@ -168,7 +159,8 @@ def __init__(
token: Optional env var name whose value is used as the API key.
"""
# Resolve per-model endpoint and token, falling back to defaults
resolved_endpoint = endpoint or api_endpoint
resolved_endpoint = endpoint or get_AI_endpoint()
resolved_model = model or get_default_model(resolved_endpoint)
if token:
resolved_token = os.getenv(token, "")
if not resolved_token:
Expand All @@ -194,9 +186,9 @@ def _ToolsToFinalOutputFunction(

# Select model class based on api_type
if api_type == "responses":
model_impl = OpenAIResponsesModel(model=model, openai_client=client)
model_impl = OpenAIResponsesModel(model=resolved_model, openai_client=client)
else:
model_impl = OpenAIChatCompletionsModel(model=model, openai_client=client)
model_impl = OpenAIChatCompletionsModel(model=resolved_model, openai_client=client)

self.agent = Agent(
name=name,
Expand Down
6 changes: 6 additions & 0 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"APIProvider",
"get_AI_endpoint",
"get_AI_token",
"get_default_model",
"get_provider",
"list_capi_models",
"list_tool_call_models",
Expand Down Expand Up @@ -173,6 +174,11 @@ def get_provider(endpoint: str | None = None) -> APIProvider:
# Unknown endpoint — return a generic provider with the given base URL
return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env")

def get_default_model(endpoint: str | None = None) -> str:
"""Return the default model for the given endpoint, allowing env overrides."""
provider = get_provider(endpoint)
return os.getenv("COPILOT_DEFAULT_MODEL", default=provider.default_model)


# ---------------------------------------------------------------------------
# Endpoint / token helpers
Expand Down
4 changes: 4 additions & 0 deletions src/seclab_taskflow_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from typing import Annotated

import typer
from dotenv import find_dotenv, load_dotenv

# Load .env early, before any provider/env-var reads happen in imported modules.
load_dotenv(find_dotenv(usecwd=True))

from .available_tools import AvailableTools
from .banner import get_banner
Expand Down
145 changes: 90 additions & 55 deletions src/seclab_taskflow_agent/mcp_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

from __future__ import annotations

__all__ = ["MCP_CLEANUP_TIMEOUT", "build_mcp_servers", "mcp_session_task"]
__all__ = ["MCP_CLEANUP_TIMEOUT", "build_mcp_servers", "mcp_session_task", "register_transport"]

import asyncio
import logging
from typing import TYPE_CHECKING

from typing import TYPE_CHECKING, Any, Callable

from agents.mcp import MCPServerSse, MCPServerStdio, MCPServerStreamableHttp, create_static_tool_filter

Expand Down Expand Up @@ -41,6 +42,84 @@ def __init__(self, server: MCPNamespaceWrap, process: StreamableMCPThread | None
self.name = name


# Type alias for a builder function
MCPServerBuilder = Callable[[str, dict[str, Any], Any, int, list[str]], MCPServerEntry]

MCP_TRANSPORT_REGISTRY: dict[str, MCPServerBuilder] = {}


def register_transport(kind: str) -> Callable[[MCPServerBuilder], MCPServerBuilder]:
"""Decorator to register an MCP transport builder."""
def decorator(builder: MCPServerBuilder) -> MCPServerBuilder:
Comment thread
pi-2r marked this conversation as resolved.
if kind in MCP_TRANSPORT_REGISTRY:
raise ValueError(
f"MCP transport {kind!r} is already registered by {MCP_TRANSPORT_REGISTRY[kind].__name__!r}. "
"Use a unique kind name or unregister the existing builder first."
Comment thread
pi-2r marked this conversation as resolved.
Outdated
)
MCP_TRANSPORT_REGISTRY[kind] = builder
return builder
return decorator


@register_transport("stdio")
def _build_stdio(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
if params.get("reconnecting", False):
mcp_server = ReconnectingMCPServerStdio(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
cache_tools_list=True,
)
else:
mcp_server = MCPServerStdio(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
cache_tools_list=True,
)
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), None, name=tb)


@register_transport("sse")
def _build_sse(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
mcp_server = MCPServerSse(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
)
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), None, name=tb)


@register_transport("streamable")
def _build_streamable(tb: str, params: dict[str, Any], tool_filter: Any, client_session_timeout: int, confirms: list[str]) -> MCPServerEntry:
server_proc = None
if "command" in params:

def _print_out(line: str) -> None:
logging.info(f"Streamable MCP Server stdout: {line}")

def _print_err(line: str) -> None:
logging.info(f"Streamable MCP Server stderr: {line}")

server_proc = StreamableMCPThread(
params["command"],
url=params["url"],
env=params["env"],
on_output=_print_out,
on_error=_print_err,
)
mcp_server = MCPServerStreamableHttp(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
)
return MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb)


def build_mcp_servers(
available_tools: AvailableTools,
toolboxes: list[str],
Expand All @@ -66,59 +145,15 @@ def build_mcp_servers(
if headless:
confirms = []
client_session_timeout = client_session_timeout or DEFAULT_MCP_CLIENT_SESSION_TIMEOUT
server_proc = None

match params["kind"]:
case "stdio":
if params.get("reconnecting", False):
mcp_server = ReconnectingMCPServerStdio(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
cache_tools_list=True,
)
else:
mcp_server = MCPServerStdio(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
cache_tools_list=True,
)
case "sse":
mcp_server = MCPServerSse(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
)
case "streamable":
if "command" in params:

def _print_out(line: str) -> None:
logging.info(f"Streamable MCP Server stdout: {line}")

def _print_err(line: str) -> None:
logging.info(f"Streamable MCP Server stderr: {line}")

server_proc = StreamableMCPThread(
params["command"],
url=params["url"],
env=params["env"],
on_output=_print_out,
on_error=_print_err,
)
mcp_server = MCPServerStreamableHttp(
name=tb,
params=params,
tool_filter=tool_filter,
client_session_timeout_seconds=client_session_timeout,
)
case _:
raise ValueError(f"Unsupported MCP transport: {params['kind']}")

entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb))
kind = params.get("kind")
if kind is None:
raise ValueError(f"Missing 'kind' key in MCP params for toolbox {tb!r}")
builder = MCP_TRANSPORT_REGISTRY.get(kind)
if builder is None:
raise ValueError(f"Unsupported MCP transport: {kind!r}")

entry = builder(tb, params, tool_filter, client_session_timeout, confirms)
entries.append(entry)

return entries

Expand Down
99 changes: 99 additions & 0 deletions src/seclab_taskflow_agent/model_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: GitHub, Inc.
# SPDX-License-Identifier: MIT

"""Model resolution logic for Taskflow.

Extracts model configuration and task-level overrides.
"""

from __future__ import annotations

from typing import Any

from .available_tools import AvailableTools
from .capi import get_default_model
from .models import ModelConfigDocument, TaskDefinition


def resolve_model_config(
available_tools: AvailableTools,
model_config_ref: str,
) -> tuple[list[str], dict[str, str], dict[str, dict[str, Any]], str]:
"""Load and validate the model configuration file.

Args:
available_tools: Tool registry used to load the config file.
model_config_ref: Reference name for the model config document.

Returns:
A tuple of (model_keys, model_dict, models_params, api_type) where
model_keys is the list of logical model names, model_dict maps them
to provider model IDs, models_params holds per-model settings, and
api_type is ``"chat_completions"`` or ``"responses"``.

Raises:
ValueError: If the config file has structural problems.
"""
m_config: ModelConfigDocument = available_tools.get_model_config(model_config_ref)
model_dict: dict[str, str] = m_config.models or {}
model_keys: list[str] = list(model_dict.keys())
models_params: dict[str, dict[str, Any]] = m_config.model_settings or {}
unknown = set(models_params) - set(model_keys)
if unknown:
raise ValueError(
f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}"
)
return model_keys, model_dict, models_params, m_config.api_type


def resolve_task_model(
task: TaskDefinition,
model_keys: list[str],
model_dict: dict[str, str],
models_params: dict[str, dict[str, Any]],
default_api_type: str = "chat_completions",
) -> tuple[str, dict[str, Any], str, str | None, str | None]:
"""Resolve the final model name, settings, and per-model overrides.

Returns:
A tuple of ``(model_id, model_settings, api_type, endpoint, token)``
where *endpoint* and *token* are ``None`` when not overridden.

Raises:
ValueError: If task-level model_settings is not a dictionary.
"""
model_settings: dict[str, Any] = {}
api_type: str = default_api_type
endpoint: str | None = None
token: str | None = None

# Step 1: Peek at task-level settings to extract the endpoint override
# *before* resolving the default model, so get_default_model() receives
# the correct endpoint and picks the right provider's default model.
task_model_settings: dict[str, Any] | Any = task.model_settings or {}
if not isinstance(task_model_settings, dict):
raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary")
task_settings = dict(task_model_settings)
preliminary_endpoint: str | None = task_settings.get("endpoint", None)

# Step 2: Resolve the logical model name, using the endpoint-aware default
logical_name: str = task.model or get_default_model(preliminary_endpoint)

# Step 3: Look up config-level settings for this logical name
if logical_name in model_keys:
if logical_name in models_params:
model_settings = models_params[logical_name].copy()
logical_name = model_dict[logical_name]

# Step 4: Extract engine-level keys from config settings
api_type = model_settings.pop("api_type", api_type)
endpoint = model_settings.pop("endpoint", None)
token = model_settings.pop("token", None)

# Step 5: Apply task-level overrides (task wins over config)
api_type = task_settings.pop("api_type", api_type)
endpoint = task_settings.pop("endpoint", endpoint)
token = task_settings.pop("token", token)

model_settings.update(task_settings)
return logical_name, model_settings, api_type, endpoint, token
Loading
Loading