Skip to content

Commit 8ee1e8b

Browse files
committed
feat: make MaxRetries configurable in OpenAI provider (#115)
Add MaxRetries *int to config.OpenAI. When set, option.WithMaxRetries is passed to the SDK client in both responses and chat completions interceptors. Nil preserves the SDK default (2 retries); 0 disables retries entirely. Update TestClientAndConnectionError and TestUpstreamError to set MaxRetries=0, eliminating retry delays and speeding up these tests.
1 parent 1d53ad0 commit 8ee1e8b

4 files changed

Lines changed: 20 additions & 8 deletions

File tree

config/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ type OpenAI struct {
4343
CircuitBreaker *CircuitBreaker
4444
SendActorHeaders bool
4545
ExtraHeaders map[string]string
46+
// MaxRetries controls the number of automatic retries the SDK will perform
47+
// on transient errors. If nil, the SDK default (2) is used.
48+
// Set to 0 to disable retries entirely.
49+
MaxRetries *int
4650
}
4751

4852
type Copilot struct {

intercept/chatcompletions/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ type interceptionBase struct {
4747

4848
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
4949
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
50+
if i.cfg.MaxRetries != nil {
51+
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
52+
}
5053

5154
// Add extra headers if configured.
5255
// Some providers require additional headers that are not added by the SDK.

intercept/responses/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ type responsesInterceptionBase struct {
5555

5656
func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
5757
opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)}
58+
if i.cfg.MaxRetries != nil {
59+
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
60+
}
5861

5962
// Add extra headers if configured.
6063
// Some providers require additional headers that are not added by the SDK.

internal/integrationtest/responses_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,6 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
596596
}
597597
}
598598

599-
// TODO set MaxRetries to speed up this test
600-
// option.WithMaxRetries(0), in base responses interceptor
601-
// https://github.com/coder/aibridge/issues/115
602599
func TestClientAndConnectionError(t *testing.T) {
603600
t.Parallel()
604601

@@ -642,7 +639,11 @@ func TestClientAndConnectionError(t *testing.T) {
642639
t.Cleanup(cancel)
643640

644641
// tc.addr may be an intentionally invalid URL; use withCustomProvider.
645-
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
642+
// MaxRetries is set to 0 to disable SDK retries and speed up the test.
643+
cfg := openAICfg(tc.addr, apiKey)
644+
maxRetries := 0
645+
cfg.MaxRetries = &maxRetries
646+
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(cfg)))
646647

647648
reqBytes := responsesRequestBytes(t, tc.streaming)
648649
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
@@ -660,9 +661,6 @@ func TestClientAndConnectionError(t *testing.T) {
660661
}
661662
}
662663

663-
// TODO set MaxRetries to speed up this test
664-
// option.WithMaxRetries(0), in base responses interceptor
665-
// https://github.com/coder/aibridge/issues/115
666664
func TestUpstreamError(t *testing.T) {
667665
t.Parallel()
668666

@@ -721,7 +719,11 @@ func TestUpstreamError(t *testing.T) {
721719
}))
722720
t.Cleanup(upstream.Close)
723721

724-
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
722+
// MaxRetries is set to 0 to disable SDK retries and speed up the test.
723+
cfg := openAICfg(upstream.URL, apiKey)
724+
maxRetries := 0
725+
cfg.MaxRetries = &maxRetries
726+
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(provider.NewOpenAI(cfg)))
725727

726728
reqBytes := responsesRequestBytes(t, tc.streaming)
727729
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)

0 commit comments

Comments
 (0)