11# SPDX-FileCopyrightText: GitHub, Inc.
22# SPDX-License-Identifier: MIT
33
4- """AI API endpoint and token management (CAPI integration)."""
4+ """AI API endpoint and token management.
5+
6+ Supports multiple API providers (GitHub Copilot, GitHub Models, OpenAI, and
7+ custom endpoints). All provider-specific behaviour is captured in a single
8+ ``APIProvider`` dataclass so that adding a new provider only requires one
9+ registry entry instead of changes scattered across multiple match/case blocks.
10+ """
11+
12+ from __future__ import annotations
513
614import json
715import logging
816import os
17+ from collections .abc import Mapping
18+ from dataclasses import dataclass , field
19+ from types import MappingProxyType
20+ from typing import Any
921from urllib .parse import urlparse
1022
1123import httpx
12- from strenum import StrEnum
1324
1425__all__ = [
15- "AI_API_ENDPOINT_ENUM" ,
1626 "COPILOT_INTEGRATION_ID" ,
27+ "APIProvider" ,
1728 "get_AI_endpoint" ,
1829 "get_AI_token" ,
30+ "get_provider" ,
1931 "list_capi_models" ,
2032 "list_tool_call_models" ,
2133 "supports_tool_calls" ,
2234]
2335
36+ COPILOT_INTEGRATION_ID = os .getenv ("COPILOT_INTEGRATION_ID" , "vscode-chat" )
37+
38+
39+ # ---------------------------------------------------------------------------
40+ # Provider abstraction
41+ # ---------------------------------------------------------------------------
42+
43+ @dataclass (frozen = True )
44+ class APIProvider :
45+ """Encapsulates all endpoint-specific behaviour in one place."""
46+
47+ name : str
48+ base_url : str
49+ models_catalog : str = "/models"
50+ default_model : str = "gpt-4.1"
51+ extra_headers : Mapping [str , str ] = field (default_factory = dict )
52+
53+ def __post_init__ (self ) -> None :
54+ # Ensure base_url ends with / so httpx URL.join() preserves the path
55+ if self .base_url and not self .base_url .endswith ("/" ):
56+ object .__setattr__ (self , "base_url" , self .base_url + "/" )
57+ # Freeze mutable headers so singleton providers can't be mutated
58+ if isinstance (self .extra_headers , dict ):
59+ object .__setattr__ (self , "extra_headers" , MappingProxyType (self .extra_headers ))
60+
61+ # -- response parsing -----------------------------------------------------
2462
25- # Enumeration of currently supported API endpoints.
26- class AI_API_ENDPOINT_ENUM (StrEnum ):
27- AI_API_MODELS_GITHUB = "models.github.ai"
28- AI_API_GITHUBCOPILOT = "api.githubcopilot.com"
29- AI_API_OPENAI = "api.openai.com"
63+ def parse_models_list (self , body : Any ) -> list [dict ]:
64+ """Extract the models list from a catalog response body."""
65+ if isinstance (body , list ):
66+ return body
67+ if isinstance (body , dict ):
68+ data = body .get ("data" , [])
69+ return data if isinstance (data , list ) else []
70+ return []
3071
31- def to_url (self ) -> str :
32- """Convert the endpoint to its full URL."""
33- match self :
34- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
35- return f"https://{ self } "
36- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
37- return f"https://{ self } /inference"
38- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
39- return f"https://{ self } /v1"
40- case _:
41- raise ValueError (f"Unsupported endpoint: { self } " )
72+ # -- tool-call capability check -------------------------------------------
4273
74+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
75+ """Return True if *model* supports tool calls according to its catalog entry."""
76+ # Default: optimistically assume support when present in catalog
77+ return bool (model_info )
4378
44- COPILOT_INTEGRATION_ID = "vscode-chat"
4579
80+ class _CopilotProvider (APIProvider ):
81+ """GitHub Copilot API (api.githubcopilot.com)."""
82+
83+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
84+ return (
85+ model_info
86+ .get ("capabilities" , {})
87+ .get ("supports" , {})
88+ .get ("tool_calls" , False )
89+ )
90+
91+
92+ class _GitHubModelsProvider (APIProvider ):
93+ """GitHub Models API (models.github.ai)."""
94+
95+ def parse_models_list (self , body : Any ) -> list [dict ]:
96+ # Models API returns a bare list, not {"data": [...]}
97+ if isinstance (body , list ):
98+ return body
99+ return super ().parse_models_list (body )
100+
101+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
102+ return "tool-calling" in model_info .get ("capabilities" , [])
103+
104+
105+ class _OpenAIProvider (APIProvider ):
106+ """OpenAI API (api.openai.com).
107+
108+ The OpenAI /v1/models catalog does not expose capability metadata, so
109+ we maintain a prefix allowlist of known chat-completion model families.
110+ """
111+
112+ _CHAT_PREFIXES = ("gpt-3.5" , "gpt-4" , "o1" , "o3" , "o4" , "chatgpt-" )
113+
114+ def check_tool_calls (self , _model : str , model_info : dict ) -> bool :
115+ model_id = model_info .get ("id" , "" ).lower ()
116+ return any (model_id .startswith (p ) for p in self ._CHAT_PREFIXES )
117+ # ---------------------------------------------------------------------------
118+ # Provider registry — add new providers here
119+ # ---------------------------------------------------------------------------
120+
121+ _PROVIDERS : dict [str , APIProvider ] = {
122+ "api.githubcopilot.com" : _CopilotProvider (
123+ name = "copilot" ,
124+ base_url = "https://api.githubcopilot.com" ,
125+ default_model = "gpt-4.1" ,
126+ extra_headers = {"Copilot-Integration-Id" : COPILOT_INTEGRATION_ID },
127+ ),
128+ "models.github.ai" : _GitHubModelsProvider (
129+ name = "github-models" ,
130+ base_url = "https://models.github.ai/inference" ,
131+ models_catalog = "/catalog/models" ,
132+ default_model = "openai/gpt-4.1" ,
133+ ),
134+ "api.openai.com" : _OpenAIProvider (
135+ name = "openai" ,
136+ base_url = "https://api.openai.com/v1" ,
137+ models_catalog = "/v1/models" ,
138+ default_model = "gpt-4.1" ,
139+ ),
140+ }
141+
142+ def get_provider (endpoint : str | None = None ) -> APIProvider :
143+ """Return the ``APIProvider`` for the given (or configured) endpoint URL."""
144+ url = endpoint or get_AI_endpoint ()
145+ netloc = urlparse (url ).netloc
146+ provider = _PROVIDERS .get (netloc )
147+ if provider is not None :
148+ return provider
149+ # Unknown endpoint — return a generic provider with the given base URL
150+ return APIProvider (name = "custom" , base_url = url , default_model = "please-set-default-model-via-env" )
151+
152+
153+ # ---------------------------------------------------------------------------
154+ # Endpoint / token helpers
155+ # ---------------------------------------------------------------------------
46156
47- # you can also set https://api.githubcopilot.com if you prefer
48- # but beware that your taskflows need to reference the correct model id
49- # since different APIs use their own id schema, use -l with your desired
50- # endpoint to retrieve the correct id names to use for your taskflow
51157def get_AI_endpoint () -> str :
52158 """Return the configured AI API endpoint URL."""
53159 return os .getenv ("AI_API_ENDPOINT" , default = "https://models.github.ai/inference" )
@@ -65,82 +171,54 @@ def get_AI_token() -> str:
65171 raise RuntimeError (msg )
66172
67173
68- # assume we are >= python 3.9 for our type hints
69- def list_capi_models (token : str ) -> dict [str , dict ]:
70- """Retrieve a dictionary of available CAPI models"""
71- models = {}
174+ # ---------------------------------------------------------------------------
175+ # Model catalog
176+ # ---------------------------------------------------------------------------
177+
178+ def list_capi_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
179+ """Retrieve available models from the configured API endpoint.
180+
181+ Args:
182+ token: Bearer token for authentication.
183+ endpoint: Optional endpoint URL override (defaults to env config).
184+ """
185+ provider = get_provider (endpoint )
186+ base = provider .base_url
187+ models : dict [str , dict ] = {}
72188 try :
73- api_endpoint = get_AI_endpoint ()
74- netloc = urlparse (api_endpoint ).netloc
75- match netloc :
76- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
77- models_catalog = "models"
78- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
79- models_catalog = "catalog/models"
80- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
81- models_catalog = "models"
82- case _:
83- # Unknown endpoint — try the OpenAI-style models catalog
84- models_catalog = "models"
189+ headers = {
190+ "Accept" : "application/json" ,
191+ "Authorization" : f"Bearer { token } " ,
192+ ** provider .extra_headers ,
193+ }
85194 r = httpx .get (
86- httpx .URL (api_endpoint ).join (models_catalog ),
87- headers = {
88- "Accept" : "application/json" ,
89- "Authorization" : f"Bearer { token } " ,
90- "Copilot-Integration-Id" : COPILOT_INTEGRATION_ID ,
91- },
195+ httpx .URL (base ).join (provider .models_catalog ),
196+ headers = headers ,
92197 )
93198 r .raise_for_status ()
94- # CAPI vs Models API
95- match netloc :
96- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
97- models_list = r .json ().get ("data" , [])
98- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
99- models_list = r .json ()
100- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
101- models_list = r .json ().get ("data" , [])
102- case _:
103- # Unknown endpoint — try common response shapes
104- body = r .json ()
105- if isinstance (body , dict ):
106- models_list = body .get ("data" , [])
107- elif isinstance (body , list ):
108- models_list = body
109- else :
110- models_list = []
111- for model in models_list :
199+ for model in provider .parse_models_list (r .json ()):
112200 models [model .get ("id" )] = dict (model )
113- except httpx .RequestError :
114- logging .exception ("Request error" )
115- except json .JSONDecodeError :
116- logging .exception ("JSON error" )
117- except httpx .HTTPStatusError :
118- logging .exception ("HTTP error" )
201+ except (httpx .RequestError , httpx .HTTPStatusError , json .JSONDecodeError ):
202+ logging .exception ("Failed to list models from %s" , base )
119203 return models
120204
121205
122- def supports_tool_calls (model : str , models : dict [str , dict ]) -> bool :
123- """Check whether the given model supports tool calls."""
124- api_endpoint = get_AI_endpoint ()
125- netloc = urlparse (api_endpoint ).netloc
126- match netloc :
127- case AI_API_ENDPOINT_ENUM .AI_API_GITHUBCOPILOT :
128- return models .get (model , {}).get ("capabilities" , {}).get ("supports" , {}).get ("tool_calls" , False )
129- case AI_API_ENDPOINT_ENUM .AI_API_MODELS_GITHUB :
130- return "tool-calling" in models .get (model , {}).get ("capabilities" , [])
131- case AI_API_ENDPOINT_ENUM .AI_API_OPENAI :
132- return "gpt-" in model .lower ()
133- case _:
134- # Unknown endpoint — optimistically assume tool-call support
135- # if the model is present in the catalog.
136- return model in models
137-
138-
139- def list_tool_call_models (token : str ) -> dict [str , dict ]:
206+ def supports_tool_calls (
207+ model : str ,
208+ models : dict [str , dict ],
209+ endpoint : str | None = None ,
210+ ) -> bool :
211+ """Check whether *model* supports tool calls."""
212+ provider = get_provider (endpoint )
213+ return provider .check_tool_calls (model , models .get (model , {}))
214+
215+
216+ def list_tool_call_models (token : str , endpoint : str | None = None ) -> dict [str , dict ]:
140217 """Return only models that support tool calls."""
141- models = list_capi_models (token )
142- tool_models : dict [str , dict ] = {}
143- for model in models :
144- if supports_tool_calls (model , models ) is True :
145- tool_models [model ] = models [model ]
146- return tool_models
218+ models = list_capi_models (token , endpoint )
219+ provider = get_provider (endpoint )
220+ return {
221+ mid : info
222+ for mid , info in models .items ()
223+ if provider .check_tool_calls (mid , info )
224+ }
0 commit comments