Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
from collections.abc import Callable
from typing import Any
from urllib.parse import urlparse

from agents import (
Agent,
Expand All @@ -26,7 +25,7 @@
from dotenv import find_dotenv, load_dotenv
from openai import AsyncOpenAI

from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token
from .capi import get_AI_endpoint, get_AI_token, get_provider

__all__ = [
"DEFAULT_MODEL",
Expand All @@ -39,17 +38,8 @@
load_dotenv(find_dotenv(usecwd=True))

api_endpoint = get_AI_endpoint()
match urlparse(api_endpoint).netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
default_model = "gpt-4.1"
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
default_model = "openai/gpt-4.1"
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
default_model = "gpt-4.1"
case _:
default_model = "please-set-default-model-via-env"

DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model)
_default_provider = get_provider(api_endpoint)
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model)


class TaskRunHooks(RunHooks):
Expand Down Expand Up @@ -186,10 +176,12 @@ def __init__(
else:
resolved_token = get_AI_token()

# Only send provider-specific headers to matching endpoints
provider = get_provider(resolved_endpoint)
client = AsyncOpenAI(
base_url=resolved_endpoint,
api_key=resolved_token,
default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
default_headers=provider.extra_headers or None,
)
set_tracing_disabled(True)
self.run_hooks = run_hooks or TaskRunHooks()
Expand Down
266 changes: 172 additions & 94 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,159 @@
# SPDX-FileCopyrightText: GitHub, Inc.
# SPDX-License-Identifier: MIT

"""AI API endpoint and token management (CAPI integration)."""
"""AI API endpoint and token management.

Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
custom endpoints). All provider-specific behaviour is captured in a single
``APIProvider`` dataclass so that adding a new provider only requires one
registry entry instead of changes scattered across multiple match/case blocks.
"""

from __future__ import annotations

import json
import logging
import os
from collections.abc import Mapping
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any
from urllib.parse import urlparse

import httpx
from strenum import StrEnum

__all__ = [
"AI_API_ENDPOINT_ENUM",
"COPILOT_INTEGRATION_ID",
"APIProvider",
"get_AI_endpoint",
"get_AI_token",
"get_provider",
"list_capi_models",
"list_tool_call_models",
"supports_tool_calls",
]

COPILOT_INTEGRATION_ID = os.getenv("COPILOT_INTEGRATION_ID", "vscode-chat")


# ---------------------------------------------------------------------------
# Provider abstraction
# ---------------------------------------------------------------------------

@dataclass(frozen=True)
class APIProvider:
"""Encapsulates all endpoint-specific behaviour in one place."""

name: str
base_url: str
models_catalog: str = "/models"
default_model: str = "gpt-4.1"
extra_headers: Mapping[str, str] = field(default_factory=dict)

def __post_init__(self) -> None:
# Ensure base_url ends with / so httpx URL.join() preserves the path
if self.base_url and not self.base_url.endswith("/"):
object.__setattr__(self, "base_url", self.base_url + "/")
# Freeze mutable headers so singleton providers can't be mutated
if isinstance(self.extra_headers, dict):
object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers))

# -- response parsing -----------------------------------------------------

# Enumeration of currently supported API endpoints.
class AI_API_ENDPOINT_ENUM(StrEnum):
AI_API_MODELS_GITHUB = "models.github.ai"
AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
AI_API_OPENAI = "api.openai.com"
def parse_models_list(self, body: Any) -> list[dict]:
"""Extract the models list from a catalog response body."""
if isinstance(body, list):
return body
if isinstance(body, dict):
data = body.get("data", [])
return data if isinstance(data, list) else []
return []

def to_url(self) -> str:
"""Convert the endpoint to its full URL."""
match self:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
return f"https://{self}"
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
return f"https://{self}/inference"
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
return f"https://{self}/v1"
case _:
raise ValueError(f"Unsupported endpoint: {self}")
# -- tool-call capability check -------------------------------------------

def check_tool_calls(self, _model: str, model_info: dict) -> bool:
"""Return True if *model* supports tool calls according to its catalog entry."""
# Default: optimistically assume support when present in catalog
return bool(model_info)

COPILOT_INTEGRATION_ID = "vscode-chat"

class _CopilotProvider(APIProvider):
"""GitHub Copilot API (api.githubcopilot.com)."""

def check_tool_calls(self, _model: str, model_info: dict) -> bool:
return (
model_info
.get("capabilities", {})
.get("supports", {})
.get("tool_calls", False)
)


class _GitHubModelsProvider(APIProvider):
"""GitHub Models API (models.github.ai)."""

def parse_models_list(self, body: Any) -> list[dict]:
# Models API returns a bare list, not {"data": [...]}
if isinstance(body, list):
return body
return super().parse_models_list(body)

def check_tool_calls(self, _model: str, model_info: dict) -> bool:
return "tool-calling" in model_info.get("capabilities", [])


class _OpenAIProvider(APIProvider):
"""OpenAI API (api.openai.com).

The OpenAI /v1/models catalog does not expose capability metadata, so
we maintain a prefix allowlist of known chat-completion model families.
"""

_CHAT_PREFIXES = ("gpt-3.5", "gpt-4", "o1", "o3", "o4", "chatgpt-")

def check_tool_calls(self, _model: str, model_info: dict) -> bool:
model_id = model_info.get("id", "").lower()
return any(model_id.startswith(p) for p in self._CHAT_PREFIXES)
# ---------------------------------------------------------------------------
# Provider registry — add new providers here
# ---------------------------------------------------------------------------

_PROVIDERS: dict[str, APIProvider] = {
"api.githubcopilot.com": _CopilotProvider(
name="copilot",
base_url="https://api.githubcopilot.com",
default_model="gpt-4.1",
extra_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
),
"models.github.ai": _GitHubModelsProvider(
name="github-models",
base_url="https://models.github.ai/inference",
models_catalog="/catalog/models",
default_model="openai/gpt-4.1",
),
"api.openai.com": _OpenAIProvider(
name="openai",
base_url="https://api.openai.com/v1",
models_catalog="/v1/models",
default_model="gpt-4.1",
),
}

def get_provider(endpoint: str | None = None) -> APIProvider:
"""Return the ``APIProvider`` for the given (or configured) endpoint URL."""
url = endpoint or get_AI_endpoint()
netloc = urlparse(url).netloc
provider = _PROVIDERS.get(netloc)
if provider is not None:
return provider
# Unknown endpoint — return a generic provider with the given base URL
return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env")


# ---------------------------------------------------------------------------
# Endpoint / token helpers
# ---------------------------------------------------------------------------

# you can also set https://api.githubcopilot.com if you prefer
# but beware that your taskflows need to reference the correct model id
# since different APIs use their own id schema, use -l with your desired
# endpoint to retrieve the correct id names to use for your taskflow
def get_AI_endpoint() -> str:
"""Return the configured AI API endpoint URL."""
return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference")
Expand All @@ -64,82 +170,54 @@ def get_AI_token() -> str:
raise RuntimeError("AI_API_TOKEN environment variable is not set.")


# assume we are >= python 3.9 for our type hints
def list_capi_models(token: str) -> dict[str, dict]:
"""Retrieve a dictionary of available CAPI models"""
models = {}
# ---------------------------------------------------------------------------
# Model catalog
# ---------------------------------------------------------------------------

def list_capi_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
"""Retrieve available models from the configured API endpoint.

Args:
token: Bearer token for authentication.
endpoint: Optional endpoint URL override (defaults to env config).
"""
provider = get_provider(endpoint)
base = provider.base_url
models: dict[str, dict] = {}
try:
api_endpoint = get_AI_endpoint()
netloc = urlparse(api_endpoint).netloc
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
models_catalog = "models"
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_catalog = "catalog/models"
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
models_catalog = "models"
case _:
# Unknown endpoint — try the OpenAI-style models catalog
models_catalog = "models"
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {token}",
**provider.extra_headers,
}
r = httpx.get(
httpx.URL(api_endpoint).join(models_catalog),
headers={
"Accept": "application/json",
"Authorization": f"Bearer {token}",
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
},
httpx.URL(base).join(provider.models_catalog),
headers=headers,
)
r.raise_for_status()
# CAPI vs Models API
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
models_list = r.json().get("data", [])
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_list = r.json()
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
models_list = r.json().get("data", [])
case _:
# Unknown endpoint — try common response shapes
body = r.json()
if isinstance(body, dict):
models_list = body.get("data", [])
elif isinstance(body, list):
models_list = body
else:
models_list = []
for model in models_list:
for model in provider.parse_models_list(r.json()):
models[model.get("id")] = dict(model)
except httpx.RequestError:
logging.exception("Request error")
except json.JSONDecodeError:
logging.exception("JSON error")
except httpx.HTTPStatusError:
logging.exception("HTTP error")
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError):
logging.exception("Failed to list models from %s", base)
return models


def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
"""Check whether the given model supports tool calls."""
api_endpoint = get_AI_endpoint()
netloc = urlparse(api_endpoint).netloc
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False)
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
return "tool-calling" in models.get(model, {}).get("capabilities", [])
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
return "gpt-" in model.lower()
case _:
# Unknown endpoint — optimistically assume tool-call support
# if the model is present in the catalog.
return model in models


def list_tool_call_models(token: str) -> dict[str, dict]:
def supports_tool_calls(
model: str,
models: dict[str, dict],
endpoint: str | None = None,
) -> bool:
"""Check whether *model* supports tool calls."""
provider = get_provider(endpoint)
return provider.check_tool_calls(model, models.get(model, {}))


def list_tool_call_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
"""Return only models that support tool calls."""
models = list_capi_models(token)
tool_models: dict[str, dict] = {}
for model in models:
if supports_tool_calls(model, models) is True:
tool_models[model] = models[model]
return tool_models
models = list_capi_models(token, endpoint)
provider = get_provider(endpoint)
return {
mid: info
for mid, info in models.items()
if provider.check_tool_calls(mid, info)
}
Loading
Loading