Skip to content

Commit b4b89c0

Browse files
committed
Refactor capi.py: replace scattered match/case with provider pattern
Consolidate endpoint-specific logic (catalog path, response parsing, tool-call detection, headers) into an APIProvider dataclass with a hostname-keyed registry. Adding a new endpoint is now a single registry entry instead of changes across three match/case blocks. - Remove AI_API_ENDPOINT_ENUM StrEnum and strenum dependency - Only send Copilot-Integration-Id header to Copilot endpoints - Accept optional endpoint parameter in public functions - Drop fragile "gpt-" substring heuristic for OpenAI tool-call check - Update tests to use new provider API
1 parent 622d035 commit b4b89c0

File tree

4 files changed

+209
-160
lines changed

4 files changed

+209
-160
lines changed

src/seclab_taskflow_agent/agent.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import os
77
from collections.abc import Callable
88
from typing import Any
9-
from urllib.parse import urlparse
109

1110
from agents import (
1211
Agent,
@@ -26,7 +25,7 @@
2625
from dotenv import find_dotenv, load_dotenv
2726
from openai import AsyncOpenAI
2827

29-
from .capi import AI_API_ENDPOINT_ENUM, COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token
28+
from .capi import get_AI_endpoint, get_AI_token, get_provider
3029

3130
__all__ = [
3231
"DEFAULT_MODEL",
@@ -39,17 +38,8 @@
3938
load_dotenv(find_dotenv(usecwd=True))
4039

4140
api_endpoint = get_AI_endpoint()
42-
match urlparse(api_endpoint).netloc:
43-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
44-
default_model = "gpt-4o"
45-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
46-
default_model = "openai/gpt-4o"
47-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
48-
default_model = "gpt-4o"
49-
case _:
50-
default_model = "please-set-default-model-via-env"
51-
52-
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=default_model)
41+
_default_provider = get_provider(api_endpoint)
42+
DEFAULT_MODEL = os.getenv("COPILOT_DEFAULT_MODEL", default=_default_provider.default_model)
5343

5444

5545
class TaskRunHooks(RunHooks):
@@ -186,10 +176,12 @@ def __init__(
186176
else:
187177
resolved_token = get_AI_token()
188178

179+
# Only send provider-specific headers to matching endpoints
180+
provider = get_provider(resolved_endpoint)
189181
client = AsyncOpenAI(
190182
base_url=resolved_endpoint,
191183
api_key=resolved_token,
192-
default_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
184+
default_headers=provider.extra_headers or None,
193185
)
194186
set_tracing_disabled(True)
195187
self.run_hooks = run_hooks or TaskRunHooks()

src/seclab_taskflow_agent/capi.py

Lines changed: 155 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,142 @@
11
# SPDX-FileCopyrightText: GitHub, Inc.
22
# SPDX-License-Identifier: MIT
33

4-
"""AI API endpoint and token management (CAPI integration)."""
4+
"""AI API endpoint and token management.
5+
6+
Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+
custom endpoints). All provider-specific behaviour is captured in a single
8+
``APIProvider`` dataclass so that adding a new provider only requires one
9+
registry entry instead of changes scattered across multiple match/case blocks.
10+
"""
11+
12+
from __future__ import annotations
513

614
import json
715
import logging
816
import os
17+
from dataclasses import dataclass, field
18+
from typing import Any
919
from urllib.parse import urlparse
1020

1121
import httpx
12-
from strenum import StrEnum
1322

1423
__all__ = [
15-
"AI_API_ENDPOINT_ENUM",
1624
"COPILOT_INTEGRATION_ID",
25+
"APIProvider",
1726
"get_AI_endpoint",
1827
"get_AI_token",
28+
"get_provider",
1929
"list_capi_models",
2030
"list_tool_call_models",
2131
"supports_tool_calls",
2232
]
2333

34+
COPILOT_INTEGRATION_ID = "vscode-chat"
2435

25-
# Enumeration of currently supported API endpoints.
26-
class AI_API_ENDPOINT_ENUM(StrEnum):
27-
AI_API_MODELS_GITHUB = "models.github.ai"
28-
AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29-
AI_API_OPENAI = "api.openai.com"
3036

31-
def to_url(self) -> str:
32-
"""Convert the endpoint to its full URL."""
33-
match self:
34-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
35-
return f"https://{self}"
36-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
37-
return f"https://{self}/inference"
38-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
39-
return f"https://{self}/v1"
40-
case _:
41-
raise ValueError(f"Unsupported endpoint: {self}")
37+
# ---------------------------------------------------------------------------
38+
# Provider abstraction
39+
# ---------------------------------------------------------------------------
4240

41+
@dataclass(frozen=True)
42+
class APIProvider:
43+
"""Encapsulates all endpoint-specific behaviour in one place."""
44+
45+
name: str
46+
base_url: str
47+
models_catalog: str = "models"
48+
default_model: str = "gpt-4o"
49+
extra_headers: dict[str, str] = field(default_factory=dict)
50+
51+
# -- response parsing -----------------------------------------------------
52+
53+
def parse_models_list(self, body: Any) -> list[dict]:
54+
"""Extract the models list from a catalog response body."""
55+
if isinstance(body, list):
56+
return body
57+
if isinstance(body, dict):
58+
return body.get("data", [])
59+
return []
60+
61+
# -- tool-call capability check -------------------------------------------
62+
63+
def check_tool_calls(self, model: str, model_info: dict) -> bool:
64+
"""Return True if *model* supports tool calls according to its catalog entry."""
65+
# Default: optimistically assume support when present in catalog
66+
return bool(model_info)
4367

44-
COPILOT_INTEGRATION_ID = "vscode-chat"
4568

69+
class _CopilotProvider(APIProvider):
70+
"""GitHub Copilot API (api.githubcopilot.com)."""
71+
72+
def check_tool_calls(self, model: str, model_info: dict) -> bool:
73+
return (
74+
model_info
75+
.get("capabilities", {})
76+
.get("supports", {})
77+
.get("tool_calls", False)
78+
)
79+
80+
81+
class _GitHubModelsProvider(APIProvider):
82+
"""GitHub Models API (models.github.ai)."""
83+
84+
def parse_models_list(self, body: Any) -> list[dict]:
85+
# Models API returns a bare list, not {"data": [...]}
86+
if isinstance(body, list):
87+
return body
88+
return super().parse_models_list(body)
89+
90+
def check_tool_calls(self, model: str, model_info: dict) -> bool:
91+
return "tool-calling" in model_info.get("capabilities", [])
92+
93+
94+
# ---------------------------------------------------------------------------
95+
# Provider registry — add new providers here
96+
# ---------------------------------------------------------------------------
97+
98+
_PROVIDERS: dict[str, APIProvider] = {
99+
"api.githubcopilot.com": _CopilotProvider(
100+
name="copilot",
101+
base_url="https://api.githubcopilot.com",
102+
default_model="gpt-4o",
103+
extra_headers={"Copilot-Integration-Id": COPILOT_INTEGRATION_ID},
104+
),
105+
"models.github.ai": _GitHubModelsProvider(
106+
name="github-models",
107+
base_url="https://models.github.ai/inference",
108+
models_catalog="catalog/models",
109+
default_model="openai/gpt-4o",
110+
),
111+
"api.openai.com": APIProvider(
112+
name="openai",
113+
base_url="https://api.openai.com/v1",
114+
default_model="gpt-4o",
115+
),
116+
}
117+
118+
_DEFAULT_PROVIDER = APIProvider(
119+
name="custom",
120+
base_url="", # filled at lookup time
121+
default_model="gpt-4o",
122+
)
123+
124+
125+
def get_provider(endpoint: str | None = None) -> APIProvider:
126+
"""Return the ``APIProvider`` for the given (or configured) endpoint URL."""
127+
url = endpoint or get_AI_endpoint()
128+
netloc = urlparse(url).netloc
129+
provider = _PROVIDERS.get(netloc)
130+
if provider is not None:
131+
return provider
132+
# Unknown endpoint — return a generic provider with the given base URL
133+
return APIProvider(name="custom", base_url=url)
134+
135+
136+
# ---------------------------------------------------------------------------
137+
# Endpoint / token helpers
138+
# ---------------------------------------------------------------------------
46139

47-
# you can also set https://api.githubcopilot.com if you prefer
48-
# but beware that your taskflows need to reference the correct model id
49-
# since different APIs use their own id schema, use -l with your desired
50-
# endpoint to retrieve the correct id names to use for your taskflow
51140
def get_AI_endpoint() -> str:
52141
"""Return the configured AI API endpoint URL."""
53142
return os.getenv("AI_API_ENDPOINT", default="https://models.github.ai/inference")
@@ -64,82 +153,54 @@ def get_AI_token() -> str:
64153
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
65154

66155

67-
# assume we are >= python 3.9 for our type hints
68-
def list_capi_models(token: str) -> dict[str, dict]:
69-
"""Retrieve a dictionary of available CAPI models"""
70-
models = {}
156+
# ---------------------------------------------------------------------------
157+
# Model catalog
158+
# ---------------------------------------------------------------------------
159+
160+
def list_capi_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
161+
"""Retrieve available models from the configured API endpoint.
162+
163+
Args:
164+
token: Bearer token for authentication.
165+
endpoint: Optional endpoint URL override (defaults to env config).
166+
"""
167+
url = endpoint or get_AI_endpoint()
168+
provider = get_provider(url)
169+
models: dict[str, dict] = {}
71170
try:
72-
api_endpoint = get_AI_endpoint()
73-
netloc = urlparse(api_endpoint).netloc
74-
match netloc:
75-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
76-
models_catalog = "models"
77-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
78-
models_catalog = "catalog/models"
79-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
80-
models_catalog = "models"
81-
case _:
82-
# Unknown endpoint — try the OpenAI-style models catalog
83-
models_catalog = "models"
171+
headers = {
172+
"Accept": "application/json",
173+
"Authorization": f"Bearer {token}",
174+
**provider.extra_headers,
175+
}
84176
r = httpx.get(
85-
httpx.URL(api_endpoint).join(models_catalog),
86-
headers={
87-
"Accept": "application/json",
88-
"Authorization": f"Bearer {token}",
89-
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
90-
},
177+
httpx.URL(url).join(provider.models_catalog),
178+
headers=headers,
91179
)
92180
r.raise_for_status()
93-
# CAPI vs Models API
94-
match netloc:
95-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
96-
models_list = r.json().get("data", [])
97-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
98-
models_list = r.json()
99-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
100-
models_list = r.json().get("data", [])
101-
case _:
102-
# Unknown endpoint — try common response shapes
103-
body = r.json()
104-
if isinstance(body, dict):
105-
models_list = body.get("data", [])
106-
elif isinstance(body, list):
107-
models_list = body
108-
else:
109-
models_list = []
110-
for model in models_list:
181+
for model in provider.parse_models_list(r.json()):
111182
models[model.get("id")] = dict(model)
112-
except httpx.RequestError:
113-
logging.exception("Request error")
114-
except json.JSONDecodeError:
115-
logging.exception("JSON error")
116-
except httpx.HTTPStatusError:
117-
logging.exception("HTTP error")
183+
except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError):
184+
logging.exception("Failed to list models from %s", url)
118185
return models
119186

120187

121-
def supports_tool_calls(model: str, models: dict[str, dict]) -> bool:
122-
"""Check whether the given model supports tool calls."""
123-
api_endpoint = get_AI_endpoint()
124-
netloc = urlparse(api_endpoint).netloc
125-
match netloc:
126-
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
127-
return models.get(model, {}).get("capabilities", {}).get("supports", {}).get("tool_calls", False)
128-
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
129-
return "tool-calling" in models.get(model, {}).get("capabilities", [])
130-
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
131-
return "gpt-" in model.lower()
132-
case _:
133-
# Unknown endpoint — optimistically assume tool-call support
134-
# if the model is present in the catalog.
135-
return model in models
136-
137-
138-
def list_tool_call_models(token: str) -> dict[str, dict]:
188+
def supports_tool_calls(
189+
model: str,
190+
models: dict[str, dict],
191+
endpoint: str | None = None,
192+
) -> bool:
193+
"""Check whether *model* supports tool calls."""
194+
provider = get_provider(endpoint)
195+
return provider.check_tool_calls(model, models.get(model, {}))
196+
197+
198+
def list_tool_call_models(token: str, endpoint: str | None = None) -> dict[str, dict]:
139199
"""Return only models that support tool calls."""
140-
models = list_capi_models(token)
141-
tool_models: dict[str, dict] = {}
142-
for model in models:
143-
if supports_tool_calls(model, models) is True:
144-
tool_models[model] = models[model]
145-
return tool_models
200+
models = list_capi_models(token, endpoint)
201+
provider = get_provider(endpoint)
202+
return {
203+
mid: info
204+
for mid, info in models.items()
205+
if provider.check_tool_calls(mid, info)
206+
}

0 commit comments

Comments
 (0)