Skip to content

Commit 519b082

Browse files
authored
feat: validate duplicate provider names in NewRequestBridge (#240)
## Description Adds validation to `NewRequestBridge` to prevent registering multiple providers with the same name. Duplicate names would cause route prefix conflicts and ambiguous metrics, logs, and circuit breaker behavior. ## Changes * Add `validateProviders()` that rejects duplicate provider names at startup * Add table-driven tests covering valid configurations and duplicate name detection Related to: #152 _Disclaimer: initially produced by Claude Opus 4.6, heavily modified and reviewed by @ssncferreira ._
1 parent 97aea98 commit 519b082

2 files changed

Lines changed: 143 additions & 0 deletions

File tree

bridge.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"net/url"
8+
"regexp"
89
"strings"
910
"sync"
1011
"sync/atomic"
@@ -57,6 +58,25 @@ type RequestBridge struct {
5758

5859
var _ http.Handler = &RequestBridge{}
5960

61+
// validProviderName matches names containing only lowercase alphanumeric characters and hyphens.
62+
var validProviderName = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
63+
64+
// validateProviders checks that provider names are valid and unique.
65+
func validateProviders(providers []provider.Provider) error {
66+
names := make(map[string]bool, len(providers))
67+
for _, prov := range providers {
68+
name := prov.Name()
69+
if !validProviderName.MatchString(name) {
70+
return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name)
71+
}
72+
if names[name] {
73+
return fmt.Errorf("duplicate provider name: %q", name)
74+
}
75+
names[name] = true
76+
}
77+
return nil
78+
}
79+
6080
// NewRequestBridge creates a new *[RequestBridge] and registers the HTTP routes defined by the given providers.
6181
// Any routes which are requested but not registered will be reverse-proxied to the upstream service.
6282
//
@@ -67,6 +87,10 @@ var _ http.Handler = &RequestBridge{}
6787
// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method.
6888
// Providers returning nil will not have circuit breaker protection.
6989
func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) {
90+
if err := validateProviders(providers); err != nil {
91+
return nil, err
92+
}
93+
7094
mux := http.NewServeMux()
7195

7296
for _, prov := range providers {

bridge_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,125 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16+
func TestValidateProvider_Names(t *testing.T) {
17+
t.Parallel()
18+
19+
tests := []struct {
20+
name string
21+
providers []provider.Provider
22+
expectErr string
23+
}{
24+
{
25+
name: "all_supported_providers",
26+
providers: []provider.Provider{
27+
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
28+
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
29+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
30+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
31+
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
32+
},
33+
},
34+
{
35+
name: "default_names_and_base_urls",
36+
providers: []provider.Provider{
37+
NewOpenAIProvider(config.OpenAI{}),
38+
NewAnthropicProvider(config.Anthropic{}, nil),
39+
NewCopilotProvider(config.Copilot{}),
40+
},
41+
},
42+
{
43+
name: "multiple_copilot_instances",
44+
providers: []provider.Provider{
45+
NewCopilotProvider(config.Copilot{}),
46+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
47+
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
48+
},
49+
},
50+
{
51+
name: "name_with_slashes",
52+
providers: []provider.Provider{
53+
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
54+
},
55+
expectErr: "invalid provider name",
56+
},
57+
{
58+
name: "name_with_spaces",
59+
providers: []provider.Provider{
60+
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
61+
},
62+
expectErr: "invalid provider name",
63+
},
64+
{
65+
name: "name_with_uppercase",
66+
providers: []provider.Provider{
67+
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
68+
},
69+
expectErr: "invalid provider name",
70+
},
71+
}
72+
73+
for _, tc := range tests {
74+
t.Run(tc.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
err := validateProviders(tc.providers)
78+
if tc.expectErr != "" {
79+
require.Error(t, err)
80+
assert.Contains(t, err.Error(), tc.expectErr)
81+
} else {
82+
require.NoError(t, err)
83+
}
84+
})
85+
}
86+
}
87+
88+
func TestValidateProvider_DuplicateNames(t *testing.T) {
89+
t.Parallel()
90+
91+
tests := []struct {
92+
name string
93+
providers []provider.Provider
94+
expectErr string
95+
}{
96+
{
97+
name: "unique_names",
98+
providers: []provider.Provider{
99+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
100+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
101+
},
102+
},
103+
{
104+
name: "duplicate_base_url_different_names",
105+
providers: []provider.Provider{
106+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
107+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
108+
},
109+
},
110+
{
111+
name: "duplicate_name",
112+
providers: []provider.Provider{
113+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
114+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
115+
},
116+
expectErr: "duplicate provider name",
117+
},
118+
}
119+
120+
for _, tc := range tests {
121+
t.Run(tc.name, func(t *testing.T) {
122+
t.Parallel()
123+
124+
err := validateProviders(tc.providers)
125+
if tc.expectErr != "" {
126+
require.Error(t, err)
127+
assert.Contains(t, err.Error(), tc.expectErr)
128+
} else {
129+
require.NoError(t, err)
130+
}
131+
})
132+
}
133+
}
134+
16135
func TestPassthroughRoutesForProviders(t *testing.T) {
17136
t.Parallel()
18137

0 commit comments

Comments
 (0)