@@ -6,6 +6,7 @@ package mcp
66
77import (
88 "context"
9+ "errors"
910 "fmt"
1011 "io"
1112 "net/http"
@@ -1156,3 +1157,58 @@ func TestTokenInfo(t *testing.T) {
11561157 t .Errorf ("got %q, want %q" , g , w )
11571158 }
11581159}
1160+
1161+ // errTestAuthorizeFailed is a sentinel error returned by
1162+ // retrieveErrorOAuthHandler.Authorize().
1163+ var errTestAuthorizeFailed = errors .New ("authorize intentionally failed for test" )
1164+
1165+ // retrieveErrorOAuthHandler is a mock OAuthHandler that always returns
1166+ // an oauth2.RetrieveError from its TokenSource's Token() method.
1167+ type retrieveErrorOAuthHandler struct {}
1168+
1169+ func (h * retrieveErrorOAuthHandler ) TokenSource (ctx context.Context ) (oauth2.TokenSource , error ) {
1170+ return h , nil
1171+ }
1172+
1173+ func (h * retrieveErrorOAuthHandler ) Token () (* oauth2.Token , error ) {
1174+ return nil , & oauth2.RetrieveError {
1175+ Response : & http.Response {StatusCode : http .StatusBadRequest },
1176+ Body : []byte ("test retrieve error" ),
1177+ ErrorCode : "invalid_grant" ,
1178+ }
1179+ }
1180+
1181+ func (h * retrieveErrorOAuthHandler ) Authorize (ctx context.Context , req * http.Request , resp * http.Response ) error {
1182+ return errTestAuthorizeFailed
1183+ }
1184+
1185+ // TestStreamableClientOAuth_RetrieveError verifies that an invalid_grant RetrieveError
1186+ // from the OAuth token source correctly skips sending Authorization header and relies on
1187+ // the server's 401 response to trigger the Authorize fallback flow.
1188+ func TestStreamableClientOAuth_RetrieveError (t * testing.T ) {
1189+ ctx := context .Background ()
1190+ oauthHandler := & retrieveErrorOAuthHandler {}
1191+
1192+ // Mock MCP server returns 401 Unauthorized to simulate a server rejecting
1193+ // the request that omitted the Authorization header.
1194+ httpServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1195+ w .WriteHeader (http .StatusUnauthorized )
1196+ }))
1197+ t .Cleanup (httpServer .Close )
1198+
1199+ transport := & StreamableClientTransport {
1200+ Endpoint : httpServer .URL ,
1201+ OAuthHandler : oauthHandler ,
1202+ }
1203+ client := NewClient (testImpl , nil )
1204+
1205+ // Attempt to connect. The Connect call will trigger the initialization request,
1206+ // which will fail to retrieve the token and proceed without auth header, receive 401,
1207+ // and invoke Authorize().
1208+ _ , err := client .Connect (ctx , transport , nil )
1209+
1210+ // Expect the connection to fail with the sentinel error, not the RetrieveError.
1211+ if ! errors .Is (err , errTestAuthorizeFailed ) {
1212+ t .Fatalf ("client.Connect() error = %v, want %v" , err , errTestAuthorizeFailed )
1213+ }
1214+ }
0 commit comments