@@ -109,7 +109,7 @@ func TestAnthropicMessages(t *testing.T) {
109109
110110 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
111111 t .Cleanup (cancel )
112- srv := newMockServer (ctx , t , files , nil )
112+ srv := newMockServer (ctx , t , files , nil , nil )
113113 t .Cleanup (srv .Close )
114114
115115 recorderClient := & testutil.MockRecorder {}
@@ -379,7 +379,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
379379
380380 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
381381 t .Cleanup (cancel )
382- srv := newMockServer (ctx , t , files , nil )
382+ srv := newMockServer (ctx , t , files , nil , nil )
383383 t .Cleanup (srv .Close )
384384
385385 recorderClient := & testutil.MockRecorder {}
@@ -483,7 +483,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
483483 t .Cleanup (cancel )
484484
485485 // Setup mock server with response mutator for multi-turn interaction.
486- srv := newMockServer (ctx , t , files , func (reqCount uint32 , resp []byte ) []byte {
486+ srv := newMockServer (ctx , t , files , nil , func (reqCount uint32 , resp []byte ) []byte {
487487 if reqCount == 1 {
488488 // First request gets the tool call response
489489 return resp
@@ -556,91 +556,128 @@ func TestOpenAIChatCompletions(t *testing.T) {
556556func TestSimple (t * testing.T ) {
557557 t .Parallel ()
558558
559+ getAnthropicResponseID := func (streaming bool , resp * http.Response ) (string , error ) {
560+ if streaming {
561+ decoder := ssestream .NewDecoder (resp )
562+ stream := ssestream .NewStream [anthropic.MessageStreamEventUnion ](decoder , nil )
563+ var message anthropic.Message
564+ for stream .Next () {
565+ event := stream .Current ()
566+ if err := message .Accumulate (event ); err != nil {
567+ return "" , fmt .Errorf ("accumulate event: %w" , err )
568+ }
569+ }
570+ if stream .Err () != nil {
571+ return "" , fmt .Errorf ("stream error: %w" , stream .Err ())
572+ }
573+ return message .ID , nil
574+ }
575+
576+ body , err := io .ReadAll (resp .Body )
577+ if err != nil {
578+ return "" , fmt .Errorf ("read body: %w" , err )
579+ }
580+
581+ var message anthropic.Message
582+ if err := json .Unmarshal (body , & message ); err != nil {
583+ return "" , fmt .Errorf ("unmarshal response: %w" , err )
584+ }
585+ return message .ID , nil
586+ }
587+
588+ getOpenAIResponseID := func (streaming bool , resp * http.Response ) (string , error ) {
589+ if streaming {
590+ // Parse the response stream.
591+ decoder := oaissestream .NewDecoder (resp )
592+ stream := oaissestream .NewStream [openai.ChatCompletionChunk ](decoder , nil )
593+ var message openai.ChatCompletionAccumulator
594+ for stream .Next () {
595+ chunk := stream .Current ()
596+ message .AddChunk (chunk )
597+ }
598+ if stream .Err () != nil {
599+ return "" , fmt .Errorf ("stream error: %w" , stream .Err ())
600+ }
601+ return message .ID , nil
602+ }
603+
604+ // Parse & unmarshal the response.
605+ body , err := io .ReadAll (resp .Body )
606+ if err != nil {
607+ return "" , fmt .Errorf ("read body: %w" , err )
608+ }
609+
610+ var message openai.ChatCompletion
611+ if err := json .Unmarshal (body , & message ); err != nil {
612+ return "" , fmt .Errorf ("unmarshal response: %w" , err )
613+ }
614+ return message .ID , nil
615+ }
616+
617+ // Common configuration functions for each provider type.
618+ configureAnthropic := func (t * testing.T , addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
619+ t .Helper ()
620+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
621+ providers := []aibridge.Provider {provider .NewAnthropic (anthropicCfg (addr , apiKey ), nil )}
622+ return aibridge .NewRequestBridge (t .Context (), providers , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
623+ }
624+
625+ configureOpenAI := func (t * testing.T , addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
626+ t .Helper ()
627+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
628+ providers := []aibridge.Provider {provider .NewOpenAI (openaiCfg (addr , apiKey ))}
629+ return aibridge .NewRequestBridge (t .Context (), providers , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
630+ }
631+
559632 testCases := []struct {
560633 name string
561634 fixture []byte
562- configureFunc func (string , aibridge.Recorder ) (* aibridge.RequestBridge , error )
563- getResponseIDFunc func (bool , * http.Response ) (string , error )
635+ basePath string
636+ expectedPath string
637+ configureFunc func (* testing.T , string , aibridge.Recorder ) (* aibridge.RequestBridge , error )
638+ getResponseIDFunc func (streaming bool , resp * http.Response ) (string , error )
564639 createRequest func (* testing.T , string , []byte ) * http.Request
565640 expectedMsgID string
566641 }{
567642 {
568- name : config .ProviderAnthropic ,
569- fixture : fixtures .AntSimple ,
570- configureFunc : func (addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
571- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
572- provider := []aibridge.Provider {provider .NewAnthropic (anthropicCfg (addr , apiKey ), nil )}
573- return aibridge .NewRequestBridge (t .Context (), provider , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
574- },
575- getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
576- if streaming {
577- decoder := ssestream .NewDecoder (resp )
578- stream := ssestream .NewStream [anthropic.MessageStreamEventUnion ](decoder , nil )
579- var message anthropic.Message
580- for stream .Next () {
581- event := stream .Current ()
582- if err := message .Accumulate (event ); err != nil {
583- return "" , fmt .Errorf ("accumulate event: %w" , err )
584- }
585- }
586- if stream .Err () != nil {
587- return "" , fmt .Errorf ("stream error: %w" , stream .Err ())
588- }
589- return message .ID , nil
590- }
591-
592- body , err := io .ReadAll (resp .Body )
593- if err != nil {
594- return "" , fmt .Errorf ("read body: %w" , err )
595- }
596-
597- var message anthropic.Message
598- if err := json .Unmarshal (body , & message ); err != nil {
599- return "" , fmt .Errorf ("unmarshal response: %w" , err )
600- }
601- return message .ID , nil
602- },
603- createRequest : createAnthropicMessagesReq ,
604- expectedMsgID : "msg_01Pvyf26bY17RcjmWfJsXGBn" ,
643+ name : config .ProviderAnthropic ,
644+ fixture : fixtures .AntSimple ,
645+ basePath : "" ,
646+ expectedPath : "/v1/messages" ,
647+ configureFunc : configureAnthropic ,
648+ getResponseIDFunc : getAnthropicResponseID ,
649+ createRequest : createAnthropicMessagesReq ,
650+ expectedMsgID : "msg_01Pvyf26bY17RcjmWfJsXGBn" ,
605651 },
606652 {
607- name : config .ProviderOpenAI ,
608- fixture : fixtures .OaiChatSimple ,
609- configureFunc : func (addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
610- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
611- providers := []aibridge.Provider {provider .NewOpenAI (openaiCfg (addr , apiKey ))}
612- return aibridge .NewRequestBridge (t .Context (), providers , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
613- },
614- getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
615- if streaming {
616- // Parse the response stream.
617- decoder := oaissestream .NewDecoder (resp )
618- stream := oaissestream .NewStream [openai.ChatCompletionChunk ](decoder , nil )
619- var message openai.ChatCompletionAccumulator
620- for stream .Next () {
621- chunk := stream .Current ()
622- message .AddChunk (chunk )
623- }
624- if stream .Err () != nil {
625- return "" , fmt .Errorf ("stream error: %w" , stream .Err ())
626- }
627- return message .ID , nil
628- }
629-
630- // Parse & unmarshal the response.
631- body , err := io .ReadAll (resp .Body )
632- if err != nil {
633- return "" , fmt .Errorf ("read body: %w" , err )
634- }
635-
636- var message openai.ChatCompletion
637- if err := json .Unmarshal (body , & message ); err != nil {
638- return "" , fmt .Errorf ("unmarshal response: %w" , err )
639- }
640- return message .ID , nil
641- },
642- createRequest : createOpenAIChatCompletionsReq ,
643- expectedMsgID : "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N" ,
653+ name : config .ProviderOpenAI ,
654+ fixture : fixtures .OaiChatSimple ,
655+ basePath : "" ,
656+ expectedPath : "/chat/completions" ,
657+ configureFunc : configureOpenAI ,
658+ getResponseIDFunc : getOpenAIResponseID ,
659+ createRequest : createOpenAIChatCompletionsReq ,
660+ expectedMsgID : "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N" ,
661+ },
662+ {
663+ name : config .ProviderAnthropic + "_baseURL_path" ,
664+ fixture : fixtures .AntSimple ,
665+ basePath : "/api" ,
666+ expectedPath : "/api/v1/messages" ,
667+ configureFunc : configureAnthropic ,
668+ getResponseIDFunc : getAnthropicResponseID ,
669+ createRequest : createAnthropicMessagesReq ,
670+ expectedMsgID : "msg_01Pvyf26bY17RcjmWfJsXGBn" ,
671+ },
672+ {
673+ name : config .ProviderOpenAI + "_baseURL_path" ,
674+ fixture : fixtures .OaiChatSimple ,
675+ basePath : "/api" ,
676+ expectedPath : "/api/chat/completions" ,
677+ configureFunc : configureOpenAI ,
678+ getResponseIDFunc : getOpenAIResponseID ,
679+ createRequest : createOpenAIChatCompletionsReq ,
680+ expectedMsgID : "chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N" ,
644681 },
645682 }
646683
@@ -671,12 +708,14 @@ func TestSimple(t *testing.T) {
671708 // Given: a mock API server and a Bridge through which the requests will flow.
672709 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
673710 t .Cleanup (cancel )
674- srv := newMockServer (ctx , t , files , nil )
711+ srv := newMockServer (ctx , t , files , func (r * http.Request ) {
712+ require .Equal (t , tc .expectedPath , r .URL .Path )
713+ }, nil )
675714 t .Cleanup (srv .Close )
676715
677716 recorderClient := & testutil.MockRecorder {}
678717
679- b , err := tc .configureFunc (srv .URL , recorderClient )
718+ b , err := tc .configureFunc (t , srv .URL + tc . basePath , recorderClient )
680719 require .NoError (t , err )
681720
682721 mockSrv := httptest .NewUnstartedServer (b )
@@ -734,12 +773,16 @@ func TestFallthrough(t *testing.T) {
734773
735774 testCases := []struct {
736775 name string
776+ providerName string
737777 fixture []byte
778+ basePath string
738779 configureFunc func (string , aibridge.Recorder ) (aibridge.Provider , * aibridge.RequestBridge )
739780 }{
740781 {
741- name : config .ProviderAnthropic ,
742- fixture : fixtures .AntFallthrough ,
782+ name : "ant_empty_base_url_path" ,
783+ providerName : config .ProviderAnthropic ,
784+ fixture : fixtures .AntFallthrough ,
785+ basePath : "" ,
743786 configureFunc : func (addr string , client aibridge.Recorder ) (aibridge.Provider , * aibridge.RequestBridge ) {
744787 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
745788 provider := provider .NewAnthropic (anthropicCfg (addr , apiKey ), nil )
@@ -749,8 +792,36 @@ func TestFallthrough(t *testing.T) {
749792 },
750793 },
751794 {
752- name : config .ProviderOpenAI ,
753- fixture : fixtures .OaiChatFallthrough ,
795+ name : "oai_empty_base_url_path" ,
796+ providerName : config .ProviderOpenAI ,
797+ fixture : fixtures .OaiChatFallthrough ,
798+ basePath : "" ,
799+ configureFunc : func (addr string , client aibridge.Recorder ) (aibridge.Provider , * aibridge.RequestBridge ) {
800+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
801+ provider := provider .NewOpenAI (openaiCfg (addr , apiKey ))
802+ bridge , err := aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {provider }, client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
803+ require .NoError (t , err )
804+ return provider , bridge
805+ },
806+ },
807+ {
808+ name : "ant_some_base_url_path" ,
809+ providerName : config .ProviderAnthropic ,
810+ fixture : fixtures .AntFallthrough ,
811+ basePath : "/api" ,
812+ configureFunc : func (addr string , client aibridge.Recorder ) (aibridge.Provider , * aibridge.RequestBridge ) {
813+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
814+ provider := provider .NewAnthropic (anthropicCfg (addr , apiKey ), nil )
815+ bridge , err := aibridge .NewRequestBridge (t .Context (), []aibridge.Provider {provider }, client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
816+ require .NoError (t , err )
817+ return provider , bridge
818+ },
819+ },
820+ {
821+ name : "oai_some_base_url_path" ,
822+ providerName : config .ProviderOpenAI ,
823+ fixture : fixtures .OaiChatFallthrough ,
824+ basePath : "/api" ,
754825 configureFunc : func (addr string , client aibridge.Recorder ) (aibridge.Provider , * aibridge.RequestBridge ) {
755826 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
756827 provider := provider .NewOpenAI (openaiCfg (addr , apiKey ))
@@ -770,11 +841,12 @@ func TestFallthrough(t *testing.T) {
770841
771842 files := filesMap (arc )
772843 require .Contains (t , files , fixtureResponse )
844+ expectedPath := tc .basePath + "/v1/models"
773845
774846 var receivedHeaders * http.Header
775847 respBody := files [fixtureResponse ]
776848 upstream := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
777- if r .URL .Path != "/v1/models" {
849+ if r .URL .Path != expectedPath {
778850 t .Errorf ("unexpected request path: %q" , r .URL .Path )
779851 t .FailNow ()
780852 }
@@ -789,7 +861,8 @@ func TestFallthrough(t *testing.T) {
789861
790862 recorderClient := & testutil.MockRecorder {}
791863
792- provider , bridge := tc .configureFunc (upstream .URL , recorderClient )
864+ upstreamURL := upstream .URL + tc .basePath
865+ provider , bridge := tc .configureFunc (upstreamURL , recorderClient )
793866
794867 bridgeSrv := httptest .NewUnstartedServer (bridge )
795868 bridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
@@ -798,7 +871,7 @@ func TestFallthrough(t *testing.T) {
798871 bridgeSrv .Start ()
799872 t .Cleanup (bridgeSrv .Close )
800873
801- req , err := http .NewRequestWithContext (t .Context (), "GET" , fmt .Sprintf ("%s/%s/v1/models" , bridgeSrv .URL , tc .name ), nil )
874+ req , err := http .NewRequestWithContext (t .Context (), "GET" , fmt .Sprintf ("%s/%s/v1/models" , bridgeSrv .URL , tc .providerName ), nil )
802875 require .NoError (t , err )
803876
804877 resp , err := http .DefaultClient .Do (req )
@@ -1074,7 +1147,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
10741147 t .Cleanup (cancel )
10751148
10761149 // Setup mock server with response mutator for multi-turn interaction.
1077- mockSrv := newMockServer (ctx , t , files , func (reqCount uint32 , resp []byte ) []byte {
1150+ mockSrv := newMockServer (ctx , t , files , nil , func (reqCount uint32 , resp []byte ) []byte {
10781151 if reqCount == 1 {
10791152 return resp // First request gets the normal response (with tool call).
10801153 }
@@ -1310,7 +1383,7 @@ func TestErrorHandling(t *testing.T) {
13101383 reqBody := files [fixtureRequest ]
13111384
13121385 // Setup mock server.
1313- mockSrv := newMockServer (ctx , t , files , nil )
1386+ mockSrv := newMockServer (ctx , t , files , nil , nil )
13141387 mockSrv .statusCode = http .StatusInternalServerError
13151388 t .Cleanup (mockSrv .Close )
13161389
@@ -1983,11 +2056,15 @@ type mockServer struct {
19832056 statusCode int
19842057}
19852058
1986- func newMockServer (ctx context.Context , t * testing.T , files archiveFileMap , responseMutatorFn func (reqCount uint32 , resp []byte ) []byte ) * mockServer {
2059+ func newMockServer (ctx context.Context , t * testing.T , files archiveFileMap , requestValidatorFn func ( * http. Request ), responseMutatorFn func (reqCount uint32 , resp []byte ) []byte ) * mockServer {
19872060 t .Helper ()
19882061
19892062 ms := & mockServer {}
19902063 srv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
2064+ if requestValidatorFn != nil {
2065+ requestValidatorFn (r )
2066+ }
2067+
19912068 statusCode := http .StatusOK
19922069 if ms .statusCode != 0 {
19932070 statusCode = ms .statusCode
0 commit comments