Skip to content

Commit b1c6f3f

Browse files
committed
fix: validate provider name
1 parent 411bc41 commit b1c6f3f

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

bridge.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net/http"
77
"net/url"
8+
"regexp"
89
"strings"
910
"sync"
1011
"sync/atomic"
@@ -57,14 +58,21 @@ type RequestBridge struct {
5758

5859
var _ http.Handler = &RequestBridge{}
5960

60-
// validateProviders checks that no two providers share the same name.
61+
// validProviderName matches names containing only lowercase alphanumeric characters and hyphens.
62+
var validProviderName = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)
63+
64+
// validateProviders checks that provider names are valid and unique.
6165
func validateProviders(providers []provider.Provider) error {
6266
names := make(map[string]bool, len(providers))
6367
for _, prov := range providers {
64-
if names[prov.Name()] {
65-
return fmt.Errorf("duplicate provider name: %q", prov.Name())
68+
name := prov.Name()
69+
if !validProviderName.MatchString(name) {
70+
return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name)
71+
}
72+
if names[name] {
73+
return fmt.Errorf("duplicate provider name: %q", name)
6674
}
67-
names[prov.Name()] = true
75+
names[name] = true
6876
}
6977
return nil
7078
}

bridge_test.go

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"github.com/stretchr/testify/require"
1414
)
1515

16-
func TestValidateProviders(t *testing.T) {
16+
func TestValidateProvider_Names(t *testing.T) {
1717
t.Parallel()
1818

1919
tests := []struct {
@@ -48,12 +48,57 @@ func TestValidateProviders(t *testing.T) {
4848
},
4949
},
5050
{
51-
name: "duplicate_name",
51+
name: "name_with_slashes",
52+
providers: []provider.Provider{
53+
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
54+
},
55+
expectErr: "invalid provider name",
56+
},
57+
{
58+
name: "name_with_spaces",
59+
providers: []provider.Provider{
60+
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
61+
},
62+
expectErr: "invalid provider name",
63+
},
64+
{
65+
name: "name_with_uppercase",
66+
providers: []provider.Provider{
67+
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
68+
},
69+
expectErr: "invalid provider name",
70+
},
71+
}
72+
73+
for _, tc := range tests {
74+
t.Run(tc.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
err := validateProviders(tc.providers)
78+
if tc.expectErr != "" {
79+
require.Error(t, err)
80+
assert.Contains(t, err.Error(), tc.expectErr)
81+
} else {
82+
require.NoError(t, err)
83+
}
84+
})
85+
}
86+
}
87+
88+
func TestValidateProvider_DuplicateNames(t *testing.T) {
89+
t.Parallel()
90+
91+
tests := []struct {
92+
name string
93+
providers []provider.Provider
94+
expectErr string
95+
}{
96+
{
97+
name: "unique_names",
5298
providers: []provider.Provider{
5399
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
54-
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
100+
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
55101
},
56-
expectErr: "duplicate provider name",
57102
},
58103
{
59104
name: "duplicate_base_url_different_names",
@@ -62,6 +107,14 @@ func TestValidateProviders(t *testing.T) {
62107
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
63108
},
64109
},
110+
{
111+
name: "duplicate_name",
112+
providers: []provider.Provider{
113+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
114+
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
115+
},
116+
expectErr: "duplicate provider name",
117+
},
65118
}
66119

67120
for _, tc := range tests {

0 commit comments

Comments
 (0)