11# SPDX-FileCopyrightText: GitHub, Inc.
22# SPDX-License-Identifier: MIT
33
4- """AI API endpoint and token management (CAPI integration)."""
4+ """AI API endpoint and token management.
5+
6+ Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+ custom endpoints). All provider-specific behaviour is captured in a single
8+ ``APIProvider`` dataclass so that adding a new provider only requires one
9+ registry entry instead of changes scattered across multiple match/case blocks.
10+ """
11+
12+ from __future__ import annotations
513
614import json
715import logging
816import os
17+ from dataclasses import dataclass , field
18+ from typing import Any
919from urllib .parse import urlparse
1020
1121import httpx
12- from strenum import StrEnum
1322
1423__all__ = [
15- "AI_API_ENDPOINT_ENUM" ,
1624 "COPILOT_INTEGRATION_ID" ,
25+ "APIProvider" ,
1726 "get_AI_endpoint" ,
1827 "get_AI_token" ,
28+ "get_provider" ,
1929 "list_capi_models" ,
2030 "list_tool_call_models" ,
2131 "supports_tool_calls" ,
2232]
2333
34+ COPILOT_INTEGRATION_ID = "vscode-chat"
2435
25- # Enumeration of currently supported API endpoints.
26- class AI_API_ENDPOINT_ENUM (StrEnum ):
27- AI_API_MODELS_GITHUB = "models.github.ai"
28- AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29- AI_API_OPENAI = "api.openai.com"
3036
31- def to_url (self ) -> str :
32- """Convert the endpoint to its full URL."""
33- match self :
34- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
35- return f"https://{ self } "
36- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
37- return f"https://{ self } /inference"
38- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
39- return f"https://{ self } /v1"
40- case _:
41- raise ValueError (f"Unsupported endpoint: { self } " )
37+ # ---------------------------------------------------------------------------
38+ # Provider abstraction
39+ # ---------------------------------------------------------------------------
4240
41+ @dataclass (frozen = True )
42+ class APIProvider :
43+ """Encapsulates all endpoint-specific behaviour in one place."""
44+
45+ name : str
46+ base_url : str
47+ models_catalog : str = "models"
48+ default_model : str = "gpt-4o"
49+ extra_headers : dict [str , str ] = field (default_factory = dict )
50+
51+ # -- response parsing -----------------------------------------------------
52+
53+ def parse_models_list (self , body : Any ) -> list [dict ]:
54+ """Extract the models list from a catalog response body."""
55+ if isinstance (body , list ):
56+ return body
57+ if isinstance (body , dict ):
58+ return body .get ("data" , [])
59+ return []
60+
61+ # -- tool-call capability check -------------------------------------------
62+
63+ def check_tool_calls (self , model : str , model_info : dict ) -> bool :
64+ """Return True if *model* supports tool calls according to its catalog entry."""
65+ # Default: optimistically assume support when present in catalog
66+ return bool (model_info )
4367
44- COPILOT_INTEGRATION_ID = "vscode-chat"
4568
69+ class _CopilotProvider (APIProvider ):
70+ """GitHub Copilot API (api.githubcopilot.com)."""
71+
72+ def check_tool_calls (self , model : str , model_info : dict ) -> bool :
73+ return (
74+ model_info
75+ .get ("capabilities" , {})
76+ .get ("supports" , {})
77+ .get ("tool_calls" , False )
78+ )
79+
80+
81+ class _GitHubModelsProvider (APIProvider ):
82+ """GitHub Models API (models.github.ai)."""
83+
84+ def parse_models_list (self , body : Any ) -> list [dict ]:
85+ # Models API returns a bare list, not {"data": [...]}
86+ if isinstance (body , list ):
87+ return body
88+ return super ().parse_models_list (body )
89+
90+ def check_tool_calls (self , model : str , model_info : dict ) -> bool :
91+ return "tool-calling" in model_info .get ("capabilities" , [])
92+
93+
94+ # ---------------------------------------------------------------------------
95+ # Provider registry — add new providers here
96+ # ---------------------------------------------------------------------------
97+
98+ _PROVIDERS : dict [str , APIProvider ] = {
99+ "api.githubcopilot.com" : _CopilotProvider (
100+ name = "copilot" ,
101+ base_url = "https://api.githubcopilot.com" ,
102+ default_model = "gpt-4o" ,
103+ extra_headers = {"Copilot-Integration-Id" : COPILOT_INTEGRATION_ID },
104+ ),
105+ "models.github.ai" : _GitHubModelsProvider (
106+ name = "github-models" ,
107+ base_url = "https://models.github.ai/inference" ,
108+ models_catalog = "catalog/models" ,
109+ default_model = "openai/gpt-4o" ,
110+ ),
111+ "api.openai.com" : APIProvider (
112+ name = "openai" ,
113+ base_url = "https://api.openai.com/v1" ,
114+ default_model = "gpt-4o" ,
115+ ),
116+ }
117+
118+ _DEFAULT_PROVIDER = APIProvider (
119+ name = "custom" ,
120+ base_url = "" , # filled at lookup time
121+ default_model = "gpt-4o" ,
122+ )
123+
124+
125+ def get_provider (endpoint : str | None = None ) -> APIProvider :
126+ """Return the ``APIProvider`` for the given (or configured) endpoint URL."""
127+ url = endpoint or get_AI_endpoint ()
128+ netloc = urlparse (url ).netloc
129+ provider = _PROVIDERS .get (netloc )
130+ if provider is not None :
131+ return provider
132+ # Unknown endpoint — return a generic provider with the given base URL
133+ return APIProvider (name = "custom" , base_url = url )
134+
135+
136+ # ---------------------------------------------------------------------------
137+ # Endpoint / token helpers
138+ # ---------------------------------------------------------------------------
46139
47- # you can also set https://api.githubcopilot.com if you prefer
48- # but beware that your taskflows need to reference the correct model id
49- # since different APIs use their own id schema, use -l with your desired
50- # endpoint to retrieve the correct id names to use for your taskflow
51140def get_AI_endpoint () -> str :
52141 """Return the configured AI API endpoint URL."""
53142 return os .getenv ("AI_API_ENDPOINT" , default = "https://models.github.ai/inference" )
@@ -64,82 +153,54 @@ def get_AI_token() -> str:
64153 raise RuntimeError ("AI_API_TOKEN environment variable is not set." )
65154
66155
67- # assume we are >= python 3.9 for our type hints
68- def list_capi_models (token : str ) -> dict [str , dict ]:
69- """Retrieve a dictionary of available CAPI models"""
70- models = {}
156+ # ---------------------------------------------------------------------------
157+ # Model catalog
158+ # ---------------------------------------------------------------------------
159+
160+ def list_capi_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
161+ """Retrieve available models from the configured API endpoint.
162+
163+ Args:
164+ token: Bearer token for authentication.
165+ endpoint: Optional endpoint URL override (defaults to env config).
166+ """
167+ url = endpoint or get_AI_endpoint ()
168+ provider = get_provider (url )
169+ models : dict [str , dict ] = {}
71170 try :
72- api_endpoint = get_AI_endpoint ()
73- netloc = urlparse (api_endpoint ).netloc
74- match netloc :
75- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
76- models_catalog = "models"
77- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
78- models_catalog = "catalog/models"
79- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
80- models_catalog = "models"
81- case _:
82- # Unknown endpoint — try the OpenAI-style models catalog
83- models_catalog = "models"
171+ headers = {
172+ "Accept" : "application/json" ,
173+ "Authorization" : f"Bearer { token } " ,
174+ ** provider .extra_headers ,
175+ }
84176 r = httpx .get (
85- httpx .URL (api_endpoint ).join (models_catalog ),
86- headers = {
87- "Accept" : "application/json" ,
88- "Authorization" : f"Bearer { token } " ,
89- "Copilot-Integration-Id" : COPILOT_INTEGRATION_ID ,
90- },
177+ httpx .URL (url ).join (provider .models_catalog ),
178+ headers = headers ,
91179 )
92180 r .raise_for_status ()
93- # CAPI vs Models API
94- match netloc :
95- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
96- models_list = r .json ().get ("data" , [])
97- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
98- models_list = r .json ()
99- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
100- models_list = r .json ().get ("data" , [])
101- case _:
102- # Unknown endpoint — try common response shapes
103- body = r .json ()
104- if isinstance (body , dict ):
105- models_list = body .get ("data" , [])
106- elif isinstance (body , list ):
107- models_list = body
108- else :
109- models_list = []
110- for model in models_list :
181+ for model in provider .parse_models_list (r .json ()):
111182 models [model .get ("id" )] = dict (model )
112- except httpx .RequestError :
113- logging .exception ("Request error" )
114- except json .JSONDecodeError :
115- logging .exception ("JSON error" )
116- except httpx .HTTPStatusError :
117- logging .exception ("HTTP error" )
183+ except (httpx .RequestError , httpx .HTTPStatusError , json .JSONDecodeError ):
184+ logging .exception ("Failed to list models from %s" , url )
118185 return models
119186
120187
121- def supports_tool_calls (model : str , models : dict [str , dict ]) -> bool :
122- """Check whether the given model supports tool calls."""
123- api_endpoint = get_AI_endpoint ()
124- netloc = urlparse (api_endpoint ).netloc
125- match netloc :
126- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
127- return models .get (model , {}).get ("capabilities" , {}).get ("supports" , {}).get ("tool_calls" , False )
128- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
129- return "tool-calling" in models .get (model , {}).get ("capabilities" , [])
130- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
131- return "gpt-" in model .lower ()
132- case _:
133- # Unknown endpoint — optimistically assume tool-call support
134- # if the model is present in the catalog.
135- return model in models
136-
137-
138- def list_tool_call_models (token : str ) -> dict [str , dict ]:
188+ def supports_tool_calls (
189+ model : str ,
190+ models : dict [str , dict ],
191+ endpoint : str | None = None ,
192+ ) -> bool :
193+ """Check whether *model* supports tool calls."""
194+ provider = get_provider (endpoint )
195+ return provider .check_tool_calls (model , models .get (model , {}))
196+
197+
198+ def list_tool_call_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
139199 """Return only models that support tool calls."""
140- models = list_capi_models (token )
141- tool_models : dict [str , dict ] = {}
142- for model in models :
143- if supports_tool_calls (model , models ) is True :
144- tool_models [model ] = models [model ]
145- return tool_models
200+ models = list_capi_models (token , endpoint )
201+ provider = get_provider (endpoint )
202+ return {
203+ mid : info
204+ for mid , info in models .items ()
205+ if provider .check_tool_calls (mid , info )
206+ }
0 commit comments