Skip to content

Commit a95ba1a

Browse files
authored
feat(provider): add provider macros (#19)
1 parent 7200e16 commit a95ba1a

7 files changed

Lines changed: 392 additions & 0 deletions

File tree

docs/internals/llm-types.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ The provider side is split into three layers.
134134

135135
`ProviderCapabilities` is capability discovery. It returns typed trait objects such as `as_native_anthropic_messages()` and `as_native_openai_responses()` instead of booleans, so a provider cannot claim support for a feature without also exposing the methods behind that feature.
136136

137+
For OpenAI-compatible providers, the concrete definition can stay very small. The `provider!()` macro generates `ProviderMeta`, a default `ChatTransform`, and an empty `ProviderCapabilities` implementation from a declarative block of base URL, auth shape, stream reader kind, and compatibility quirks.
138+
139+
`OpenAIDef` remains hand-written because it needs its own default quirk profile, while `DeepSeek` is the first macro-generated provider in the new stack.
140+
137141
### Runtime provider instances
138142

139143
`ProviderInstance` binds a shared provider definition to runtime auth, base URL overrides, and custom headers.

src/gateway/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod error;
22
pub mod formats;
33
pub mod provider_instance;
4+
pub mod providers;
45
pub mod traits;
56
pub mod types;

src/gateway/provider_instance.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@ impl fmt::Debug for ProviderAuth {
2525
}
2626
}
2727

28+
impl ProviderAuth {
29+
pub fn api_key(&self) -> Result<&str> {
30+
match self {
31+
Self::ApiKey(api_key) => Ok(api_key),
32+
Self::None => Err(GatewayError::Validation(
33+
"missing ProviderAuth::ApiKey value".into(),
34+
)),
35+
}
36+
}
37+
38+
pub fn api_key_for(&self, provider: &str) -> Result<&str> {
39+
self.api_key().map_err(|error| match error {
40+
GatewayError::Validation(message) => {
41+
GatewayError::Validation(format!("provider {}: {}", provider, message))
42+
}
43+
other => other,
44+
})
45+
}
46+
}
47+
2848
/// Runtime provider configuration: definition, auth, and deployment overrides.
2949
#[derive(Clone)]
3050
pub struct ProviderInstance {
@@ -207,6 +227,33 @@ mod tests {
207227
assert_eq!(format!("{:?}", ProviderAuth::None), "None");
208228
}
209229

230+
#[test]
231+
fn provider_auth_api_key_requires_api_key_variant() {
232+
assert_eq!(
233+
ProviderAuth::ApiKey("sk-secret".into()).api_key().unwrap(),
234+
"sk-secret"
235+
);
236+
237+
let error = ProviderAuth::None.api_key().unwrap_err();
238+
assert!(matches!(
239+
error,
240+
GatewayError::Validation(message)
241+
if message.contains("ProviderAuth::ApiKey")
242+
));
243+
}
244+
245+
#[test]
246+
fn provider_auth_api_key_for_adds_provider_context() {
247+
let error = ProviderAuth::None.api_key_for("deepseek").unwrap_err();
248+
249+
assert!(matches!(
250+
error,
251+
GatewayError::Validation(message)
252+
if message.contains("deepseek")
253+
&& message.contains("ProviderAuth::ApiKey")
254+
));
255+
}
256+
210257
#[test]
211258
fn provider_instance_build_url_uses_provider_path() {
212259
let instance = ProviderInstance {

src/gateway/providers/deepseek.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use crate::gateway::providers::macros::provider;
2+
3+
provider!(DeepSeek {
4+
display_name: "deepseek",
5+
base_url: "https://api.deepseek.com",
6+
auth: bearer,
7+
});

src/gateway/providers/macros.rs

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
macro_rules! provider {
2+
(@chat_path) => {};
3+
(@chat_path $path:literal) => {
4+
fn chat_endpoint_path(&self, _model: &str) -> std::borrow::Cow<'static, str> {
5+
std::borrow::Cow::Borrowed($path)
6+
}
7+
};
8+
9+
(@stream) => {};
10+
(@stream $kind:expr) => {
11+
fn stream_reader_kind(&self) -> $crate::gateway::traits::StreamReaderKind {
12+
$kind
13+
}
14+
};
15+
16+
(@auth bearer) => {
17+
fn build_auth_headers(
18+
&self,
19+
auth: &$crate::gateway::provider_instance::ProviderAuth,
20+
) -> $crate::gateway::error::Result<http::HeaderMap> {
21+
let mut headers = http::HeaderMap::new();
22+
let value = http::HeaderValue::from_str(&format!(
23+
"Bearer {}",
24+
auth.api_key_for(self.name())?
25+
))
26+
.map_err(|error| $crate::gateway::error::GatewayError::Validation(error.to_string()))?;
27+
headers.insert(http::header::AUTHORIZATION, value);
28+
Ok(headers)
29+
}
30+
};
31+
32+
(@auth api_key_header($header:literal)) => {
33+
fn build_auth_headers(
34+
&self,
35+
auth: &$crate::gateway::provider_instance::ProviderAuth,
36+
) -> $crate::gateway::error::Result<http::HeaderMap> {
37+
const HEADER_NAME: http::header::HeaderName =
38+
http::header::HeaderName::from_static($header);
39+
let mut headers = http::HeaderMap::new();
40+
let value = http::HeaderValue::from_str(auth.api_key_for(self.name())?)
41+
.map_err(|error| $crate::gateway::error::GatewayError::Validation(error.to_string()))?;
42+
headers.insert(HEADER_NAME, value);
43+
Ok(headers)
44+
}
45+
};
46+
47+
(@quirks { $($field:ident : $value:expr),* $(,)? }) => {
48+
fn default_quirks(&self) -> $crate::gateway::traits::CompatQuirks {
49+
$crate::gateway::traits::CompatQuirks {
50+
$($field: $value,)*
51+
..$crate::gateway::traits::CompatQuirks::NONE
52+
}
53+
}
54+
};
55+
56+
(@impl_provider
57+
$name:ident,
58+
$display:literal,
59+
$base:literal,
60+
[$($path:tt)?],
61+
[$($stream_kind:tt)?],
62+
[$($auth_kind:tt)+]
63+
) => {
64+
pub struct $name;
65+
66+
impl $crate::gateway::traits::ProviderMeta for $name {
67+
fn name(&self) -> &'static str {
68+
$display
69+
}
70+
71+
fn default_base_url(&self) -> &'static str {
72+
$base
73+
}
74+
75+
provider!(@chat_path $($path)?);
76+
provider!(@stream $($stream_kind)?);
77+
provider!(@auth $($auth_kind)+);
78+
}
79+
80+
impl $crate::gateway::traits::ProviderCapabilities for $name {}
81+
};
82+
83+
(
84+
$name:ident {
85+
display_name: $display:literal,
86+
base_url: $base:literal,
87+
$(chat_path: $path:literal,)?
88+
$(stream: $stream_kind:expr,)?
89+
auth: bearer,
90+
quirks: { $($quirk_field:ident : $quirk_value:expr),* $(,)? }
91+
}
92+
) => {
93+
provider!(@impl_provider $name, $display, $base, [$($path)?], [$($stream_kind)?], [bearer]);
94+
95+
impl $crate::gateway::traits::ChatTransform for $name {
96+
provider!(@quirks { $($quirk_field : $quirk_value),* });
97+
}
98+
};
99+
100+
(
101+
$name:ident {
102+
display_name: $display:literal,
103+
base_url: $base:literal,
104+
$(chat_path: $path:literal,)?
105+
$(stream: $stream_kind:expr,)?
106+
auth: bearer $(,)?
107+
}
108+
) => {
109+
provider!(@impl_provider $name, $display, $base, [$($path)?], [$($stream_kind)?], [bearer]);
110+
111+
impl $crate::gateway::traits::ChatTransform for $name {}
112+
};
113+
114+
(
115+
$name:ident {
116+
display_name: $display:literal,
117+
base_url: $base:literal,
118+
$(chat_path: $path:literal,)?
119+
$(stream: $stream_kind:expr,)?
120+
auth: api_key_header($header:literal),
121+
quirks: { $($quirk_field:ident : $quirk_value:expr),* $(,)? }
122+
}
123+
) => {
124+
provider!(@impl_provider $name, $display, $base, [$($path)?], [$($stream_kind)?], [api_key_header($header)]);
125+
126+
impl $crate::gateway::traits::ChatTransform for $name {
127+
provider!(@quirks { $($quirk_field : $quirk_value),* });
128+
}
129+
};
130+
131+
(
132+
$name:ident {
133+
display_name: $display:literal,
134+
base_url: $base:literal,
135+
$(chat_path: $path:literal,)?
136+
$(stream: $stream_kind:expr,)?
137+
auth: api_key_header($header:literal) $(,)?
138+
}
139+
) => {
140+
provider!(@impl_provider $name, $display, $base, [$($path)?], [$($stream_kind)?], [api_key_header($header)]);
141+
142+
impl $crate::gateway::traits::ChatTransform for $name {}
143+
};
144+
}
145+
146+
pub(crate) use provider;
147+
148+
#[cfg(test)]
149+
mod tests {
150+
use std::borrow::Cow;
151+
152+
use crate::gateway::{
153+
provider_instance::ProviderAuth,
154+
traits::{ChatTransform, ProviderMeta, StreamReaderKind},
155+
};
156+
157+
provider!(MacroTestProvider {
158+
display_name: "macro-test",
159+
base_url: "https://provider.example.com",
160+
chat_path: "/custom/chat",
161+
stream: StreamReaderKind::JsonArrayStream,
162+
auth: api_key_header("x-api-key"),
163+
quirks: {
164+
unsupported_params: &["seed"],
165+
tool_args_may_be_object: true,
166+
}
167+
});
168+
169+
#[test]
170+
fn macro_generated_provider_exposes_expected_metadata() {
171+
let provider = MacroTestProvider;
172+
173+
assert_eq!(provider.name(), "macro-test");
174+
assert_eq!(provider.default_base_url(), "https://provider.example.com");
175+
assert_eq!(
176+
provider.chat_endpoint_path("ignored"),
177+
Cow::Borrowed("/custom/chat")
178+
);
179+
assert_eq!(
180+
provider.stream_reader_kind(),
181+
StreamReaderKind::JsonArrayStream
182+
);
183+
}
184+
185+
#[test]
186+
fn macro_generated_provider_builds_auth_headers_and_quirks() {
187+
let provider = MacroTestProvider;
188+
let headers = provider
189+
.build_auth_headers(&ProviderAuth::ApiKey("secret-key".into()))
190+
.unwrap();
191+
let quirks = provider.default_quirks();
192+
193+
assert_eq!(headers["x-api-key"], "secret-key");
194+
assert_eq!(quirks.unsupported_params, &["seed"]);
195+
assert!(quirks.tool_args_may_be_object);
196+
}
197+
198+
#[test]
199+
fn macro_generated_provider_reports_provider_name_for_missing_api_key() {
200+
let provider = MacroTestProvider;
201+
let error = provider
202+
.build_auth_headers(&ProviderAuth::None)
203+
.unwrap_err();
204+
205+
assert!(matches!(
206+
error,
207+
crate::gateway::error::GatewayError::Validation(message)
208+
if message.contains("macro-test")
209+
&& message.contains("ProviderAuth::ApiKey")
210+
));
211+
}
212+
}

src/gateway/providers/mod.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
pub mod deepseek;
2+
pub mod macros;
3+
pub mod openai;
4+
5+
pub use deepseek::DeepSeek;
6+
pub use openai::OpenAIDef;
7+
8+
use crate::gateway::{error::Result, provider_instance::ProviderRegistry};
9+
10+
pub fn default_provider_registry() -> Result<ProviderRegistry> {
11+
let builder = ProviderRegistry::builder()
12+
.register(OpenAIDef)?
13+
.register(DeepSeek)?;
14+
Ok(builder.build())
15+
}
16+
17+
#[cfg(test)]
18+
mod tests {
19+
use super::default_provider_registry;
20+
21+
#[test]
22+
fn default_provider_registry_registers_builtin_providers() {
23+
let registry = default_provider_registry().unwrap();
24+
25+
assert_eq!(registry.get("openai").unwrap().name(), "openai");
26+
assert_eq!(registry.get("deepseek").unwrap().name(), "deepseek");
27+
assert!(registry.get("missing").is_none());
28+
}
29+
}

0 commit comments

Comments
 (0)