Skip to content

Commit 406f98f

Browse files
authored
feat: add path of providers base url to pass through requests (#159)
When requst is being passed though the change makes it so paht of the forwarded request incudles path of configured base url. Example: providers base url: `http://some.domain/some/path` pass though route: `/route` url of passed though request before: `http://some.domain/route` after the change: `http://some.domain/some/path/route`
1 parent da78629 commit 406f98f

8 files changed

Lines changed: 311 additions & 110 deletions

apidump_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func TestAPIDump(t *testing.T) {
100100
reqBody := files[fixtureRequest]
101101

102102
// Setup mock upstream server.
103-
srv := newMockServer(ctx, t, files, nil)
103+
srv := newMockServer(ctx, t, files, nil, nil)
104104
t.Cleanup(srv.Close)
105105

106106
// Create temp dir for API dumps.

bridge_integration_test.go

Lines changed: 168 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestAnthropicMessages(t *testing.T) {
109109

110110
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
111111
t.Cleanup(cancel)
112-
srv := newMockServer(ctx, t, files, nil)
112+
srv := newMockServer(ctx, t, files, nil, nil)
113113
t.Cleanup(srv.Close)
114114

115115
recorderClient := &testutil.MockRecorder{}
@@ -379,7 +379,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
379379

380380
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
381381
t.Cleanup(cancel)
382-
srv := newMockServer(ctx, t, files, nil)
382+
srv := newMockServer(ctx, t, files, nil, nil)
383383
t.Cleanup(srv.Close)
384384

385385
recorderClient := &testutil.MockRecorder{}
@@ -483,7 +483,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
483483
t.Cleanup(cancel)
484484

485485
// Setup mock server with response mutator for multi-turn interaction.
486-
srv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte {
486+
srv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte {
487487
if reqCount == 1 {
488488
// First request gets the tool call response
489489
return resp
@@ -556,91 +556,128 @@ func TestOpenAIChatCompletions(t *testing.T) {
556556
func TestSimple(t *testing.T) {
557557
t.Parallel()
558558

559+
getAnthropicResponseID := func(streaming bool, resp *http.Response) (string, error) {
560+
if streaming {
561+
decoder := ssestream.NewDecoder(resp)
562+
stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil)
563+
var message anthropic.Message
564+
for stream.Next() {
565+
event := stream.Current()
566+
if err := message.Accumulate(event); err != nil {
567+
return "", fmt.Errorf("accumulate event: %w", err)
568+
}
569+
}
570+
if stream.Err() != nil {
571+
return "", fmt.Errorf("stream error: %w", stream.Err())
572+
}
573+
return message.ID, nil
574+
}
575+
576+
body, err := io.ReadAll(resp.Body)
577+
if err != nil {
578+
return "", fmt.Errorf("read body: %w", err)
579+
}
580+
581+
var message anthropic.Message
582+
if err := json.Unmarshal(body, &message); err != nil {
583+
return "", fmt.Errorf("unmarshal response: %w", err)
584+
}
585+
return message.ID, nil
586+
}
587+
588+
getOpenAIResponseID := func(streaming bool, resp *http.Response) (string, error) {
589+
if streaming {
590+
// Parse the response stream.
591+
decoder := oaissestream.NewDecoder(resp)
592+
stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
593+
var message openai.ChatCompletionAccumulator
594+
for stream.Next() {
595+
chunk := stream.Current()
596+
message.AddChunk(chunk)
597+
}
598+
if stream.Err() != nil {
599+
return "", fmt.Errorf("stream error: %w", stream.Err())
600+
}
601+
return message.ID, nil
602+
}
603+
604+
// Parse & unmarshal the response.
605+
body, err := io.ReadAll(resp.Body)
606+
if err != nil {
607+
return "", fmt.Errorf("read body: %w", err)
608+
}
609+
610+
var message openai.ChatCompletion
611+
if err := json.Unmarshal(body, &message); err != nil {
612+
return "", fmt.Errorf("unmarshal response: %w", err)
613+
}
614+
return message.ID, nil
615+
}
616+
617+
// Common configuration functions for each provider type.
618+
configureAnthropic := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
619+
t.Helper()
620+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
621+
providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
622+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
623+
}
624+
625+
configureOpenAI := func(t *testing.T, addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
626+
t.Helper()
627+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
628+
providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))}
629+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
630+
}
631+
559632
testCases := []struct {
560633
name string
561634
fixture []byte
562-
configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error)
563-
getResponseIDFunc func(bool, *http.Response) (string, error)
635+
basePath string
636+
expectedPath string
637+
configureFunc func(*testing.T, string, aibridge.Recorder) (*aibridge.RequestBridge, error)
638+
getResponseIDFunc func(streaming bool, resp *http.Response) (string, error)
564639
createRequest func(*testing.T, string, []byte) *http.Request
565640
expectedMsgID string
566641
}{
567642
{
568-
name: config.ProviderAnthropic,
569-
fixture: fixtures.AntSimple,
570-
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
571-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
572-
provider := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)}
573-
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
574-
},
575-
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
576-
if streaming {
577-
decoder := ssestream.NewDecoder(resp)
578-
stream := ssestream.NewStream[anthropic.MessageStreamEventUnion](decoder, nil)
579-
var message anthropic.Message
580-
for stream.Next() {
581-
event := stream.Current()
582-
if err := message.Accumulate(event); err != nil {
583-
return "", fmt.Errorf("accumulate event: %w", err)
584-
}
585-
}
586-
if stream.Err() != nil {
587-
return "", fmt.Errorf("stream error: %w", stream.Err())
588-
}
589-
return message.ID, nil
590-
}
591-
592-
body, err := io.ReadAll(resp.Body)
593-
if err != nil {
594-
return "", fmt.Errorf("read body: %w", err)
595-
}
596-
597-
var message anthropic.Message
598-
if err := json.Unmarshal(body, &message); err != nil {
599-
return "", fmt.Errorf("unmarshal response: %w", err)
600-
}
601-
return message.ID, nil
602-
},
603-
createRequest: createAnthropicMessagesReq,
604-
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
643+
name: config.ProviderAnthropic,
644+
fixture: fixtures.AntSimple,
645+
basePath: "",
646+
expectedPath: "/v1/messages",
647+
configureFunc: configureAnthropic,
648+
getResponseIDFunc: getAnthropicResponseID,
649+
createRequest: createAnthropicMessagesReq,
650+
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
605651
},
606652
{
607-
name: config.ProviderOpenAI,
608-
fixture: fixtures.OaiChatSimple,
609-
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
610-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
611-
providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))}
612-
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
613-
},
614-
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
615-
if streaming {
616-
// Parse the response stream.
617-
decoder := oaissestream.NewDecoder(resp)
618-
stream := oaissestream.NewStream[openai.ChatCompletionChunk](decoder, nil)
619-
var message openai.ChatCompletionAccumulator
620-
for stream.Next() {
621-
chunk := stream.Current()
622-
message.AddChunk(chunk)
623-
}
624-
if stream.Err() != nil {
625-
return "", fmt.Errorf("stream error: %w", stream.Err())
626-
}
627-
return message.ID, nil
628-
}
629-
630-
// Parse & unmarshal the response.
631-
body, err := io.ReadAll(resp.Body)
632-
if err != nil {
633-
return "", fmt.Errorf("read body: %w", err)
634-
}
635-
636-
var message openai.ChatCompletion
637-
if err := json.Unmarshal(body, &message); err != nil {
638-
return "", fmt.Errorf("unmarshal response: %w", err)
639-
}
640-
return message.ID, nil
641-
},
642-
createRequest: createOpenAIChatCompletionsReq,
643-
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
653+
name: config.ProviderOpenAI,
654+
fixture: fixtures.OaiChatSimple,
655+
basePath: "",
656+
expectedPath: "/chat/completions",
657+
configureFunc: configureOpenAI,
658+
getResponseIDFunc: getOpenAIResponseID,
659+
createRequest: createOpenAIChatCompletionsReq,
660+
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
661+
},
662+
{
663+
name: config.ProviderAnthropic + "_baseURL_path",
664+
fixture: fixtures.AntSimple,
665+
basePath: "/api",
666+
expectedPath: "/api/v1/messages",
667+
configureFunc: configureAnthropic,
668+
getResponseIDFunc: getAnthropicResponseID,
669+
createRequest: createAnthropicMessagesReq,
670+
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
671+
},
672+
{
673+
name: config.ProviderOpenAI + "_baseURL_path",
674+
fixture: fixtures.OaiChatSimple,
675+
basePath: "/api",
676+
expectedPath: "/api/chat/completions",
677+
configureFunc: configureOpenAI,
678+
getResponseIDFunc: getOpenAIResponseID,
679+
createRequest: createOpenAIChatCompletionsReq,
680+
expectedMsgID: "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N",
644681
},
645682
}
646683

@@ -671,12 +708,14 @@ func TestSimple(t *testing.T) {
671708
// Given: a mock API server and a Bridge through which the requests will flow.
672709
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
673710
t.Cleanup(cancel)
674-
srv := newMockServer(ctx, t, files, nil)
711+
srv := newMockServer(ctx, t, files, func(r *http.Request) {
712+
require.Equal(t, tc.expectedPath, r.URL.Path)
713+
}, nil)
675714
t.Cleanup(srv.Close)
676715

677716
recorderClient := &testutil.MockRecorder{}
678717

679-
b, err := tc.configureFunc(srv.URL, recorderClient)
718+
b, err := tc.configureFunc(t, srv.URL+tc.basePath, recorderClient)
680719
require.NoError(t, err)
681720

682721
mockSrv := httptest.NewUnstartedServer(b)
@@ -734,12 +773,16 @@ func TestFallthrough(t *testing.T) {
734773

735774
testCases := []struct {
736775
name string
776+
providerName string
737777
fixture []byte
778+
basePath string
738779
configureFunc func(string, aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge)
739780
}{
740781
{
741-
name: config.ProviderAnthropic,
742-
fixture: fixtures.AntFallthrough,
782+
name: "ant_empty_base_url_path",
783+
providerName: config.ProviderAnthropic,
784+
fixture: fixtures.AntFallthrough,
785+
basePath: "",
743786
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
744787
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
745788
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
@@ -749,8 +792,36 @@ func TestFallthrough(t *testing.T) {
749792
},
750793
},
751794
{
752-
name: config.ProviderOpenAI,
753-
fixture: fixtures.OaiChatFallthrough,
795+
name: "oai_empty_base_url_path",
796+
providerName: config.ProviderOpenAI,
797+
fixture: fixtures.OaiChatFallthrough,
798+
basePath: "",
799+
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
800+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
801+
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
802+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
803+
require.NoError(t, err)
804+
return provider, bridge
805+
},
806+
},
807+
{
808+
name: "ant_some_base_url_path",
809+
providerName: config.ProviderAnthropic,
810+
fixture: fixtures.AntFallthrough,
811+
basePath: "/api",
812+
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
813+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
814+
provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)
815+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
816+
require.NoError(t, err)
817+
return provider, bridge
818+
},
819+
},
820+
{
821+
name: "oai_some_base_url_path",
822+
providerName: config.ProviderOpenAI,
823+
fixture: fixtures.OaiChatFallthrough,
824+
basePath: "/api",
754825
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
755826
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
756827
provider := provider.NewOpenAI(openaiCfg(addr, apiKey))
@@ -770,11 +841,12 @@ func TestFallthrough(t *testing.T) {
770841

771842
files := filesMap(arc)
772843
require.Contains(t, files, fixtureResponse)
844+
expectedPath := tc.basePath + "/v1/models"
773845

774846
var receivedHeaders *http.Header
775847
respBody := files[fixtureResponse]
776848
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
777-
if r.URL.Path != "/v1/models" {
849+
if r.URL.Path != expectedPath {
778850
t.Errorf("unexpected request path: %q", r.URL.Path)
779851
t.FailNow()
780852
}
@@ -789,7 +861,8 @@ func TestFallthrough(t *testing.T) {
789861

790862
recorderClient := &testutil.MockRecorder{}
791863

792-
provider, bridge := tc.configureFunc(upstream.URL, recorderClient)
864+
upstreamURL := upstream.URL + tc.basePath
865+
provider, bridge := tc.configureFunc(upstreamURL, recorderClient)
793866

794867
bridgeSrv := httptest.NewUnstartedServer(bridge)
795868
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
@@ -798,7 +871,7 @@ func TestFallthrough(t *testing.T) {
798871
bridgeSrv.Start()
799872
t.Cleanup(bridgeSrv.Close)
800873

801-
req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.name), nil)
874+
req, err := http.NewRequestWithContext(t.Context(), "GET", fmt.Sprintf("%s/%s/v1/models", bridgeSrv.URL, tc.providerName), nil)
802875
require.NoError(t, err)
803876

804877
resp, err := http.DefaultClient.Do(req)
@@ -1074,7 +1147,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10741147
t.Cleanup(cancel)
10751148

10761149
// Setup mock server with response mutator for multi-turn interaction.
1077-
mockSrv := newMockServer(ctx, t, files, func(reqCount uint32, resp []byte) []byte {
1150+
mockSrv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte {
10781151
if reqCount == 1 {
10791152
return resp // First request gets the normal response (with tool call).
10801153
}
@@ -1310,7 +1383,7 @@ func TestErrorHandling(t *testing.T) {
13101383
reqBody := files[fixtureRequest]
13111384

13121385
// Setup mock server.
1313-
mockSrv := newMockServer(ctx, t, files, nil)
1386+
mockSrv := newMockServer(ctx, t, files, nil, nil)
13141387
mockSrv.statusCode = http.StatusInternalServerError
13151388
t.Cleanup(mockSrv.Close)
13161389

@@ -1983,11 +2056,15 @@ type mockServer struct {
19832056
statusCode int
19842057
}
19852058

1986-
func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer {
2059+
func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requestValidatorFn func(*http.Request), responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer {
19872060
t.Helper()
19882061

19892062
ms := &mockServer{}
19902063
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2064+
if requestValidatorFn != nil {
2065+
requestValidatorFn(r)
2066+
}
2067+
19912068
statusCode := http.StatusOK
19922069
if ms.statusCode != 0 {
19932070
statusCode = ms.statusCode

0 commit comments

Comments
 (0)