Skip to content
Merged
18 changes: 16 additions & 2 deletions pkg/http/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
package oauth

import (
"context"
"fmt"
"net/http"
"strings"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/oauthex"
Expand Down Expand Up @@ -43,8 +45,13 @@ type Config struct {
// This is used to construct the OAuth resource URL.
BaseURL string

// APIHost is the GitHub API host resolver that provides OAuth URL.
// If set, this takes precedence over AuthorizationServer.
APIHost utils.APIHostResolver

// AuthorizationServer is the OAuth authorization server URL.
// Defaults to GitHub's OAuth server if not specified.
// This field is ignored if APIHost is set.
AuthorizationServer string

// ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp").
Expand All @@ -64,8 +71,15 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
cfg = &Config{}
}

// Default authorization server to GitHub
if cfg.AuthorizationServer == "" {
// Resolve authorization server from APIHost if provided
if cfg.APIHost != nil {
oauthURL, err := cfg.APIHost.OAuthURL(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to get OAuth URL from API host: %w", err)
}
cfg.AuthorizationServer = oauthURL.String()
} else if cfg.AuthorizationServer == "" {
// Default authorization server to GitHub if not provided
cfg.AuthorizationServer = DefaultAuthorizationServer
}

Expand Down
56 changes: 56 additions & 0 deletions pkg/http/oauth/oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
package oauth

import (
"context"
"crypto/tls"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// mockAPIHostResolver is a test implementation of utils.APIHostResolver
type mockAPIHostResolver struct {
oauthURL string
}

func (m mockAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func (m mockAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func (m mockAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func (m mockAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func (m mockAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) {
return url.Parse(m.oauthURL)
}

// Ensure mockAPIHostResolver implements utils.APIHostResolver
var _ utils.APIHostResolver = mockAPIHostResolver{}

func TestNewAuthHandler(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -51,6 +82,31 @@ func TestNewAuthHandler(t *testing.T) {
expectedAuthServer: DefaultAuthorizationServer,
expectedResourcePath: "/mcp",
},
{
name: "APIHost with HTTPS GHES",
cfg: &Config{
APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"},
},
expectedAuthServer: "https://ghes.example.com/login/oauth",
expectedResourcePath: "",
},
{
name: "APIHost with HTTP GHES",
cfg: &Config{
APIHost: mockAPIHostResolver{oauthURL: "http://ghes.local/login/oauth"},
},
expectedAuthServer: "http://ghes.local/login/oauth",
expectedResourcePath: "",
},
{
name: "APIHost takes precedence over AuthorizationServer",
cfg: &Config{
APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"},
AuthorizationServer: "https://should-be-ignored.example.com/oauth",
},
expectedAuthServer: "https://ghes.example.com/login/oauth",
expectedResourcePath: "",
},
}

for _, tc := range tests {
Expand Down
1 change: 1 addition & 0 deletions pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func RunHTTPServer(cfg ServerConfig) error {
oauthCfg := &oauth.Config{
BaseURL: cfg.BaseURL,
ResourcePath: cfg.ResourcePath,
APIHost: apiHost,
}

serverOptions := []HandlerOption{}
Expand Down
3 changes: 3 additions & 0 deletions pkg/scopes/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
return nil, nil
}
func (t testAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func TestParseScopeHeader(t *testing.T) {
tests := []struct {
Expand Down
24 changes: 24 additions & 0 deletions pkg/utils/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ type APIHostResolver interface {
GraphqlURL(ctx context.Context) (*url.URL, error)
UploadURL(ctx context.Context) (*url.URL, error)
RawURL(ctx context.Context) (*url.URL, error)
OAuthURL(ctx context.Context) (*url.URL, error)
}

type APIHost struct {
restURL *url.URL
gqlURL *url.URL
uploadURL *url.URL
rawURL *url.URL
oauthURL *url.URL
}

var _ APIHostResolver = APIHost{}
Expand Down Expand Up @@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) {
return a.rawURL, nil
}

func (a APIHost) OAuthURL(_ context.Context) (*url.URL, error) {
return a.oauthURL, nil
}

func newDotcomHost() (APIHost, error) {
baseRestURL, err := url.Parse("https://api.github.com/")
if err != nil {
Expand All @@ -73,11 +79,17 @@ func newDotcomHost() (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err)
}

oauthURL, err := url.Parse("https://github.com/login/oauth")
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse dotcom OAuth URL: %w", err)
}

return APIHost{
restURL: baseRestURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
oauthURL: oauthURL,
}, nil
}

Expand Down Expand Up @@ -112,11 +124,17 @@ func newGHECHost(hostname string) (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err)
}

oauthURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname()))
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse GHEC OAuth URL: %w", err)
}

return APIHost{
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
oauthURL: oauthURL,
}, nil
}

Expand Down Expand Up @@ -164,11 +182,17 @@ func newGHESHost(hostname string) (APIHost, error) {
return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err)
}

oauthURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname()))
if err != nil {
return APIHost{}, fmt.Errorf("failed to parse GHES OAuth URL: %w", err)
}

return APIHost{
restURL: restURL,
gqlURL: gqlURL,
uploadURL: uploadURL,
rawURL: rawURL,
oauthURL: oauthURL,
}, nil
}

Expand Down
140 changes: 140 additions & 0 deletions pkg/utils/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package utils //nolint:revive //TODO: figure out a better name for this package

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestOAuthURL(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
host string
expectedOAuth string
expectError bool
errorSubstring string
}{
{
name: "dotcom (empty host)",
host: "",
expectedOAuth: "https://github.com/login/oauth",
},
{
name: "dotcom (explicit github.com)",
host: "https://github.com",
expectedOAuth: "https://github.com/login/oauth",
},
{
name: "GHEC with HTTPS",
host: "https://acme.ghe.com",
expectedOAuth: "https://acme.ghe.com/login/oauth",
},
{
name: "GHEC with HTTP (should error)",
host: "http://acme.ghe.com",
expectError: true,
errorSubstring: "GHEC URL must be HTTPS",
},
{
name: "GHES with HTTPS",
host: "https://ghes.example.com",
expectedOAuth: "https://ghes.example.com/login/oauth",
},
{
name: "GHES with HTTP",
host: "http://ghes.example.com",
expectedOAuth: "http://ghes.example.com/login/oauth",
},
{
name: "GHES with HTTP and custom port (port stripped - not supported yet)",
host: "http://ghes.local:8080",
expectedOAuth: "http://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment
},
{
name: "GHES with HTTPS and custom port (port stripped - not supported yet)",
host: "https://ghes.local:8443",
expectedOAuth: "https://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment
},
{
name: "host without scheme (should error)",
host: "ghes.example.com",
expectError: true,
errorSubstring: "host must have a scheme",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiHost, err := NewAPIHost(tt.host)

if tt.expectError {
require.Error(t, err)
if tt.errorSubstring != "" {
assert.Contains(t, err.Error(), tt.errorSubstring)
}
return
}

require.NoError(t, err)
require.NotNil(t, apiHost)

oauthURL, err := apiHost.OAuthURL(ctx)
require.NoError(t, err)
require.NotNil(t, oauthURL)

assert.Equal(t, tt.expectedOAuth, oauthURL.String())
})
}
}

func TestAPIHost_AllURLsHaveConsistentScheme(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
host string
expectedScheme string
}{
{
name: "GHES with HTTPS",
host: "https://ghes.example.com",
expectedScheme: "https",
},
{
name: "GHES with HTTP",
host: "http://ghes.example.com",
expectedScheme: "http",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiHost, err := NewAPIHost(tt.host)
require.NoError(t, err)

restURL, err := apiHost.BaseRESTURL(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expectedScheme, restURL.Scheme, "REST URL scheme should match")

gqlURL, err := apiHost.GraphqlURL(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expectedScheme, gqlURL.Scheme, "GraphQL URL scheme should match")

uploadURL, err := apiHost.UploadURL(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expectedScheme, uploadURL.Scheme, "Upload URL scheme should match")

rawURL, err := apiHost.RawURL(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expectedScheme, rawURL.Scheme, "Raw URL scheme should match")

oauthURL, err := apiHost.OAuthURL(ctx)
require.NoError(t, err)
assert.Equal(t, tt.expectedScheme, oauthURL.Scheme, "OAuth URL scheme should match")
})
}
}
Loading