Skip to content

Commit d0b1986

Browse files
committed
Generate authorization server urls during API resolution.
Then we can pass this into the OAuth implementation.
1 parent 16ff74a commit d0b1986

File tree

5 files changed

+160
-34
lines changed

5 files changed

+160
-34
lines changed

pkg/http/oauth/oauth.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
package oauth
44

55
import (
6+
"context"
67
"fmt"
78
"net/http"
89
"strings"
910

1011
"github.com/github/github-mcp-server/pkg/http/headers"
12+
"github.com/github/github-mcp-server/pkg/utils"
1113
"github.com/go-chi/chi/v5"
1214
"github.com/modelcontextprotocol/go-sdk/auth"
1315
"github.com/modelcontextprotocol/go-sdk/oauthex"
@@ -16,9 +18,6 @@ import (
1618
const (
1719
// OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata.
1820
OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource"
19-
20-
// DefaultAuthorizationServer is GitHub's OAuth authorization server.
21-
DefaultAuthorizationServer = "https://github.com/login/oauth"
2221
)
2322

2423
// SupportedScopes lists all OAuth scopes that may be required by MCP tools.
@@ -59,14 +58,19 @@ type AuthHandler struct {
5958
}
6059

6160
// NewAuthHandler creates a new OAuth auth handler.
62-
func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
61+
func NewAuthHandler(ctx context.Context, cfg *Config, apiHost utils.APIHostResolver) (*AuthHandler, error) {
6362
if cfg == nil {
6463
cfg = &Config{}
6564
}
6665

6766
// Default authorization server to GitHub
6867
if cfg.AuthorizationServer == "" {
69-
cfg.AuthorizationServer = DefaultAuthorizationServer
68+
url, err := apiHost.AuthorizationServerURL(ctx)
69+
if err != nil {
70+
return nil, fmt.Errorf("failed to get authorization server URL from API host: %w", err)
71+
}
72+
73+
cfg.AuthorizationServer = url.String()
7074
}
7175

7276
return &AuthHandler{

pkg/http/oauth/oauth_test.go

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,22 @@ import (
88
"testing"
99

1010
"github.com/github/github-mcp-server/pkg/http/headers"
11+
"github.com/github/github-mcp-server/pkg/utils"
1112
"github.com/go-chi/chi/v5"
1213
"github.com/stretchr/testify/assert"
1314
"github.com/stretchr/testify/require"
1415
)
1516

17+
var (
18+
defaultAuthorizationServer = "https://github.com/login/oauth"
19+
)
20+
1621
func TestNewAuthHandler(t *testing.T) {
1722
t.Parallel()
1823

24+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
25+
require.NoError(t, err)
26+
1927
tests := []struct {
2028
name string
2129
cfg *Config
@@ -25,13 +33,13 @@ func TestNewAuthHandler(t *testing.T) {
2533
{
2634
name: "nil config uses defaults",
2735
cfg: nil,
28-
expectedAuthServer: DefaultAuthorizationServer,
36+
expectedAuthServer: defaultAuthorizationServer,
2937
expectedResourcePath: "",
3038
},
3139
{
3240
name: "empty config uses defaults",
3341
cfg: &Config{},
34-
expectedAuthServer: DefaultAuthorizationServer,
42+
expectedAuthServer: defaultAuthorizationServer,
3543
expectedResourcePath: "",
3644
},
3745
{
@@ -48,7 +56,7 @@ func TestNewAuthHandler(t *testing.T) {
4856
BaseURL: "https://example.com",
4957
ResourcePath: "/mcp",
5058
},
51-
expectedAuthServer: DefaultAuthorizationServer,
59+
expectedAuthServer: defaultAuthorizationServer,
5260
expectedResourcePath: "/mcp",
5361
},
5462
}
@@ -57,11 +65,12 @@ func TestNewAuthHandler(t *testing.T) {
5765
t.Run(tc.name, func(t *testing.T) {
5866
t.Parallel()
5967

60-
handler, err := NewAuthHandler(tc.cfg)
68+
handler, err := NewAuthHandler(t.Context(), tc.cfg, dotcomHost)
6169
require.NoError(t, err)
6270
require.NotNil(t, handler)
6371

6472
assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer)
73+
assert.Equal(t, tc.expectedResourcePath, handler.cfg.ResourcePath)
6574
})
6675
}
6776
}
@@ -372,7 +381,7 @@ func TestHandleProtectedResource(t *testing.T) {
372381
authServers, ok := body["authorization_servers"].([]any)
373382
require.True(t, ok)
374383
require.Len(t, authServers, 1)
375-
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
384+
assert.Equal(t, defaultAuthorizationServer, authServers[0])
376385
},
377386
},
378387
{
@@ -451,7 +460,10 @@ func TestHandleProtectedResource(t *testing.T) {
451460
t.Run(tc.name, func(t *testing.T) {
452461
t.Parallel()
453462

454-
handler, err := NewAuthHandler(tc.cfg)
463+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
464+
require.NoError(t, err)
465+
466+
handler, err := NewAuthHandler(t.Context(), tc.cfg, dotcomHost)
455467
require.NoError(t, err)
456468

457469
router := chi.NewRouter()
@@ -493,9 +505,12 @@ func TestHandleProtectedResource(t *testing.T) {
493505
func TestRegisterRoutes(t *testing.T) {
494506
t.Parallel()
495507

496-
handler, err := NewAuthHandler(&Config{
508+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
509+
require.NoError(t, err)
510+
511+
handler, err := NewAuthHandler(t.Context(), &Config{
497512
BaseURL: "https://api.example.com",
498-
})
513+
}, dotcomHost)
499514
require.NoError(t, err)
500515

501516
router := chi.NewRouter()
@@ -559,9 +574,12 @@ func TestSupportedScopes(t *testing.T) {
559574
func TestProtectedResourceResponseFormat(t *testing.T) {
560575
t.Parallel()
561576

562-
handler, err := NewAuthHandler(&Config{
577+
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
578+
require.NoError(t, err)
579+
580+
handler, err := NewAuthHandler(t.Context(), &Config{
563581
BaseURL: "https://api.example.com",
564-
})
582+
}, dotcomHost)
565583
require.NoError(t, err)
566584

567585
router := chi.NewRouter()
@@ -598,7 +616,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
598616
authServers, ok := response["authorization_servers"].([]any)
599617
require.True(t, ok)
600618
assert.Len(t, authServers, 1)
601-
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
619+
assert.Equal(t, defaultAuthorizationServer, authServers[0])
602620
}
603621

604622
func TestOAuthProtectedResourcePrefix(t *testing.T) {
@@ -611,5 +629,70 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
611629
func TestDefaultAuthorizationServer(t *testing.T) {
612630
t.Parallel()
613631

614-
assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer)
632+
assert.Equal(t, "https://github.com/login/oauth", defaultAuthorizationServer)
633+
}
634+
635+
func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
636+
t.Parallel()
637+
638+
tests := []struct {
639+
name string
640+
host string
641+
expectedURL string
642+
expectError bool
643+
errorContains string
644+
}{
645+
{
646+
name: "valid host returns authorization server URL",
647+
host: "http://api.github.com",
648+
expectedURL: "https://github.com/login/oauth",
649+
expectError: false,
650+
},
651+
{
652+
name: "invalid host returns error",
653+
host: "://invalid-url",
654+
expectedURL: "",
655+
expectError: true,
656+
errorContains: "could not parse host as URL",
657+
},
658+
{
659+
name: "host without scheme returns error",
660+
host: "api.github.com",
661+
expectedURL: "",
662+
expectError: true,
663+
errorContains: "host must have a scheme",
664+
},
665+
{
666+
name: "GHES host returns correct authorization server URL with subdomain isolation",
667+
host: "https://api.ghe.example.com",
668+
expectedURL: "https://ghe.example.com/login/oauth",
669+
expectError: false,
670+
},
671+
{
672+
name: "GHES host returns correct authorization server URL without subdomain isolation",
673+
host: "https://ghe-nosubdomain.example.com/api/v3",
674+
expectedURL: "https://ghe-nosubdomain.example.com/login/oauth",
675+
expectError: false,
676+
},
677+
}
678+
679+
for _, tc := range tests {
680+
t.Run(tc.name, func(t *testing.T) {
681+
t.Parallel()
682+
683+
apiHost, err := utils.NewAPIHost(tc.host)
684+
if tc.expectError {
685+
require.Error(t, err)
686+
if tc.errorContains != "" {
687+
assert.Contains(t, err.Error(), tc.errorContains)
688+
}
689+
return
690+
}
691+
require.NoError(t, err)
692+
693+
url, err := apiHost.AuthorizationServerURL(t.Context())
694+
require.NoError(t, err)
695+
assert.Equal(t, tc.expectedURL, url.String())
696+
})
697+
}
615698
}

pkg/http/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func RunHTTPServer(cfg ServerConfig) error {
136136

137137
r := chi.NewRouter()
138138
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
139-
oauthHandler, err := oauth.NewAuthHandler(oauthCfg)
139+
oauthHandler, err := oauth.NewAuthHandler(ctx, oauthCfg, apiHost)
140140
if err != nil {
141141
return fmt.Errorf("failed to create OAuth handler: %w", err)
142142
}

pkg/scopes/fetcher_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
2828
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
2929
return nil, nil
3030
}
31+
func (t testAPIHostResolver) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
32+
return nil, nil
33+
}
3134

3235
func TestParseScopeHeader(t *testing.T) {
3336
tests := []struct {

pkg/utils/api.go

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ type APIHostResolver interface {
1414
GraphqlURL(ctx context.Context) (*url.URL, error)
1515
UploadURL(ctx context.Context) (*url.URL, error)
1616
RawURL(ctx context.Context) (*url.URL, error)
17+
AuthorizationServerURL(ctx context.Context) (*url.URL, error)
1718
}
1819

1920
type APIHost struct {
20-
restURL *url.URL
21-
gqlURL *url.URL
22-
uploadURL *url.URL
23-
rawURL *url.URL
21+
restURL *url.URL
22+
gqlURL *url.URL
23+
uploadURL *url.URL
24+
rawURL *url.URL
25+
authorizationServerURL *url.URL
2426
}
2527

2628
var _ APIHostResolver = APIHost{}
@@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) {
5254
return a.rawURL, nil
5355
}
5456

57+
func (a APIHost) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
58+
return a.authorizationServerURL, nil
59+
}
60+
5561
func newDotcomHost() (APIHost, error) {
5662
baseRestURL, err := url.Parse("https://api.github.com/")
5763
if err != nil {
@@ -73,11 +79,18 @@ func newDotcomHost() (APIHost, error) {
7379
return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err)
7480
}
7581

82+
// The authorization server for GitHub.com is at github.com/login/oauth, not api.github.com
83+
authorizationServerURL, err := url.Parse("https://github.com/login/oauth")
84+
if err != nil {
85+
return APIHost{}, fmt.Errorf("failed to parse dotcom Authorization Server URL: %w", err)
86+
}
87+
7688
return APIHost{
77-
restURL: baseRestURL,
78-
gqlURL: gqlURL,
79-
uploadURL: uploadURL,
80-
rawURL: rawURL,
89+
restURL: baseRestURL,
90+
gqlURL: gqlURL,
91+
uploadURL: uploadURL,
92+
rawURL: rawURL,
93+
authorizationServerURL: authorizationServerURL,
8194
}, nil
8295
}
8396

@@ -112,11 +125,19 @@ func newGHECHost(hostname string) (APIHost, error) {
112125
return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err)
113126
}
114127

128+
// The authorization server for GHEC is still on the root domain, not the api subdomain
129+
rootHost := strings.TrimPrefix(u.Hostname(), "api.")
130+
authorizationServerURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", rootHost))
131+
if err != nil {
132+
return APIHost{}, fmt.Errorf("failed to parse GHEC Authorization Server URL: %w", err)
133+
}
134+
115135
return APIHost{
116-
restURL: restURL,
117-
gqlURL: gqlURL,
118-
uploadURL: uploadURL,
119-
rawURL: rawURL,
136+
restURL: restURL,
137+
gqlURL: gqlURL,
138+
uploadURL: uploadURL,
139+
rawURL: rawURL,
140+
authorizationServerURL: authorizationServerURL,
120141
}, nil
121142
}
122143

@@ -164,11 +185,26 @@ func newGHESHost(hostname string) (APIHost, error) {
164185
return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err)
165186
}
166187

188+
// If subdomain isolation is enabled, the hostname will be api.hostname, but the authorization server is still on the root domain at hostname/login/oauth
189+
// If subdomain isolation is not enabled, the hostname is still hostname and the authorization server is at hostname/login/oauth
190+
var rootHost string
191+
if hasSubdomainIsolation {
192+
rootHost = strings.TrimPrefix(u.Hostname(), "api.")
193+
} else {
194+
rootHost = u.Hostname()
195+
}
196+
authorizationServerURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, rootHost))
197+
198+
if err != nil {
199+
return APIHost{}, fmt.Errorf("failed to parse GHES Authorization Server URL: %w", err)
200+
}
201+
167202
return APIHost{
168-
restURL: restURL,
169-
gqlURL: gqlURL,
170-
uploadURL: uploadURL,
171-
rawURL: rawURL,
203+
restURL: restURL,
204+
gqlURL: gqlURL,
205+
uploadURL: uploadURL,
206+
rawURL: rawURL,
207+
authorizationServerURL: authorizationServerURL,
172208
}, nil
173209
}
174210

0 commit comments

Comments
 (0)