Skip to content

Commit 2b3d8ad

Browse files
authored
feat: Tower based Middlewares (#539)
* initial implementation of middleware * allow error to be Into<OpenAIError> * documentation and default reqwest service * namspaced in middleware * park tower error handling for now TODO * working compilation of all crates * support for middleware in wasm * impl Default for ReqwestService * add tower example * add tower-wasm example * introduce Boxed variant * add both display and debug * update comment * updated doc * updated doc * refactor: rename http_executor to executor; create a middleware dir for doc.rs * move retry stuff in its own module * update examples * document and log retry headers * add openai retry layer * fix logical error in openai layer * optimize openai retry layer * return early on success * update doc * updated doc * add middleware to full * add middleware section * nest retry stuff inside ::middleware::retry * change of module path * updated doc * updated doc * update doc * updated doc * fix the warning * update tower example * cleanup * middleware.md * only compile it for wasm * cleanup client * cleanup client * cleanup * renmae midddlware * update readme * cleanup streaming and form creation * fix test * add larger delay for retrieval to work * make image-edit example work * allow tower user input to be random * update readme * update workflows to build middleware feature * fix warnings * fix warnings * only keep wasm32-unknown-unknown for workflows checks * bump async-openai-macros version * update doc comment * remove duplicate From tower BoxError for wasm/non-wasm * remove redundant doc comment * cleanup comment * in readme use link to github instead of relative path - so that it works on crates.io * remove comment * update lib.rs WASM section * add docs on middleware error handling
1 parent 02870d9 commit 2b3d8ad

35 files changed

Lines changed: 2062 additions & 256 deletions

File tree

.github/workflows/pr-checks.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,9 @@ jobs:
112112

113113
- name: Run clippy with feature ${{ matrix.feature }}
114114
run: cargo clippy --no-default-features --features ${{ matrix.feature }} -- -D warnings
115+
116+
- name: Build with feature ${{ matrix.feature }} and middleware
117+
if: ${{ !contains(matrix.feature, 'types') }}
118+
env:
119+
RUSTFLAGS: "-D warnings"
120+
run: cargo build --no-default-features --features ${{ matrix.feature }},middleware --verbose

.github/workflows/wasm-pr-checks.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
target: [wasm32-unknown-unknown, wasm32-wasip1]
17+
target: [wasm32-unknown-unknown] #, wasm32-wasip1 (doesnt work with changes for middlewares)
1818
feature:
1919
[
2020
byot,
@@ -98,3 +98,9 @@ jobs:
9898
env:
9999
RUSTFLAGS: "-D warnings"
100100
run: cargo build --no-default-features --features ${{ matrix.feature }} --verbose --target ${{ matrix.target }}
101+
102+
- name: Build with feature ${{ matrix.feature }}, middleware, and target ${{ matrix.target }}
103+
if: ${{ !contains(matrix.feature, 'types') }}
104+
env:
105+
RUSTFLAGS: "-D warnings"
106+
run: cargo build --no-default-features --features ${{ matrix.feature }},middleware --verbose --target ${{ matrix.target }}

async-openai-macros/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "async-openai-macros"
3-
version = "0.1.1"
3+
version = "0.2.0"
44
authors = ["Himanshu Neema"]
55
keywords = ["openai", "macros", "ai"]
66
description = "Macros for async-openai"
@@ -14,6 +14,9 @@ readme = "README.md"
1414
[lib]
1515
proc-macro = true
1616

17+
[features]
18+
middleware = []
19+
1720
[dependencies]
1821
syn = { version = "2.0", features = ["full"] }
1922
quote = "1.0"

async-openai-macros/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use proc_macro::TokenStream;
22
use quote::{quote, ToTokens};
33
use syn::{
44
parse::{Parse, ParseStream},
5-
parse_macro_input,
5+
parse_macro_input, parse_quote,
66
punctuated::Punctuated,
77
token::Comma,
88
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
@@ -57,6 +57,7 @@ pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
5757
let input = parse_macro_input!(item as ItemFn);
5858
let mut new_generics = Generics::default();
5959
let mut param_count = 0;
60+
let middleware_enabled = cfg!(feature = "middleware");
6061

6162
// Process function arguments
6263
let mut new_params = Vec::new();
@@ -82,6 +83,17 @@ pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
8283
{
8384
type_param.bounds.extend(vec![bound.clone()]);
8485
}
86+
let needs_middleware_replay_bounds =
87+
bounds_args.bounds.iter().any(|(name, bound)| {
88+
name == &generic_name
89+
&& bound.to_token_stream().to_string().contains("Clone")
90+
&& !bound.to_token_stream().to_string().contains("Display")
91+
});
92+
if middleware_enabled && needs_middleware_replay_bounds {
93+
type_param
94+
.bounds
95+
.push(parse_quote!(crate::middleware::MiddlewareInput));
96+
}
8597

8698
new_params.push(GenericParam::Type(type_param));
8799
param_count += 1;

async-openai/Cargo.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ chat-completion = ["chat-completion-types", "_api"]
4949
assistant = ["assistant-types", "_api", ]
5050
administration = ["administration-types", "_api"]
5151
completions = ["completion-types", "_api"]
52+
middleware = ["dep:tower", "_api", "async-openai-macros?/middleware"]
5253

5354
# Type feature flags - these enable only the types
5455
response-types = ["dep:derive_builder"]
@@ -128,12 +129,12 @@ full = [
128129
"completions",
129130
"types",
130131
"byot",
132+
"middleware",
131133
]
132134

133135
# Internal feature to enable API dependencies
134136
_api = [
135137
"dep:async-openai-macros",
136-
"dep:backoff",
137138
"dep:base64",
138139
"dep:bytes",
139140
"dep:futures",
@@ -144,6 +145,7 @@ _api = [
144145
"dep:tokio-stream",
145146
"dep:tokio-util",
146147
"dep:tracing",
148+
"dep:tower",
147149
"dep:secrecy",
148150
"dep:eventsource-stream",
149151
"dep:serde_urlencoded",
@@ -160,7 +162,7 @@ bytes = { version = "1.11", optional = true }
160162

161163
# API dependencies - only needed when API features are enabled
162164
# We use a feature gate to enable these when any API feature is enabled
163-
async-openai-macros = { path = "../async-openai-macros", version = "0.1.1", optional = true }
165+
async-openai-macros = { path = "../async-openai-macros", version = "0.2.0", optional = true }
164166
base64 = { version = "0.22", optional = true }
165167
rand = { version = "0.9", optional = true }
166168
reqwest = { version = "0.13", features = [
@@ -174,14 +176,14 @@ tracing = { version = "0.1", optional = true }
174176
secrecy = { version = "0.10", features = ["serde"], optional = true }
175177
serde_urlencoded = { version = "0.7", optional = true }
176178
url = { version = "2.5", optional = true }
179+
tower = { version = "0.5", features = ["limit", "retry", "timeout", "util"], optional = true }
177180
## For Webhook signature verification
178181
hmac = { version = "0.12", optional = true, default-features = false}
179182
sha2 = { version = "0.10", optional = true, default-features = false }
180183
hex = { version = "0.4", optional = true, default-features = false }
181184

182185
## API Non-WASM dependencies (streaming and retry is not implemented for WASM yet)
183186
[target.'cfg(not(target_family = "wasm"))'.dependencies]
184-
backoff = { version = "0.4.0", features = ["tokio"], optional = true }
185187
futures = { version = "0.3", optional = true }
186188
tokio = { version = "1", features = ["fs", "macros"], optional = true }
187189
tokio-stream = { version = "0.1", optional = true }
@@ -195,6 +197,7 @@ tokio-tungstenite = { version = "0.28", optional = true, default-features = fals
195197
getrandom = { version = "0.3", features = ["wasm_js"] }
196198

197199
[dev-dependencies]
200+
http = "1"
198201
tokio-test = "0.4"
199202
serde_json = "1"
200203

async-openai/MIDDLEWARE.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Tower based middlewares
2+
3+
Enable the `middleware` feature to customize the HTTP execution path with Tower services and layers.
4+
5+
The middleware boundary is intentionally below the API groups and above the concrete HTTP transport, an example middleware stack:
6+
7+
```text
8+
async-openai API groups
9+
responses(), chat(), files(), ...
10+
|
11+
v
12+
HttpRequestFactory
13+
|
14+
v
15+
+----- concurrency_limit ------+
16+
| +------- timeout ----------+ |
17+
| | +-- OpenAIRetryLayer --+ | |
18+
| | | | | |
19+
| | | ReqwestService or | | |
20+
| | | custom service | | |
21+
| | | | | |
22+
| | +-- OpenAIRetryLayer --+ | |
23+
| +------- timeout ----------+ |
24+
+----- concurrency_limit ------+
25+
|
26+
v
27+
reqwest::Response
28+
```
29+
30+
The request value passed through tower is `HttpRequestFactory`, not `reqwest::Request`. This is deliberate: `reqwest::Request` is not generally cloneable once it contains a streaming body, but retry middleware needs a way to replay a request. The factory is cheap to clone and rebuilds a fresh `reqwest::Request` for each attempt.
31+
32+
## Use the Default `ReqwestService`
33+
34+
`ReqwestService` is a tower service backed by `reqwest::Client`. It is used by default to make outbound HTTP requests.
35+
36+
```rust
37+
use async_openai::{Client, config::OpenAIConfig};
38+
use async_openai::middleware::{retry::OpenAIRetryLayer, ReqwestService};
39+
use std::time::Duration;
40+
41+
let service = tower::ServiceBuilder::new()
42+
.concurrency_limit(8)
43+
.timeout(Duration::from_secs(30))
44+
.layer(OpenAIRetryLayer::default())
45+
.service(ReqwestService::new(reqwest::Client::new()));
46+
47+
let client = Client::with_config(OpenAIConfig::default())
48+
.with_http_service(service);
49+
```
50+
51+
## Use a Custom Service
52+
53+
You can replace `ReqwestService` entirely. This is useful for logging, metrics, tests, mocks, alternate transports, or policy layers that want to inspect the generated request before sending it.
54+
55+
```rust
56+
use async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
57+
use async_openai::middleware::HttpRequestFactory;
58+
use tower::service_fn;
59+
60+
let service = service_fn(|factory: HttpRequestFactory| async move {
61+
let request = factory.build().await?;
62+
63+
// here you can inspect, modify, or log the request, route it somewhere else,
64+
// or return a synthetic response for testing.
65+
66+
println!("sending {} {}", request.method(), request.url());
67+
68+
reqwest::Client::new()
69+
.execute(request)
70+
.await
71+
.map_err(OpenAIError::Reqwest)
72+
});
73+
74+
let client = Client::with_config(OpenAIConfig::default())
75+
.with_http_service(service);
76+
```
77+
78+
## Retry layer
79+
80+
`middleware::retry::OpenAIRetryLayer` is a Tower layer and `middleware::retry::SimpleRetryPolicy` is a Tower retry policy.
81+
82+
Both attempt retries with exponential backoff on `429`, `5xx` and connection errors and respects `Retry-After` header.
83+
84+
The difference is that upon seeing 429, `OpenAIRetryLayer` consumes response body to check if it is a rate limit (retryable error) or insufficient quota (permanent error). The default async-openai client uses this layer internally for library's default retry behavior.
85+
86+
The retry boundary is `HttpRequestFactory`. Retrying clones the factory and rebuilds a fresh `reqwest::Request` for each attempt instead of cloning a built request. That matters because `reqwest::Request` is not Clone.
87+
88+
`middleware::retry::SimpleRetryPolicy` uses `middleware::retry::should_retry` to determine if a request should be retried.
89+
90+
Custom tower retry policies can call `middleware::retry::should_retry` to reuse the same retry classification while changing delay behavior.
91+
92+
On native targets retries wait using `tokio::time::sleep`. On WASM retries are immediate.
93+
94+
## Native and WASM bounds
95+
96+
The conceptual middleware boundary stays the same; only the platform thread-safety bounds differ.
97+
98+
On native targets, middleware services installed with `Client::with_http_service` must be `Send + Sync + 'static` and return `Send + 'static` futures.
99+
100+
On WASM targets, middleware services and futures must be `'static`.
101+
102+
## Bring Your Own Types Interaction
103+
104+
With the `byot` feature, generated `*_byot` methods keep minimal trait bounds. When `middleware` feature is enabled additional `MiddlewareInput` bounds are added based on native or WASM targets so the input can be stored long enough to rebuild a fresh request for retries.
105+
106+
## Error Handling
107+
108+
`OpenAIError::Boxed` is available only when the `middleware` feature is enabled.
109+
110+
Custom middleware services installed with `Client::with_http_service` may use any error type that implements `Into<OpenAIError>`. This lets middleware preserve structured errors when it has a dedicated `OpenAIError` conversion.
111+
112+
Tower's `BoxError` converts into `OpenAIError::Boxed`, which is useful for generic tower layers whose concrete error type is erased. Callers can still downcast the boxed error when they know the original error type.

async-openai/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ Features that makes `async-openai` unique:
4444
- Ergonomic builder pattern for all request objects.
4545
- Granular feature flags to enable any types or apis: good for faster compilation and crate reuse.
4646
- Microsoft Azure OpenAI Service (only for APIs matching OpenAI spec).
47-
- WASM (doesn't include streaming and retry support yet)
47+
- WASM (doesn't support streaming yet)
48+
- Middleware support with [tower](https://crates.io/crates/tower) ecosystem
4849

4950
## Usage
5051

@@ -236,6 +237,10 @@ fn chat_completion(client: &Client<Box<dyn Config>>) {
236237
}
237238
```
238239

240+
## Middleware
241+
242+
Middleware is supported via Tower ecosystem, which can be enabled with `middleware` feature. See [middleware](https://github.com/64bit/async-openai/blob/main/async-openai/MIDDLEWARE.md) for more detail.
243+
239244
## Contributing
240245

241246
🎉 Thank you for taking the time to contribute and improve the project. I'd be happy to have you!

async-openai/src/assistants/runs.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ impl<'c, C: Config> Runs<'c, C> {
7676
request.stream = Some(true);
7777
}
7878

79-
Ok(self
80-
.client
79+
self.client
8180
.post_stream_mapped_raw_events(
8281
&format!("/threads/{}/runs", self.thread_id),
8382
request,
8483
&self.request_options,
8584
TryFrom::try_from,
8685
)
87-
.await)
86+
.await
8887
}
8988

9089
/// Retrieves a run.
@@ -170,8 +169,7 @@ impl<'c, C: Config> Runs<'c, C> {
170169
request.stream = Some(true);
171170
}
172171

173-
Ok(self
174-
.client
172+
self.client
175173
.post_stream_mapped_raw_events(
176174
&format!(
177175
"/threads/{}/runs/{run_id}/submit_tool_outputs",
@@ -181,7 +179,7 @@ impl<'c, C: Config> Runs<'c, C> {
181179
&self.request_options,
182180
TryFrom::try_from,
183181
)
184-
.await)
182+
.await
185183
}
186184

187185
/// Cancels a run that is `in_progress`

async-openai/src/assistants/threads.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ impl<'c, C: Config> Threads<'c, C> {
7676

7777
request.stream = Some(true);
7878
}
79-
Ok(self
80-
.client
79+
self.client
8180
.post_stream_mapped_raw_events(
8281
"/threads/runs",
8382
request,
8483
&self.request_options,
8584
TryFrom::try_from,
8685
)
87-
.await)
86+
.await
8887
}
8988

9089
/// Create a thread.

async-openai/src/audio/speech.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ impl<'c, C: Config> Speech<'c, C> {
6060

6161
request.stream_format = Some(StreamFormat::SSE);
6262
}
63-
Ok(self
64-
.client
63+
self.client
6564
.post_stream("/audio/speech", request, &self.request_options)
66-
.await)
65+
.await
6766
}
6867
}

0 commit comments

Comments
 (0)