@@ -8,14 +8,22 @@ import (
88 "testing"
99
1010 "github.com/github/github-mcp-server/pkg/http/headers"
11+ "github.com/github/github-mcp-server/pkg/utils"
1112 "github.com/go-chi/chi/v5"
1213 "github.com/stretchr/testify/assert"
1314 "github.com/stretchr/testify/require"
1415)
1516
17+ var (
18+ defaultAuthorizationServer = "https://github.com/login/oauth"
19+ )
20+
1621func TestNewAuthHandler (t * testing.T ) {
1722 t .Parallel ()
1823
24+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
25+ require .NoError (t , err )
26+
1927 tests := []struct {
2028 name string
2129 cfg * Config
@@ -25,13 +33,13 @@ func TestNewAuthHandler(t *testing.T) {
2533 {
2634 name : "nil config uses defaults" ,
2735 cfg : nil ,
28- expectedAuthServer : DefaultAuthorizationServer ,
36+ expectedAuthServer : defaultAuthorizationServer ,
2937 expectedResourcePath : "" ,
3038 },
3139 {
3240 name : "empty config uses defaults" ,
3341 cfg : & Config {},
34- expectedAuthServer : DefaultAuthorizationServer ,
42+ expectedAuthServer : defaultAuthorizationServer ,
3543 expectedResourcePath : "" ,
3644 },
3745 {
@@ -48,7 +56,7 @@ func TestNewAuthHandler(t *testing.T) {
4856 BaseURL : "https://example.com" ,
4957 ResourcePath : "/mcp" ,
5058 },
51- expectedAuthServer : DefaultAuthorizationServer ,
59+ expectedAuthServer : defaultAuthorizationServer ,
5260 expectedResourcePath : "/mcp" ,
5361 },
5462 }
@@ -57,11 +65,12 @@ func TestNewAuthHandler(t *testing.T) {
5765 t .Run (tc .name , func (t * testing.T ) {
5866 t .Parallel ()
5967
60- handler , err := NewAuthHandler (tc .cfg )
68+ handler , err := NewAuthHandler (t . Context (), tc .cfg , dotcomHost )
6169 require .NoError (t , err )
6270 require .NotNil (t , handler )
6371
6472 assert .Equal (t , tc .expectedAuthServer , handler .cfg .AuthorizationServer )
73+ assert .Equal (t , tc .expectedResourcePath , handler .cfg .ResourcePath )
6574 })
6675 }
6776}
@@ -372,7 +381,7 @@ func TestHandleProtectedResource(t *testing.T) {
372381 authServers , ok := body ["authorization_servers" ].([]any )
373382 require .True (t , ok )
374383 require .Len (t , authServers , 1 )
375- assert .Equal (t , DefaultAuthorizationServer , authServers [0 ])
384+ assert .Equal (t , defaultAuthorizationServer , authServers [0 ])
376385 },
377386 },
378387 {
@@ -451,7 +460,10 @@ func TestHandleProtectedResource(t *testing.T) {
451460 t .Run (tc .name , func (t * testing.T ) {
452461 t .Parallel ()
453462
454- handler , err := NewAuthHandler (tc .cfg )
463+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
464+ require .NoError (t , err )
465+
466+ handler , err := NewAuthHandler (t .Context (), tc .cfg , dotcomHost )
455467 require .NoError (t , err )
456468
457469 router := chi .NewRouter ()
@@ -493,9 +505,12 @@ func TestHandleProtectedResource(t *testing.T) {
493505func TestRegisterRoutes (t * testing.T ) {
494506 t .Parallel ()
495507
496- handler , err := NewAuthHandler (& Config {
508+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
509+ require .NoError (t , err )
510+
511+ handler , err := NewAuthHandler (t .Context (), & Config {
497512 BaseURL : "https://api.example.com" ,
498- })
513+ }, dotcomHost )
499514 require .NoError (t , err )
500515
501516 router := chi .NewRouter ()
@@ -559,9 +574,12 @@ func TestSupportedScopes(t *testing.T) {
559574func TestProtectedResourceResponseFormat (t * testing.T ) {
560575 t .Parallel ()
561576
562- handler , err := NewAuthHandler (& Config {
577+ dotcomHost , err := utils .NewAPIHost ("https://api.github.com" )
578+ require .NoError (t , err )
579+
580+ handler , err := NewAuthHandler (t .Context (), & Config {
563581 BaseURL : "https://api.example.com" ,
564- })
582+ }, dotcomHost )
565583 require .NoError (t , err )
566584
567585 router := chi .NewRouter ()
@@ -598,7 +616,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
598616 authServers , ok := response ["authorization_servers" ].([]any )
599617 require .True (t , ok )
600618 assert .Len (t , authServers , 1 )
601- assert .Equal (t , DefaultAuthorizationServer , authServers [0 ])
619+ assert .Equal (t , defaultAuthorizationServer , authServers [0 ])
602620}
603621
604622func TestOAuthProtectedResourcePrefix (t * testing.T ) {
@@ -611,5 +629,70 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
611629func TestDefaultAuthorizationServer (t * testing.T ) {
612630 t .Parallel ()
613631
614- assert .Equal (t , "https://github.com/login/oauth" , DefaultAuthorizationServer )
632+ assert .Equal (t , "https://github.com/login/oauth" , defaultAuthorizationServer )
633+ }
634+
635+ func TestAPIHostResolver_AuthorizationServerURL (t * testing.T ) {
636+ t .Parallel ()
637+
638+ tests := []struct {
639+ name string
640+ host string
641+ expectedURL string
642+ expectError bool
643+ errorContains string
644+ }{
645+ {
646+ name : "valid host returns authorization server URL" ,
647+ host : "http://api.github.com" ,
648+ expectedURL : "https://github.com/login/oauth" ,
649+ expectError : false ,
650+ },
651+ {
652+ name : "invalid host returns error" ,
653+ host : "://invalid-url" ,
654+ expectedURL : "" ,
655+ expectError : true ,
656+ errorContains : "could not parse host as URL" ,
657+ },
658+ {
659+ name : "host without scheme returns error" ,
660+ host : "api.github.com" ,
661+ expectedURL : "" ,
662+ expectError : true ,
663+ errorContains : "host must have a scheme" ,
664+ },
665+ {
666+ name : "GHES host returns correct authorization server URL with subdomain isolation" ,
667+ host : "https://api.ghe.example.com" ,
668+ expectedURL : "https://ghe.example.com/login/oauth" ,
669+ expectError : false ,
670+ },
671+ {
672+ name : "GHES host returns correct authorization server URL without subdomain isolation" ,
673+ host : "https://ghe-nosubdomain.example.com/api/v3" ,
674+ expectedURL : "https://ghe-nosubdomain.example.com/login/oauth" ,
675+ expectError : false ,
676+ },
677+ }
678+
679+ for _ , tc := range tests {
680+ t .Run (tc .name , func (t * testing.T ) {
681+ t .Parallel ()
682+
683+ apiHost , err := utils .NewAPIHost (tc .host )
684+ if tc .expectError {
685+ require .Error (t , err )
686+ if tc .errorContains != "" {
687+ assert .Contains (t , err .Error (), tc .errorContains )
688+ }
689+ return
690+ }
691+ require .NoError (t , err )
692+
693+ url , err := apiHost .AuthorizationServerURL (t .Context ())
694+ require .NoError (t , err )
695+ assert .Equal (t , tc .expectedURL , url .String ())
696+ })
697+ }
615698}
0 commit comments