Skip to content

Commit 4e56693

Browse files
authored
feat(provider): add new provider instance (#17)
1 parent 3be76a4 commit 4e56693

5 files changed

Lines changed: 311 additions & 31 deletions

File tree

docs/internals/llm-types.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ The provider side is split into three layers.
132132

133133
`ProviderCapabilities` is capability discovery. It returns typed trait objects such as `as_native_anthropic_messages()` and `as_native_openai_responses()` instead of booleans, so a provider cannot claim support for a feature without also exposing the methods behind that feature.
134134

135+
### Runtime provider instances
136+
137+
`ProviderInstance` binds a shared provider definition to runtime auth, base URL overrides, and custom headers.
138+
139+
It validates provider default base URLs when no override is supplied, so bad provider metadata surfaces as a normal validation error instead of a process crash.
140+
141+
`ProviderRegistry` stores immutable `Arc<dyn ProviderCapabilities>` definitions so higher layers can resolve a provider by name and cheaply clone the shared definition into request-scoped runtime instances.
142+
143+
The registry builder rejects duplicate provider names up front, which keeps later registration mistakes from silently shadowing an earlier definition.
144+
135145
### Native support traits
136146

137147
`NativeAnthropicMessagesSupport` and `NativeOpenAIResponsesSupport` are optional extensions layered on top of `ChatTransform`.

src/gateway/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pub mod error;
2+
pub mod provider_instance;
23
pub mod traits;
34
pub mod types;

src/gateway/provider_instance.rs

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
use std::{collections::HashMap, fmt, sync::Arc};
2+
3+
use http::HeaderMap;
4+
use reqwest::Url;
5+
6+
use crate::gateway::{
7+
error::{GatewayError, Result},
8+
traits::ProviderCapabilities,
9+
};
10+
11+
/// Authentication material bound to a provider instance at runtime.
12+
#[derive(Clone, Default)]
13+
pub enum ProviderAuth {
14+
ApiKey(String),
15+
#[default]
16+
None,
17+
}
18+
19+
impl fmt::Debug for ProviderAuth {
20+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21+
match self {
22+
Self::ApiKey(_) => f.write_str("ApiKey(REDACTED)"),
23+
Self::None => f.write_str("None"),
24+
}
25+
}
26+
}
27+
28+
/// Runtime provider configuration: definition, auth, and deployment overrides.
29+
#[derive(Clone)]
30+
pub struct ProviderInstance {
31+
pub def: Arc<dyn ProviderCapabilities>,
32+
pub auth: ProviderAuth,
33+
pub base_url_override: Option<Url>,
34+
pub custom_headers: HeaderMap,
35+
}
36+
37+
impl ProviderInstance {
38+
pub fn effective_base_url(&self) -> Result<Url> {
39+
if let Some(base_url) = &self.base_url_override {
40+
return Ok(base_url.clone());
41+
}
42+
43+
self.def.default_base_url().parse().map_err(|error| {
44+
GatewayError::Validation(format!(
45+
"provider {} has invalid default_base_url {}: {}",
46+
self.def.name(),
47+
self.def.default_base_url(),
48+
error
49+
))
50+
})
51+
}
52+
53+
pub fn build_url(&self, model: &str) -> Result<String> {
54+
let base_url = self.effective_base_url()?;
55+
Ok(self.def.build_url(base_url.as_str(), model))
56+
}
57+
58+
pub fn build_headers(&self) -> Result<HeaderMap> {
59+
let mut headers = self.def.build_auth_headers(&self.auth)?;
60+
headers.extend(self.custom_headers.clone());
61+
Ok(headers)
62+
}
63+
}
64+
65+
/// Immutable registry of provider definitions.
66+
pub struct ProviderRegistry {
67+
defs: HashMap<&'static str, Arc<dyn ProviderCapabilities>>,
68+
}
69+
70+
impl ProviderRegistry {
71+
pub fn builder() -> ProviderRegistryBuilder {
72+
ProviderRegistryBuilder {
73+
defs: HashMap::new(),
74+
}
75+
}
76+
77+
pub fn get(&self, name: &str) -> Option<Arc<dyn ProviderCapabilities>> {
78+
self.defs.get(name).cloned()
79+
}
80+
}
81+
82+
pub struct ProviderRegistryBuilder {
83+
defs: HashMap<&'static str, Arc<dyn ProviderCapabilities>>,
84+
}
85+
86+
impl ProviderRegistryBuilder {
87+
pub fn register<P: ProviderCapabilities + 'static>(mut self, provider: P) -> Result<Self> {
88+
if self.defs.contains_key(provider.name()) {
89+
return Err(GatewayError::Validation(format!(
90+
"provider {} is already registered",
91+
provider.name()
92+
)));
93+
}
94+
95+
self.defs.insert(provider.name(), Arc::new(provider));
96+
Ok(self)
97+
}
98+
99+
pub fn build(self) -> ProviderRegistry {
100+
ProviderRegistry { defs: self.defs }
101+
}
102+
}
103+
104+
#[cfg(test)]
105+
mod tests {
106+
use std::{borrow::Cow, sync::Arc};
107+
108+
use http::{
109+
HeaderMap, HeaderValue,
110+
header::{AUTHORIZATION, HeaderName},
111+
};
112+
113+
use super::{ProviderAuth, ProviderInstance, ProviderRegistry};
114+
use crate::gateway::{
115+
error::{GatewayError, Result},
116+
traits::{ChatTransform, ProviderCapabilities, ProviderMeta, StreamReaderKind},
117+
};
118+
119+
struct DummyProvider;
120+
121+
struct InvalidUrlProvider;
122+
123+
struct DuplicateDummyProvider;
124+
125+
impl ProviderMeta for DummyProvider {
126+
fn name(&self) -> &'static str {
127+
"dummy"
128+
}
129+
130+
fn default_base_url(&self) -> &'static str {
131+
"https://api.example.com/"
132+
}
133+
134+
fn chat_endpoint_path(&self, model: &str) -> Cow<'static, str> {
135+
Cow::Owned(format!("/v1/models/{model}/chat"))
136+
}
137+
138+
fn stream_reader_kind(&self) -> StreamReaderKind {
139+
StreamReaderKind::Sse
140+
}
141+
142+
fn build_auth_headers(&self, auth: &ProviderAuth) -> Result<HeaderMap> {
143+
let mut headers = HeaderMap::new();
144+
if let ProviderAuth::ApiKey(api_key) = auth {
145+
let value = HeaderValue::from_str(&format!("Bearer {api_key}"))
146+
.map_err(|error| GatewayError::Validation(error.to_string()))?;
147+
headers.insert(AUTHORIZATION, value);
148+
}
149+
Ok(headers)
150+
}
151+
}
152+
153+
impl ChatTransform for DummyProvider {}
154+
155+
impl ProviderCapabilities for DummyProvider {}
156+
157+
impl ProviderMeta for InvalidUrlProvider {
158+
fn name(&self) -> &'static str {
159+
"invalid-url"
160+
}
161+
162+
fn default_base_url(&self) -> &'static str {
163+
"not a url"
164+
}
165+
166+
fn stream_reader_kind(&self) -> StreamReaderKind {
167+
StreamReaderKind::Sse
168+
}
169+
170+
fn build_auth_headers(&self, _auth: &ProviderAuth) -> Result<HeaderMap> {
171+
Ok(HeaderMap::new())
172+
}
173+
}
174+
175+
impl ChatTransform for InvalidUrlProvider {}
176+
177+
impl ProviderCapabilities for InvalidUrlProvider {}
178+
179+
impl ProviderMeta for DuplicateDummyProvider {
180+
fn name(&self) -> &'static str {
181+
"dummy"
182+
}
183+
184+
fn default_base_url(&self) -> &'static str {
185+
"https://duplicate.example.com"
186+
}
187+
188+
fn stream_reader_kind(&self) -> StreamReaderKind {
189+
StreamReaderKind::Sse
190+
}
191+
192+
fn build_auth_headers(&self, _auth: &ProviderAuth) -> Result<HeaderMap> {
193+
Ok(HeaderMap::new())
194+
}
195+
}
196+
197+
impl ChatTransform for DuplicateDummyProvider {}
198+
199+
impl ProviderCapabilities for DuplicateDummyProvider {}
200+
201+
#[test]
202+
fn provider_auth_debug_redacts_api_key() {
203+
assert_eq!(
204+
format!("{:?}", ProviderAuth::ApiKey("sk-secret".into())),
205+
"ApiKey(REDACTED)"
206+
);
207+
assert_eq!(format!("{:?}", ProviderAuth::None), "None");
208+
}
209+
210+
#[test]
211+
fn provider_instance_build_url_uses_provider_path() {
212+
let instance = ProviderInstance {
213+
def: Arc::new(DummyProvider),
214+
auth: ProviderAuth::None,
215+
base_url_override: None,
216+
custom_headers: HeaderMap::new(),
217+
};
218+
219+
assert_eq!(
220+
instance.build_url("demo-model").unwrap(),
221+
"https://api.example.com/v1/models/demo-model/chat"
222+
);
223+
}
224+
225+
#[test]
226+
fn provider_instance_invalid_default_base_url_returns_validation_error() {
227+
let instance = ProviderInstance {
228+
def: Arc::new(InvalidUrlProvider),
229+
auth: ProviderAuth::None,
230+
base_url_override: None,
231+
custom_headers: HeaderMap::new(),
232+
};
233+
234+
let error = instance.effective_base_url().unwrap_err();
235+
236+
assert!(matches!(
237+
error,
238+
GatewayError::Validation(message)
239+
if message.contains("invalid-url")
240+
&& message.contains("default_base_url")
241+
));
242+
}
243+
244+
#[test]
245+
fn provider_instance_build_headers_merges_auth_and_custom_headers() {
246+
let mut custom_headers = HeaderMap::new();
247+
custom_headers.insert(
248+
HeaderName::from_static("x-trace-id"),
249+
HeaderValue::from_static("trace-123"),
250+
);
251+
let instance = ProviderInstance {
252+
def: Arc::new(DummyProvider),
253+
auth: ProviderAuth::ApiKey("sk-secret".into()),
254+
base_url_override: None,
255+
custom_headers,
256+
};
257+
258+
let headers = instance.build_headers().unwrap();
259+
260+
assert_eq!(headers[AUTHORIZATION], "Bearer sk-secret");
261+
assert_eq!(headers["x-trace-id"], "trace-123");
262+
}
263+
264+
#[test]
265+
fn provider_registry_registers_and_looks_up_definitions() {
266+
let registry = ProviderRegistry::builder()
267+
.register(DummyProvider)
268+
.unwrap()
269+
.build();
270+
271+
let provider = registry.get("dummy").unwrap();
272+
assert_eq!(provider.name(), "dummy");
273+
assert!(registry.get("missing").is_none());
274+
}
275+
276+
#[test]
277+
fn provider_registry_rejects_duplicate_names() {
278+
let error = ProviderRegistry::builder()
279+
.register(DummyProvider)
280+
.unwrap()
281+
.register(DuplicateDummyProvider)
282+
.err()
283+
.unwrap();
284+
285+
assert!(matches!(
286+
error,
287+
GatewayError::Validation(message)
288+
if message.contains("dummy")
289+
&& message.contains("already registered")
290+
));
291+
}
292+
}

src/gateway/traits/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pub use native::{
88
NativeOpenAIResponsesSupport, OpenAIResponsesNativeStreamState,
99
};
1010
pub use provider::{
11-
ChatTransform, CompatQuirks, EmbedTransform, ImageGenTransform, ProviderAuth,
12-
ProviderCapabilities, ProviderMeta, StreamReaderKind, SttTransform, TtsTransform,
11+
ChatTransform, CompatQuirks, EmbedTransform, ImageGenTransform, ProviderCapabilities,
12+
ProviderMeta, StreamReaderKind, SttTransform, TtsTransform,
1313
};
14+
15+
pub use crate::gateway::provider_instance::ProviderAuth;

src/gateway/traits/provider.rs

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,18 @@
1-
use std::{borrow::Cow, fmt};
1+
use std::borrow::Cow;
22

33
use http::HeaderMap;
44
use serde_json::{Map, Value};
55

66
use crate::gateway::{
77
error::{GatewayError, Result},
8+
provider_instance::ProviderAuth,
89
traits::{
910
chat_format::ChatStreamState,
1011
native::{NativeAnthropicMessagesSupport, NativeOpenAIResponsesSupport},
1112
},
1213
types::openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse},
1314
};
1415

15-
/// Authentication material used by provider definitions.
16-
#[derive(Clone, Default)]
17-
pub enum ProviderAuth {
18-
ApiKey(String),
19-
#[default]
20-
None,
21-
}
22-
23-
impl fmt::Debug for ProviderAuth {
24-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25-
match self {
26-
Self::ApiKey(_) => f.write_str("ApiKey(REDACTED)"),
27-
Self::None => f.write_str("None"),
28-
}
29-
}
30-
}
31-
3216
/// Provider metadata with no data transformation logic.
3317
pub trait ProviderMeta: Send + Sync + 'static {
3418
fn name(&self) -> &'static str;
@@ -215,8 +199,8 @@ mod tests {
215199
use http::HeaderMap;
216200
use serde_json::json;
217201

218-
use super::{ChatTransform, CompatQuirks, ProviderAuth, ProviderMeta, StreamReaderKind};
219-
use crate::gateway::traits::chat_format::ChatStreamState;
202+
use super::{ChatTransform, CompatQuirks, ProviderMeta, StreamReaderKind};
203+
use crate::gateway::{provider_instance::ProviderAuth, traits::chat_format::ChatStreamState};
220204

221205
struct DummyProvider;
222206

@@ -243,15 +227,6 @@ mod tests {
243227

244228
impl ChatTransform for DummyProvider {}
245229

246-
#[test]
247-
fn provider_auth_debug_redacts_api_key() {
248-
assert_eq!(
249-
format!("{:?}", ProviderAuth::ApiKey("sk-secret".into())),
250-
"ApiKey(REDACTED)"
251-
);
252-
assert_eq!(format!("{:?}", ProviderAuth::None), "None");
253-
}
254-
255230
#[test]
256231
fn apply_to_request_removes_and_renames_fields() {
257232
let quirks = CompatQuirks {

0 commit comments

Comments
 (0)