Skip to content

Commit 1897f51

Browse files
authored
Merge pull request #187 from GitHubSecurityLab/anticomputer/capi-cleanup
Refactor capi.py: provider pattern for endpoint-specific logic
2 parents 78aeae9 + a0e4d3a commit 1897f51

File tree

4 files changed

+238
-158
lines changed

4 files changed

+238
-158
lines changed

src/seclab_taskflow_agent/agent.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import os
77
from collections.abc import Callable
88
from typing import Any
9-
from urllib.parse import urlparse
109

1110
from agents import (
1211
Agent,
@@ -26,7 +25,7 @@
2625
from dotenv import find_dotenv, load_dotenv
2726
from openai import AsyncOpenAI
2827

29-
from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token
28+
from .capi import get_AI_endpoint, get_AI_token, get_provider
3029

3130
__all__ = [
3231
"DEFAULT_MODEL",
@@ -39,17 +38,8 @@
3938
load_dotenv(find_dotenv(usecwd=True))
4039

4140
api_endpoint = get_AI_endpoint()
42-
match urlparse(api_endpoint).netloc:
43-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
44-
default_model = "gpt-4.1"
45-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
46-
default_model = "openai/gpt-4.1"
47-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
48-
default_model = "gpt-4.1"
49-
case _:
50-
default_model = "please-set-default-model-via-env"
51-
52-
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model)
41+
_default_provider = get_provider(api_endpoint)
42+
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model)
5343

5444

5545
class TaskRunHooks(RunHooks):
@@ -186,10 +176,12 @@ def __init__(
186176
else:
187177
resolved_token = get_AI_token()
188178

179+
# Only send provider-specific headers to matching endpoints
180+
provider = get_provider(resolved_endpoint)
189181
client = AsyncOpenAI(
190182
base_url=resolved_endpoint,
191183
api_key=resolved_token,
192-
default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
184+
default_headers=provider.extra_headers or None,
193185
)
194186
set_tracing_disabled(True)
195187
self.run_hooks = run_hooks or TaskRunHooks()

src/seclab_taskflow_agent/capi.py

Lines changed: 172 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,159 @@
11
# SPDX-FileCopyrightText: GitHub, Inc.
22
# SPDX-License-Identifier: MIT
33

4-
"""AI API endpoint and token management (CAPI integration)."""
4+
"""AI API endpoint and token management.
5+
6+
Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+
custom endpoints). All provider-specific behaviour is captured in a single
8+
``APIProvider`` dataclass so that adding a new provider only requires one
9+
registry entry instead of changes scattered across multiple match/case blocks.
10+
"""
11+
12+
from __future__ import annotations
513

614
import json
715
import logging
816
import os
17+
from collections.abc import Mapping
18+
from dataclasses import dataclass, field
19+
from types import MappingProxyType
20+
from typing import Any
921
from urllib.parse import urlparse
1022

1123
import httpx
12-
from strenum import StrEnum
1324

1425
__all__ = [
15-
"AI_API_ENDPOINT_ENUM",
1626
"COPILOT_INTEGRATION_ID",
27+
"APIProvider",
1728
"get_AI_endpoint",
1829
"get_AI_token",
30+
"get_provider",
1931
"list_capi_models",
2032
"list_tool_call_models",
2133
"supports_tool_calls",
2234
]
2335

36+
COPILOT_INTEGRATION_ID = os.getenv("COPILOT_INTEGRATION_ID", "vscode-chat")
37+
38+
39+
# ---------------------------------------------------------------------------
40+
# Provider abstraction
41+
# ---------------------------------------------------------------------------
42+
43+
@dataclass(frozen=True)
44+
class APIProvider:
45+
"""Encapsulates all endpoint-specific behaviour in one place."""
46+
47+
name: str
48+
base_url: str
49+
models_catalog: str = "/models"
50+
default_model: str = "gpt-4.1"
51+
extra_headers: Mapping[str, str] = field(default_factory=dict)
52+
53+
def __post_init__(self) -> None:
54+
# Ensure base_url ends with / so httpx URL.join() preserves the path
55+
if self.base_url and not self.base_url.endswith("/"):
56+
object.__setattr__(self, "base_url", self.base_url + "/")
57+
# Freeze mutable headers so singleton providers can't be mutated
58+
if isinstance(self.extra_headers, dict):
59+
object.__setattr__(self, "extra_headers", MappingProxyType(self.extra_headers))
60+
61+
# -- response parsing -----------------------------------------------------
2462

25-
# Enumeration of currently supported API endpoints.
26-
class AI_API_ENDPOINT_ENUM(StrEnum):
27-
AI_API_MODELS_GITHUB = "models.github.ai"
28-
AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29-
AI_API_OPENAI = "api.openai.com"
63+
def parse_models_list(self, body: Any) -> list[dict]:
64+
"""Extract the models list from a catalog response body."""
65+
if isinstance(body, list):
66+
return body
67+
if isinstance(body, dict):
68+
data = body.get("data", [])
69+
return data if isinstance(data, list) else []
70+
return []
3071

31-
def to_url(self) -> str:
32-
"""Convert the endpoint to its full URL."""
33-
match self:
34-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
35-
return f"https://{self}"
36-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
37-
return f"https://{self}/inference"
38-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
39-
return f"https://{self}/v1"
40-
case _:
41-
raise ValueError(f"Unsupported endpoint: {self}")
72+
# -- tool-call capability check -------------------------------------------
4273

74+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
75+
"""Return True if *model* supports tool calls according to its catalog entry."""
76+
# Default: optimistically assume support when present in catalog
77+
return bool(model_info)
4378

44-
COPILOT_INTEGRATION_ID = "vscode-chat"
4579

80+
class _CopilotProvider(APIProvider):
81+
"""GitHub Copilot API (api.githubcopilot.com)."""
82+
83+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
84+
return (
85+
model_info
86+
.get("capabilities", {})
87+
.get("supports", {})
88+
.get("tool_calls", False)
89+
)
90+
91+
92+
class _GitHubModelsProvider(APIProvider):
93+
"""GitHub Models API (models.github.ai)."""
94+
95+
def parse_models_list(self, body: Any) -> list[dict]:
96+
# Models API returns a bare list, not {"data": [...]}
97+
if isinstance(body, list):
98+
return body
99+
return super().parse_models_list(body)
100+
101+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
102+
return "tool-calling" in model_info.get("capabilities", [])
103+
104+
105+
class _OpenAIProvider(APIProvider):
106+
"""OpenAI API (api.openai.com).
107+
108+
The OpenAI /v1/models catalog does not expose capability metadata, so
109+
we maintain a prefix allowlist of known chat-completion model families.
110+
"""
111+
112+
_CHAT_PREFIXES = ("gpt-3.5", "gpt-4", "o1", "o3", "o4", "chatgpt-")
113+
114+
def check_tool_calls(self, _model: str, model_info: dict) -> bool:
115+
model_id = model_info.get("id", "").lower()
116+
return any(model_id.startswith(p) for p in self._CHAT_PREFIXES)
117+
# ---------------------------------------------------------------------------
118+
# Provider registry — add new providers here
119+
# ---------------------------------------------------------------------------
120+
121+
_PROVIDERS: dict[str, APIProvider] = {
122+
"api.githubcopilot.com": _CopilotProvider(
123+
name="copilot",
124+
base_url="https://api.githubcopilot.com",
125+
default_model="gpt-4.1",
126+
extra_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
127+
),
128+
"models.github.ai": _GitHubModelsProvider(
129+
name="github-models",
130+
base_url="https://models.github.ai/inference",
131+
models_catalog="/catalog/models",
132+
default_model="openai/gpt-4.1",
133+
),
134+
"api.openai.com": _OpenAIProvider(
135+
name="openai",
136+
base_url="https://api.openai.com/v1",
137+
models_catalog="/v1/models",
138+
default_model="gpt-4.1",
139+
),
140+
}
141+
142+
def get_provider(endpoint: str | None = None) -> APIProvider:
143+
"""Return the ``APIProvider`` for the given (or configured) endpoint URL."""
144+
url = endpoint or get_AI_endpoint()
145+
netloc = urlparse(url).netloc
146+
provider = _PROVIDERS.get(netloc)
147+
if provider is not None:
148+
return provider
149+
# Unknown endpoint — return a generic provider with the given base URL
150+
return APIProvider(name="custom", base_url=url, default_model="please-set-default-model-via-env")
151+
152+
153+
# ---------------------------------------------------------------------------
154+
# Endpoint / token helpers
155+
# ---------------------------------------------------------------------------
46156

47-
# you can also set https://api.githubcopilot.com if you prefer
48-
# but beware that your taskflows need to reference the correct model id
49-
# since different APIs use their own id schema, use -l with your desired
50-
# endpoint to retrieve the correct id names to use for your taskflow
51157
def get_AI_endpoint() -> str:
52158
"""Return the configured AI API endpoint URL."""
53159
return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference")
@@ -64,82 +170,54 @@ def get_AI_token() -> str:
64170
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
65171

66172

67-
# assume we are >= python 3.9 for our type hints
68-
def list_capi_models(token: str) -> dict[str, dict]:
69-
"""Retrieve a dictionary of available CAPI models"""
70-
models = {}
173+
# ---------------------------------------------------------------------------
174+
# Model catalog
175+
# ---------------------------------------------------------------------------
176+
177+
def list_capi_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
178+
"""Retrieve available models from the configured API endpoint.
179+
180+
Args:
181+
token: Bearer token for authentication.
182+
endpoint: Optional endpoint URL override (defaults to env config).
183+
"""
184+
provider = get_provider(endpoint)
185+
base = provider.base_url
186+
models: dict[str, dict] = {}
71187
try:
72-
api_endpoint = get_AI_endpoint()
73-
netloc = urlparse(api_endpoint).netloc
74-
match netloc:
75-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
76-
models_catalog = "models"
77-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
78-
models_catalog = "catalog/models"
79-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
80-
models_catalog = "models"
81-
case _:
82-
# Unknown endpoint — try the OpenAI-style models catalog
83-
models_catalog = "models"
188+
headers = {
189+
"Accept": "application/json",
190+
"Authorization": f"Bearer {token}",
191+
**provider.extra_headers,
192+
}
84193
r = httpx.get(
85-
httpx.URL(api_endpoint).join(models_catalog),
86-
headers={
87-
"Accept": "application/json",
88-
"Authorization": f"Bearer {token}",
89-
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
90-
},
194+
httpx.URL(base).join(provider.models_catalog),
195+
headers=headers,
91196
)
92197
r.raise_for_status()
93-
# CAPI vs Models API
94-
match netloc:
95-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
96-
models_list = r.json().get("data", [])
97-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
98-
models_list = r.json()
99-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
100-
models_list = r.json().get("data", [])
101-
case _:
102-
# Unknown endpoint — try common response shapes
103-
body = r.json()
104-
if isinstance(body, dict):
105-
models_list = body.get("data", [])
106-
elif isinstance(body, list):
107-
models_list = body
108-
else:
109-
models_list = []
110-
for model in models_list:
198+
for model in provider.parse_models_list(r.json()):
111199
models[model.get("id")] = dict(model)
112-
except httpx.RequestError:
113-
logging.exception("Request error")
114-
except json.JSONDecodeError:
115-
logging.exception("JSON error")
116-
except httpx.HTTPStatusError:
117-
logging.exception("HTTP error")
200+
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError):
201+
logging.exception("Failed to list models from %s", base)
118202
return models
119203

120204

121-
def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
122-
"""Check whether the given model supports tool calls."""
123-
api_endpoint = get_AI_endpoint()
124-
netloc = urlparse(api_endpoint).netloc
125-
match netloc:
126-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
127-
return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False)
128-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
129-
return "tool-calling" in models.get(model, {}).get("capabilities", [])
130-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
131-
return "gpt-" in model.lower()
132-
case _:
133-
# Unknown endpoint — optimistically assume tool-call support
134-
# if the model is present in the catalog.
135-
return model in models
136-
137-
138-
def list_tool_call_models(token: str) -> dict[str, dict]:
205+
def supports_tool_calls(
206+
model: str,
207+
models: dict[str, dict],
208+
endpoint: str | None = None,
209+
) -> bool:
210+
"""Check whether *model* supports tool calls."""
211+
provider = get_provider(endpoint)
212+
return provider.check_tool_calls(model, models.get(model, {}))
213+
214+
215+
def list_tool_call_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
139216
"""Return only models that support tool calls."""
140-
models = list_capi_models(token)
141-
tool_models: dict[str, dict] = {}
142-
for model in models:
143-
if supports_tool_calls(model, models) is True:
144-
tool_models[model] = models[model]
145-
return tool_models
217+
models = list_capi_models(token, endpoint)
218+
provider = get_provider(endpoint)
219+
return {
220+
mid: info
221+
for mid, info in models.items()
222+
if provider.check_tool_calls(mid, info)
223+
}

0 commit comments

Comments
 (0)