Skip to content

Commit 1657034

Browse files
authored
Merge branch 'main' into feat/store-server-info-on-client-session
2 parents a3d5a83 + b33c811 commit 1657034

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ def prepare_token_auth(
205205
headers["Authorization"] = f"Basic {encoded_credentials}"
206206
# Don't include client_secret in body for basic auth
207207
data = {k: v for k, v in data.items() if k != "client_secret"}
208-
elif auth_method == "client_secret_post" and self.client_info.client_secret:
209-
# Include client_secret in request body
208+
elif auth_method == "client_secret_post" and self.client_info.client_id and self.client_info.client_secret:
209+
# Include client_id and client_secret in request body (RFC 6749 §2.3.1)
210+
data["client_id"] = self.client_info.client_id
210211
data["client_secret"] = self.client_info.client_secret
211212
# For auth_method == "none", don't add any client_secret
212213

src/mcp/shared/experimental/tasks/message_queue.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
from abc import ABC, abstractmethod
15+
from collections import deque
1516
from dataclasses import dataclass, field
1617
from datetime import datetime, timezone
1718
from typing import Any, Literal
@@ -151,13 +152,13 @@ class InMemoryTaskMessageQueue(TaskMessageQueue):
151152
"""
152153

153154
def __init__(self) -> None:
154-
self._queues: dict[str, list[QueuedMessage]] = {}
155+
self._queues: dict[str, deque[QueuedMessage]] = {}
155156
self._events: dict[str, anyio.Event] = {}
156157

157-
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
158+
def _get_queue(self, task_id: str) -> deque[QueuedMessage]:
158159
"""Get or create the queue for a task."""
159160
if task_id not in self._queues:
160-
self._queues[task_id] = []
161+
self._queues[task_id] = deque()
161162
return self._queues[task_id]
162163

163164
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
@@ -172,7 +173,7 @@ async def dequeue(self, task_id: str) -> QueuedMessage | None:
172173
queue = self._get_queue(task_id)
173174
if not queue:
174175
return None
175-
return queue.pop(0)
176+
return queue.popleft()
176177

177178
async def peek(self, task_id: str) -> QueuedMessage | None:
178179
"""Return the next message without removing it."""

tests/client/auth/extensions/test_client_credentials.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,72 @@ async def test_exchange_token_client_credentials(self, mock_storage: MockTokenSt
252252
assert "scope=read write" in content
253253
assert "resource=https://api.example.com/v1/mcp" in content
254254

255+
@pytest.mark.anyio
256+
async def test_exchange_token_client_secret_post_includes_client_id(self, mock_storage: MockTokenStorage):
257+
"""Test that client_secret_post includes both client_id and client_secret in body (RFC 6749 §2.3.1)."""
258+
provider = ClientCredentialsOAuthProvider(
259+
server_url="https://api.example.com/v1/mcp",
260+
storage=mock_storage,
261+
client_id="test-client-id",
262+
client_secret="test-client-secret",
263+
token_endpoint_auth_method="client_secret_post",
264+
scopes="read write",
265+
)
266+
await provider._initialize()
267+
provider.context.oauth_metadata = OAuthMetadata(
268+
issuer=AnyHttpUrl("https://api.example.com"),
269+
authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"),
270+
token_endpoint=AnyHttpUrl("https://api.example.com/token"),
271+
)
272+
provider.context.protocol_version = "2025-06-18"
273+
274+
request = await provider._perform_authorization()
275+
276+
content = urllib.parse.unquote_plus(request.content.decode())
277+
assert "grant_type=client_credentials" in content
278+
assert "client_id=test-client-id" in content
279+
assert "client_secret=test-client-secret" in content
280+
# Should NOT have Basic auth header
281+
assert "Authorization" not in request.headers
282+
283+
@pytest.mark.anyio
284+
async def test_exchange_token_client_secret_post_without_client_id(self, mock_storage: MockTokenStorage):
285+
"""Test client_secret_post skips body credentials when client_id is None."""
286+
provider = ClientCredentialsOAuthProvider(
287+
server_url="https://api.example.com/v1/mcp",
288+
storage=mock_storage,
289+
client_id="placeholder",
290+
client_secret="test-client-secret",
291+
token_endpoint_auth_method="client_secret_post",
292+
scopes="read write",
293+
)
294+
await provider._initialize()
295+
provider.context.oauth_metadata = OAuthMetadata(
296+
issuer=AnyHttpUrl("https://api.example.com"),
297+
authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"),
298+
token_endpoint=AnyHttpUrl("https://api.example.com/token"),
299+
)
300+
provider.context.protocol_version = "2025-06-18"
301+
# Override client_info to have client_id=None (edge case)
302+
provider.context.client_info = OAuthClientInformationFull(
303+
redirect_uris=None,
304+
client_id=None,
305+
client_secret="test-client-secret",
306+
grant_types=["client_credentials"],
307+
token_endpoint_auth_method="client_secret_post",
308+
scope="read write",
309+
)
310+
311+
request = await provider._perform_authorization()
312+
313+
content = urllib.parse.unquote_plus(request.content.decode())
314+
assert "grant_type=client_credentials" in content
315+
# Neither client_id nor client_secret should be in body since client_id is None
316+
# (RFC 6749 §2.3.1 requires both for client_secret_post)
317+
assert "client_id=" not in content
318+
assert "client_secret=" not in content
319+
assert "Authorization" not in request.headers
320+
255321
@pytest.mark.anyio
256322
async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage):
257323
"""Test token exchange without scopes."""

tests/experimental/tasks/test_message_queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for TaskMessageQueue and InMemoryTaskMessageQueue."""
22

3+
from collections import deque
34
from datetime import datetime, timezone
45

56
import anyio
@@ -270,7 +271,7 @@ async def is_empty_with_injection(tid: str) -> bool:
270271
if call_count == 2 and tid == task_id:
271272
# Before second check, inject a message - this simulates a message
272273
# arriving between event creation and the double-check
273-
queue._queues[task_id] = [QueuedMessage(type="request", message=make_request())]
274+
queue._queues[task_id] = deque([QueuedMessage(type="request", message=make_request())])
274275
return await original_is_empty(tid)
275276

276277
queue.is_empty = is_empty_with_injection # type: ignore[method-assign]

0 commit comments

Comments
 (0)