11# SPDX-FileCopyrightText: GitHub, Inc.
22# SPDX-License-Identifier: MIT
33
4- """AI API endpoint and token management (CAPI integration)."""
4+ """AI API endpoint and token management.
5+
6+ Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+ custom endpoints). All provider-specific behaviour is captured in a single
8+ ``APIProvider`` dataclass so that adding a new provider only requires one
9+ registry entry instead of changes scattered across multiple match/case blocks.
10+ """
11+
12+ from __future__ import annotations
513
614import json
715import logging
816import os
17+ from collections .abc import Mapping
18+ from dataclasses import dataclass , field
19+ from types import MappingProxyType
20+ from typing import Any
921from urllib .parse import urlparse
1022
1123import httpx
12- from strenum import StrEnum
1324
1425__all__ = [
15- "AI_API_ENDPOINT_ENUM" ,
1626 "COPILOT_INTEGRATION_ID" ,
27+ "APIProvider" ,
1728 "get_AI_endpoint" ,
1829 "get_AI_token" ,
30+ "get_provider" ,
1931 "list_capi_models" ,
2032 "list_tool_call_models" ,
2133 "supports_tool_calls" ,
2234]
2335
36+ COPILOT_INTEGRATION_ID = os .getenv ("COPILOT_INTEGRATION_ID" , "vscode-chat" )
37+
38+
39+ # ---------------------------------------------------------------------------
40+ # Provider abstraction
41+ # ---------------------------------------------------------------------------
42+
43+ @dataclass (frozen = True )
44+ class APIProvider :
45+ """Encapsulates all endpoint-specific behaviour in one place."""
46+
47+ name : str
48+ base_url : str
49+ models_catalog : str = "/models"
50+ default_model : str = "gpt-4.1"
51+ extra_headers : Mapping [str , str ] = field (default_factory = dict )
52+
53+ def __post_init__ (self ) -> None :
54+ # Ensure base_url ends with / so httpx URL.join() preserves the path
55+ if self .base_url and not self .base_url .endswith ("/" ):
56+ object .__setattr__ (self , "base_url" , self .base_url + "/" )
57+ # Freeze mutable headers so singleton providers can't be mutated
58+ if isinstance (self .extra_headers , dict ):
59+ object .__setattr__ (self , "extra_headers" , MappingProxyType (self .extra_headers ))
60+
61+ # -- response parsing -----------------------------------------------------
2462
25- # Enumeration of currently supported API endpoints.
26- class AI_API_ENDPOINT_ENUM (StrEnum ):
27- AI_API_MODELS_GITHUB = "models.github.ai"
28- AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29- AI_API_OPENAI = "api.openai.com"
63+ def parse_models_list (self , body : Any ) -> list [dict ]:
64+ """Extract the models list from a catalog response body."""
65+ if isinstance (body , list ):
66+ return body
67+ if isinstance (body , dict ):
68+ data = body .get ("data" , [])
69+ return data if isinstance (data , list ) else []
70+ return []
3071
31- def to_url (self ) -> str :
32- """Convert the endpoint to its full URL."""
33- match self :
34- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
35- return f"https://{ self } "
36- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
37- return f"https://{ self } /inference"
38- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
39- return f"https://{ self } /v1"
40- case _:
41- raise ValueError (f"Unsupported endpoint: { self } " )
72+ # -- tool-call capability check -------------------------------------------
4273
74+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
75+ """Return True if *model* supports tool calls according to its catalog entry."""
76+ # Default: optimistically assume support when present in catalog
77+ return bool (model_info )
4378
44- COPILOT_INTEGRATION_ID = "vscode-chat"
4579
80+ class _CopilotProvider (APIProvider ):
81+ """GitHub Copilot API (api.githubcopilot.com)."""
82+
83+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
84+ return (
85+ model_info
86+ .get ("capabilities" , {})
87+ .get ("supports" , {})
88+ .get ("tool_calls" , False )
89+ )
90+
91+
92+ class _GitHubModelsProvider (APIProvider ):
93+ """GitHub Models API (models.github.ai)."""
94+
95+ def parse_models_list (self , body : Any ) -> list [dict ]:
96+ # Models API returns a bare list, not {"data": [...]}
97+ if isinstance (body , list ):
98+ return body
99+ return super ().parse_models_list (body )
100+
101+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
102+ return "tool-calling" in model_info .get ("capabilities" , [])
103+
104+
105+ class _OpenAIProvider (APIProvider ):
106+ """OpenAI API (api.openai.com).
107+
108+ The OpenAI /v1/models catalog does not expose capability metadata, so
109+ we maintain a prefix allowlist of known chat-completion model families.
110+ """
111+
112+ _CHAT_PREFIXES = ("gpt-3.5" , "gpt-4" , "o1" , "o3" , "o4" , "chatgpt-" )
113+
114+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
115+ model_id = model_info .get ("id" , "" ).lower ()
116+ return any (model_id .startswith (p ) for p in self ._CHAT_PREFIXES )
117+ # ---------------------------------------------------------------------------
118+ # Provider registry — add new providers here
119+ # ---------------------------------------------------------------------------
120+
121+ _PROVIDERS : dict [str , APIProvider ] = {
122+ "api.githubcopilot.com" : _CopilotProvider (
123+ name = "copilot" ,
124+ base_url = "https://api.githubcopilot.com" ,
125+ default_model = "gpt-4.1" ,
126+ extra_headers = {"Copilot-Integration-Id" : COPILOT_INTEGRATION_ID },
127+ ),
128+ "models.github.ai" : _GitHubModelsProvider (
129+ name = "github-models" ,
130+ base_url = "https://models.github.ai/inference" ,
131+ models_catalog = "/catalog/models" ,
132+ default_model = "openai/gpt-4.1" ,
133+ ),
134+ "api.openai.com" : _OpenAIProvider (
135+ name = "openai" ,
136+ base_url = "https://api.openai.com/v1" ,
137+ models_catalog = "/v1/models" ,
138+ default_model = "gpt-4.1" ,
139+ ),
140+ }
141+
142+ def get_provider (endpoint : str | None = None ) -> APIProvider :
143+ """Return the ``APIProvider`` for the given (or configured) endpoint URL."""
144+ url = endpoint or get_AI_endpoint ()
145+ netloc = urlparse (url ).netloc
146+ provider = _PROVIDERS .get (netloc )
147+ if provider is not None :
148+ return provider
149+ # Unknown endpoint — return a generic provider with the given base URL
150+ return APIProvider (name = "custom" , base_url = url , default_model = "please-set-default-model-via-env" )
151+
152+
153+ # ---------------------------------------------------------------------------
154+ # Endpoint / token helpers
155+ # ---------------------------------------------------------------------------
46156
47- # you can also set https://api.githubcopilot.com if you prefer
48- # but beware that your taskflows need to reference the correct model id
49- # since different APIs use their own id schema, use -l with your desired
50- # endpoint to retrieve the correct id names to use for your taskflow
51157def get_AI_endpoint () -> str :
52158 """Return the configured AI API endpoint URL."""
53159 return os .getenv ("AI_API_ENDPOINT" , default = "https://models.github.ai/inference" )
@@ -64,82 +170,54 @@ def get_AI_token() -> str:
64170 raise RuntimeError ("AI_API_TOKEN environment variable is not set." )
65171
66172
67- # assume we are >= python 3.9 for our type hints
68- def list_capi_models (token : str ) -> dict [str , dict ]:
69- """Retrieve a dictionary of available CAPI models"""
70- models = {}
173+ # ---------------------------------------------------------------------------
174+ # Model catalog
175+ # ---------------------------------------------------------------------------
176+
177+ def list_capi_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
178+ """Retrieve available models from the configured API endpoint.
179+
180+ Args:
181+ token: Bearer token for authentication.
182+ endpoint: Optional endpoint URL override (defaults to env config).
183+ """
184+ provider = get_provider (endpoint )
185+ base = provider .base_url
186+ models : dict [str , dict ] = {}
71187 try :
72- api_endpoint = get_AI_endpoint ()
73- netloc = urlparse (api_endpoint ).netloc
74- match netloc :
75- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
76- models_catalog = "models"
77- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
78- models_catalog = "catalog/models"
79- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
80- models_catalog = "models"
81- case _:
82- # Unknown endpoint — try the OpenAI-style models catalog
83- models_catalog = "models"
188+ headers = {
189+ "Accept" : "application/json" ,
190+ "Authorization" : f"Bearer { token } " ,
191+ ** provider .extra_headers ,
192+ }
84193 r = httpx .get (
85- httpx .URL (api_endpoint ).join (models_catalog ),
86- headers = {
87- "Accept" : "application/json" ,
88- "Authorization" : f"Bearer { token } " ,
89- "Copilot-Integration-Id" : COPILOT_INTEGRATION_ID ,
90- },
194+ httpx .URL (base ).join (provider .models_catalog ),
195+ headers = headers ,
91196 )
92197 r .raise_for_status ()
93- # CAPI vs Models API
94- match netloc :
95- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
96- models_list = r .json ().get ("data" , [])
97- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
98- models_list = r .json ()
99- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
100- models_list = r .json ().get ("data" , [])
101- case _:
102- # Unknown endpoint — try common response shapes
103- body = r .json ()
104- if isinstance (body , dict ):
105- models_list = body .get ("data" , [])
106- elif isinstance (body , list ):
107- models_list = body
108- else :
109- models_list = []
110- for model in models_list :
198+ for model in provider .parse_models_list (r .json ()):
111199 models [model .get ("id" )] = dict (model )
112- except httpx .RequestError :
113- logging .exception ("Request error" )
114- except json .JSONDecodeError :
115- logging .exception ("JSON error" )
116- except httpx .HTTPStatusError :
117- logging .exception ("HTTP error" )
200+ except (httpx .RequestError , httpx .HTTPStatusError , json .JSONDecodeError ):
201+ logging .exception ("Failed to list models from %s" , base )
118202 return models
119203
120204
121- def supports_tool_calls (model : str , models : dict [str , dict ]) -> bool :
122- """Check whether the given model supports tool calls."""
123- api_endpoint = get_AI_endpoint ()
124- netloc = urlparse (api_endpoint ).netloc
125- match netloc :
126- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
127- return models .get (model , {}).get ("capabilities" , {}).get ("supports" , {}).get ("tool_calls" , False )
128- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
129- return "tool-calling" in models .get (model , {}).get ("capabilities" , [])
130- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
131- return "gpt-" in model .lower ()
132- case _:
133- # Unknown endpoint — optimistically assume tool-call support
134- # if the model is present in the catalog.
135- return model in models
136-
137-
138- def list_tool_call_models (token : str ) -> dict [str , dict ]:
205+ def supports_tool_calls (
206+ model : str ,
207+ models : dict [str , dict ],
208+ endpoint : str | None = None ,
209+ ) -> bool :
210+ """Check whether *model* supports tool calls."""
211+ provider = get_provider (endpoint )
212+ return provider .check_tool_calls (model , models .get (model , {}))
213+
214+
215+ def list_tool_call_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
139216 """Return only models that support tool calls."""
140- models = list_capi_models (token )
141- tool_models : dict [str , dict ] = {}
142- for model in models :
143- if supports_tool_calls (model , models ) is True :
144- tool_models [model ] = models [model ]
145- return tool_models
217+ models = list_capi_models (token , endpoint )
218+ provider = get_provider (endpoint )
219+ return {
220+ mid : info
221+ for mid , info in models .items ()
222+ if provider .check_tool_calls (mid , info )
223+ }
0 commit comments