-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_token_querystring.py
More file actions
155 lines (120 loc) · 5.55 KB
/
test_token_querystring.py
File metadata and controls
155 lines (120 loc) · 5.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Tests for TokenQuerystringMiddleware auth conflict handling.
These tests use a minimal ASGI setup to test the middleware in isolation,
avoiding the full FastAPI/MCP stack. Requires Python 3.13+ and project deps.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock
@pytest.fixture
def middleware():
"""Import and instantiate TokenQuerystringMiddleware with a mock inner app."""
from openapi_mcp_sdk.main import TokenQuerystringMiddleware
inner_app = AsyncMock()
return TokenQuerystringMiddleware(inner_app), inner_app
@pytest.fixture
def base_scope():
"""Return a minimal HTTP scope dict."""
return {
"type": "http",
"method": "GET",
"path": "/test",
"query_string": b"",
"headers": [],
}
def _header_bytes(headers: list[tuple[str, str]]) -> list[tuple[bytes, bytes]]:
"""Convert string header tuples to bytes as ASGI expects."""
return [(k.encode(), v.encode()) for k, v in headers]
class TestTokenQuerystringMiddleware:
"""Tests for TokenQuerystringMiddleware behavior."""
@pytest.mark.asyncio
async def test_query_token_only_sets_auth_header(self, middleware, base_scope):
"""When only ?token= is provided, it should be promoted to Authorization header."""
mw, inner_app = middleware
base_scope["query_string"] = b"token=my-token"
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
# Should call inner app, not short-circuit
inner_app.assert_called_once()
call_scope = inner_app.call_args[0][0]
headers_dict = {k.decode(): v.decode() for k, v in call_scope["headers"]}
assert headers_dict.get("authorization") == "Bearer my-token"
# Query string should be cleaned
assert call_scope["query_string"] == b""
@pytest.mark.asyncio
async def test_header_only_passes_through(self, middleware, base_scope):
"""When only Authorization header is provided, middleware should pass it through unchanged."""
mw, inner_app = middleware
base_scope["headers"] = _header_bytes([("Authorization", "Bearer header-token")])
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
inner_app.assert_called_once()
call_scope = inner_app.call_args[0][0]
headers_dict = {k.decode(): v.decode() for k, v in call_scope["headers"]}
assert headers_dict.get("authorization") == "Bearer header-token"
@pytest.mark.asyncio
async def test_conflicting_auth_returns_400(self, middleware, base_scope):
"""When both ?token= and Authorization header are present, return 400."""
mw, inner_app = middleware
base_scope["query_string"] = b"token=query-token"
base_scope["headers"] = _header_bytes([("Authorization", "Bearer header-token")])
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
# Should NOT call inner app — middleware short-circuits
inner_app.assert_not_called()
# Verify 400 response was sent
calls = send.call_args_list
assert len(calls) == 2
start_msg = calls[0][0][0]
assert start_msg["type"] == "http.response.start"
assert start_msg["status"] == 400
body_msg = calls[1][0][0]
assert body_msg["type"] == "http.response.body"
body = json.loads(body_msg["body"])
assert body["error"] == "conflicting_auth"
assert "Authorization header" in body["message"]
assert "?token=" in body["message"]
@pytest.mark.asyncio
async def test_conflicting_auth_response_is_json(self, middleware, base_scope):
"""The 400 response for conflicting auth should have application/json content-type."""
mw, inner_app = middleware
base_scope["query_string"] = b"token=a"
base_scope["headers"] = _header_bytes([("Authorization", "Bearer b")])
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
start_msg = send.call_args_list[0][0][0]
headers = {k: v for k, v in start_msg["headers"]}
assert b"application/json" == headers.get(b"content-type")
@pytest.mark.asyncio
async def test_no_auth_passes_through(self, middleware, base_scope):
"""When no auth is provided, middleware should pass through without error."""
mw, inner_app = middleware
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
inner_app.assert_called_once()
@pytest.mark.asyncio
async def test_non_http_scope_passes_through(self, middleware, base_scope):
"""Non-HTTP scopes (e.g., websocket) should pass through unchanged."""
mw, inner_app = middleware
scope = {"type": "websocket"}
receive = AsyncMock()
send = AsyncMock()
await mw(scope, receive, send)
inner_app.assert_called_once_with(scope, receive, send)
@pytest.mark.asyncio
async def test_query_token_with_other_params_preserved(self, middleware, base_scope):
"""When ?token= is present with other params, only token should be removed."""
mw, inner_app = middleware
base_scope["query_string"] = b"token=my-token&foo=bar&baz=qux"
receive = AsyncMock()
send = AsyncMock()
await mw(base_scope, receive, send)
call_scope = inner_app.call_args[0][0]
qs = call_scope["query_string"].decode()
assert "token=" not in qs
assert "foo=bar" in qs
assert "baz=qux" in qs