Skip to content

Commit 4bcd04d

Browse files
Merge branch 'main' into dependabot/github_actions/github/codeql-action-4.35.2
2 parents 43ae5a1 + 4f1b3e5 commit 4bcd04d

2 files changed

Lines changed: 75 additions & 4 deletions

File tree

mcp/streamable.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/modelcontextprotocol/go-sdk/internal/util"
3838
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
3939
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
40+
"golang.org/x/oauth2"
4041
)
4142

4243
// A StreamableHTTPHandler is an http.Handler that serves streamable MCP
@@ -1803,6 +1804,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
18031804
}
18041805
req.Header.Set("Content-Type", "application/json")
18051806
req.Header.Set("Accept", "application/json, text/event-stream")
1807+
18061808
if err := c.setMCPHeaders(req); err != nil {
18071809
// Failure to set headers means that the request was not sent.
18081810
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
@@ -1934,9 +1936,20 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error {
19341936
if ts != nil {
19351937
token, err := ts.Token()
19361938
if err != nil {
1937-
return err
1938-
}
1939-
if token != nil {
1939+
// If the error is an invalid_grant oauth2.RetrieveError it indicates
1940+
// that the token source doesn't have valid authorization for the token
1941+
// endpoint, per RFC 6749 section 5.2. For example, the refresh token
1942+
// may be expired or invalid.
1943+
//
1944+
// In that case, ignore the error, skip setting the Authorization
1945+
// header, and proceed with the request. Callers that support
1946+
// authorization flows get a 401/403 response and trigger the
1947+
// Authorize() flow to refresh their token.
1948+
var retrieveErr *oauth2.RetrieveError
1949+
if !errors.As(err, &retrieveErr) || retrieveErr.ErrorCode != "invalid_grant" {
1950+
return err
1951+
}
1952+
} else if token != nil {
19401953
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
19411954
}
19421955
}
@@ -2236,8 +2249,10 @@ func (c *streamableClientConn) Close() error {
22362249
} else {
22372250
if err := c.setMCPHeaders(req); err != nil {
22382251
c.closeErr = err
2239-
} else if _, err := c.client.Do(req); err != nil {
2252+
} else if resp, err := c.client.Do(req); err != nil {
22402253
c.closeErr = err
2254+
} else {
2255+
resp.Body.Close()
22412256
}
22422257
}
22432258
}

mcp/streamable_client_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package mcp
66

77
import (
88
"context"
9+
"errors"
910
"fmt"
1011
"io"
1112
"net/http"
@@ -1156,3 +1157,58 @@ func TestTokenInfo(t *testing.T) {
11561157
t.Errorf("got %q, want %q", g, w)
11571158
}
11581159
}
1160+
1161+
// errTestAuthorizeFailed is a sentinel error returned by
1162+
// retrieveErrorOAuthHandler.Authorize().
1163+
var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test")
1164+
1165+
// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns
1166+
// an oauth2.RetrieveError from its TokenSource's Token() method.
1167+
type retrieveErrorOAuthHandler struct{}
1168+
1169+
func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
1170+
return h, nil
1171+
}
1172+
1173+
func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) {
1174+
return nil, &oauth2.RetrieveError{
1175+
Response: &http.Response{StatusCode: http.StatusBadRequest},
1176+
Body: []byte("test retrieve error"),
1177+
ErrorCode: "invalid_grant",
1178+
}
1179+
}
1180+
1181+
func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error {
1182+
return errTestAuthorizeFailed
1183+
}
1184+
1185+
// TestStreamableClientOAuth_RetrieveError verifies that an invalid_grant RetrieveError
1186+
// from the OAuth token source correctly skips sending Authorization header and relies on
1187+
// the server's 401 response to trigger the Authorize fallback flow.
1188+
func TestStreamableClientOAuth_RetrieveError(t *testing.T) {
1189+
ctx := context.Background()
1190+
oauthHandler := &retrieveErrorOAuthHandler{}
1191+
1192+
// Mock MCP server returns 401 Unauthorized to simulate a server rejecting
1193+
// the request that omitted the Authorization header.
1194+
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1195+
w.WriteHeader(http.StatusUnauthorized)
1196+
}))
1197+
t.Cleanup(httpServer.Close)
1198+
1199+
transport := &StreamableClientTransport{
1200+
Endpoint: httpServer.URL,
1201+
OAuthHandler: oauthHandler,
1202+
}
1203+
client := NewClient(testImpl, nil)
1204+
1205+
// Attempt to connect. The Connect call will trigger the initialization request,
1206+
// which will fail to retrieve the token and proceed without auth header, receive 401,
1207+
// and invoke Authorize().
1208+
_, err := client.Connect(ctx, transport, nil)
1209+
1210+
// Expect the connection to fail with the sentinel error, not the RetrieveError.
1211+
if !errors.Is(err, errTestAuthorizeFailed) {
1212+
t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed)
1213+
}
1214+
}

0 commit comments

Comments
 (0)