Skip to content

Commit 814e416

Browse files
committed
feat: validate duplicate provider names in NewRequestBridge
1 parent 97aea98 commit 814e416

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

bridge.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ type RequestBridge struct {
5757

5858
var _ http.Handler = &RequestBridge{}
5959

60+
// validateProviders checks that no two providers share the same name.
61+
func validateProviders(providers []provider.Provider) error {
62+
names := make(map[string]bool, len(providers))
63+
for _, prov := range providers {
64+
if names[prov.Name()] {
65+
return fmt.Errorf("duplicate provider name: %q", prov.Name())
66+
}
67+
names[prov.Name()] = true
68+
}
69+
// TODO(ssncferreira): validate duplicate baseURLs as well
70+
return nil
71+
}
72+
6073
// NewRequestBridge creates a new *[RequestBridge] and registers the HTTP routes defined by the given providers.
6174
// Any routes which are requested but not registered will be reverse-proxied to the upstream service.
6275
//
@@ -67,6 +80,10 @@ var _ http.Handler = &RequestBridge{}
6780
// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method.
6881
// Providers returning nil will not have circuit breaker protection.
6982
func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) {
83+
if err := validateProviders(providers); err != nil {
84+
return nil, err
85+
}
86+
7087
mux := http.NewServeMux()
7188

7289
for _, prov := range providers {

bridge_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,72 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16+
func TestValidateProviders(t *testing.T) {
17+
t.Parallel()
18+
19+
tests := []struct {
20+
name string
21+
providers []provider.Provider
22+
expectErr string
23+
}{
24+
{
25+
name: "all_supported_providers",
26+
providers: []provider.Provider{
27+
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
28+
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
29+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
30+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
31+
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
32+
},
33+
},
34+
{
35+
name: "default_names_and_base_urls",
36+
providers: []provider.Provider{
37+
NewOpenAIProvider(config.OpenAI{}),
38+
NewAnthropicProvider(config.Anthropic{}, nil),
39+
NewCopilotProvider(config.Copilot{}),
40+
},
41+
},
42+
{
43+
name: "multiple_copilot_instances",
44+
providers: []provider.Provider{
45+
NewCopilotProvider(config.Copilot{}),
46+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
47+
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
48+
},
49+
},
50+
{
51+
name: "duplicate_name",
52+
providers: []provider.Provider{
53+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
54+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
55+
},
56+
expectErr: "duplicate provider name",
57+
},
58+
{
59+
name: "duplicate_base_url_different_names",
60+
providers: []provider.Provider{
61+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
62+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
63+
},
64+
},
65+
}
66+
67+
for _, tc := range tests {
68+
t.Run(tc.name, func(t *testing.T) {
69+
t.Parallel()
70+
71+
err := validateProviders(tc.providers)
72+
if tc.expectErr != "" {
73+
require.Error(t, err)
74+
assert.Contains(t, err.Error(), tc.expectErr)
75+
} else {
76+
require.NoError(t, err)
77+
}
78+
})
79+
}
80+
}
81+
1682
func TestPassthroughRoutesForProviders(t *testing.T) {
1783
t.Parallel()
1884

0 commit comments

Comments
 (0)