Skip to content

Commit e3520e7

Browse files
committed
fix: always read fresh credentials from disk on every request
1 parent 9da6e41 commit e3520e7

6 files changed

Lines changed: 68 additions & 173 deletions

File tree

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -112,34 +112,20 @@ def create_aws_session(profile: Optional[str] = None) -> boto3.Session:
112112

113113

114114
class SessionHolder:
115-
"""Holds a boto3 session that can be refreshed on credential errors.
115+
"""Provides fresh boto3 sessions so every request reads current credentials from disk.
116116
117-
Wraps a boto3.Session so the signing hook always uses the current session,
118-
and can create a fresh session when the previous request got an auth error.
117+
Instead of caching a session and refreshing reactively on auth errors,
118+
this always creates a new session so account switches and credential
119+
refreshes take effect immediately.
119120
"""
120121

121-
def __init__(self, session: boto3.Session, profile: Optional[str] = None) -> None:
122-
"""Initialize SessionHolder with the given session and optional profile."""
123-
self.session = session
122+
def __init__(self, profile: Optional[str] = None) -> None:
123+
"""Initialize SessionHolder with the given profile."""
124124
self._profile = profile
125-
self._needs_refresh = False
126125

127-
def mark_needs_refresh(self) -> None:
128-
"""Mark that the next request should use a fresh session."""
129-
self._needs_refresh = True
130-
131-
def refresh_if_needed(self) -> None:
132-
"""Create a fresh session if a previous request got an auth error."""
133-
if not self._needs_refresh:
134-
return
135-
logger.info('Refreshing AWS session to pick up new credentials')
136-
try:
137-
self.session = create_aws_session(self._profile)
138-
self._needs_refresh = False
139-
except Exception:
140-
logger.warning(
141-
'Failed to create fresh AWS session, keeping current session', exc_info=True
142-
)
126+
def get_session(self) -> boto3.Session:
127+
"""Create and return a fresh boto3 session reading current credentials from disk."""
128+
return create_aws_session(self._profile)
143129

144130

145131
def create_sigv4_client(
@@ -219,17 +205,13 @@ async def _handle_error_response(session_holder: SessionHolder, response: httpx.
219205
and provide more detailed error information when requests fail.
220206
221207
Args:
222-
session_holder: SessionHolder to refresh on credential errors
208+
session_holder: SessionHolder (unused, kept for hook signature compatibility)
223209
response: The HTTP response object
224210
225211
Raises:
226212
No raises. let the mcp http client handle the errors.
227213
"""
228214
if response.is_error:
229-
# Mark session for refresh so the next request picks up new credentials
230-
if response.status_code in (401, 403):
231-
session_holder.mark_needs_refresh()
232-
233215
# warning only because the SDK logs error
234216
log_level = logging.WARNING
235217
if (
@@ -290,13 +272,10 @@ async def _sign_request_hook(
290272
# Set Content-Length for signing
291273
request.headers['Content-Length'] = str(len(request.content))
292274

293-
# Refresh session if a previous request got an auth error.
294-
# Done here (at signing time) so the new session reads credentials
295-
# that the user may have refreshed since the error occurred.
296-
session_holder.refresh_if_needed()
297-
298-
# Get AWS credentials from the session
299-
credentials = session_holder.session.get_credentials()
275+
# Always read fresh credentials from disk so account switches
276+
# and credential refreshes take effect immediately.
277+
session = session_holder.get_session()
278+
credentials = session.get_credentials()
300279

301280
if credentials is None:
302281
if skip_auth:

mcp_proxy_for_aws/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
from fastmcp.client.transports import StreamableHttpTransport
22-
from mcp_proxy_for_aws.sigv4_helper import SessionHolder, create_aws_session, create_sigv4_client
22+
from mcp_proxy_for_aws.sigv4_helper import SessionHolder, create_sigv4_client
2323
from typing import Any, Dict, Optional, Tuple
2424
from urllib.parse import urlparse
2525

@@ -91,10 +91,9 @@ def create_transport_with_sigv4(
9191
Returns:
9292
StreamableHttpTransport instance with SigV4 authentication
9393
"""
94-
# Create AWS session with a holder that can refresh on credential errors
95-
logger.debug('Creating AWS session with profile: %s', profile)
96-
session = create_aws_session(profile)
97-
session_holder = SessionHolder(session, profile)
94+
# Create session holder that reads fresh credentials on every request
95+
logger.debug('Creating session holder with profile: %s', profile)
96+
session_holder = SessionHolder(profile=profile)
9897

9998
def client_factory(
10099
headers: Optional[Dict[str, str]] = None,

tests/unit/test_hooks.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ def create_mock_session():
5656
def create_mock_session_holder():
5757
"""Helper to create a mocked SessionHolder."""
5858
holder = MagicMock(spec=SessionHolder)
59-
holder.session = create_mock_session()
59+
holder.get_session.return_value = create_mock_session()
6060
return holder
6161

6262

6363
class TestHandleErrorResponse:
6464
"""Test cases for the _handle_error_response function."""
6565

6666
@pytest.mark.asyncio
67-
async def test_handle_error_response_marks_refresh_on_401(self):
68-
"""401 response marks the session holder for refresh."""
67+
async def test_handle_error_response_logs_401(self):
68+
"""401 response is handled without error."""
6969
request = httpx.Request('POST', 'https://example.com/mcp')
7070
response = httpx.Response(
7171
status_code=401,
@@ -77,11 +77,9 @@ async def test_handle_error_response_marks_refresh_on_401(self):
7777

7878
await _handle_error_response(holder, response)
7979

80-
holder.mark_needs_refresh.assert_called_once()
81-
8280
@pytest.mark.asyncio
83-
async def test_handle_error_response_marks_refresh_on_403(self):
84-
"""403 response marks the session holder for refresh."""
81+
async def test_handle_error_response_logs_403(self):
82+
"""403 response is handled without error."""
8583
request = httpx.Request('POST', 'https://example.com/mcp')
8684
response = httpx.Response(
8785
status_code=403,
@@ -93,25 +91,6 @@ async def test_handle_error_response_marks_refresh_on_403(self):
9391

9492
await _handle_error_response(holder, response)
9593

96-
holder.mark_needs_refresh.assert_called_once()
97-
98-
@pytest.mark.asyncio
99-
async def test_handle_error_response_does_not_mark_refresh_on_other_errors(self):
100-
"""Non-auth error codes (400, 404, 500) do not mark refresh."""
101-
for status_code in (400, 404, 500):
102-
request = httpx.Request('POST', 'https://example.com/mcp')
103-
response = httpx.Response(
104-
status_code=status_code,
105-
headers={'content-type': 'text/plain'},
106-
content=b'Error',
107-
request=request,
108-
)
109-
holder = create_mock_session_holder()
110-
111-
await _handle_error_response(holder, response)
112-
113-
holder.mark_needs_refresh.assert_not_called()
114-
11594
@pytest.mark.asyncio
11695
async def test_handle_error_response_with_json_error(self):
11796
"""Test error handling with JSON error response."""
@@ -401,15 +380,15 @@ class TestSignRequestHook:
401380
"""Test cases for sign_request_hook function."""
402381

403382
@pytest.mark.asyncio
404-
async def test_sign_request_hook_calls_refresh_if_needed(self):
405-
"""Signing hook calls refresh_if_needed before signing."""
383+
async def test_sign_request_hook_calls_get_session(self):
384+
"""Signing hook calls get_session to read fresh credentials."""
406385
holder = create_mock_session_holder()
407386
request_body = b'{"test": "data"}'
408387
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
409388

410389
await _sign_request_hook('us-east-1', 'execute-api', holder, False, request)
411390

412-
holder.refresh_if_needed.assert_called_once()
391+
holder.get_session.assert_called_once()
413392

414393
@pytest.mark.asyncio
415394
async def test_sign_request_hook_signs_request(self):
@@ -482,7 +461,7 @@ async def test_sign_request_hook_with_partial_application(self):
482461
async def test_sign_request_hook_skips_signing_when_skip_auth(self):
483462
"""Request is sent unsigned when credentials are unavailable and skip_auth is True."""
484463
holder = create_mock_session_holder()
485-
holder.session.get_credentials.return_value = None
464+
holder.get_session.return_value.get_credentials.return_value = None
486465

487466
request_body = b'{"test": "data"}'
488467
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
@@ -497,7 +476,7 @@ async def test_sign_request_hook_skips_signing_when_skip_auth(self):
497476
async def test_sign_request_hook_raises_when_no_credentials_and_no_skip_auth(self):
498477
"""ValueError is raised when credentials are unavailable and skip_auth is False."""
499478
holder = create_mock_session_holder()
500-
holder.session.get_credentials.return_value = None
479+
holder.get_session.return_value.get_credentials.return_value = None
501480

502481
request_body = b'{"test": "data"}'
503482
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
@@ -506,25 +485,25 @@ async def test_sign_request_hook_raises_when_no_credentials_and_no_skip_auth(sel
506485
await _sign_request_hook('us-east-1', 'execute-api', holder, False, request)
507486

508487
@pytest.mark.asyncio
509-
async def test_sign_request_hook_no_credentials_still_refreshes(self):
510-
"""refresh_if_needed is called even when credentials end up None."""
488+
async def test_sign_request_hook_no_credentials_still_calls_get_session(self):
489+
"""get_session is called even when credentials end up None."""
511490
holder = create_mock_session_holder()
512-
holder.session.get_credentials.return_value = None
491+
holder.get_session.return_value.get_credentials.return_value = None
513492

514493
request_body = b'test'
515494
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
516495

517496
with pytest.raises(ValueError):
518497
await _sign_request_hook('us-east-1', 'execute-api', holder, False, request)
519498

520-
holder.refresh_if_needed.assert_called_once()
499+
holder.get_session.assert_called_once()
521500

522501
@pytest.mark.asyncio
523502
@patch('mcp_proxy_for_aws.sigv4_helper.SigV4HTTPXAuth')
524503
async def test_sign_request_hook_no_credentials_does_not_create_auth(self, mock_auth_class):
525504
"""SigV4HTTPXAuth is never instantiated when credentials are None and skip_auth is True."""
526505
holder = create_mock_session_holder()
527-
holder.session.get_credentials.return_value = None
506+
holder.get_session.return_value.get_credentials.return_value = None
528507

529508
request_body = b'test'
530509
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)

tests/unit/test_server.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,7 @@ def test_validate_service_name_service_parsing(self):
406406
@patch('mcp_proxy_for_aws.sigv4_helper.httpx.AsyncClient')
407407
def test_create_sigv4_client(self, mock_async_client):
408408
"""Test creating SigV4 authenticated client with request hooks."""
409-
mock_session = Mock()
410-
mock_session.get_credentials.return_value = Mock(access_key='test-key')
411-
session_holder = SessionHolder(mock_session, profile='test-profile')
409+
session_holder = SessionHolder(profile='test-profile')
412410

413411
create_sigv4_client(
414412
service='test-service', region='us-west-2', session_holder=session_holder
@@ -423,8 +421,7 @@ def test_create_sigv4_client(self, mock_async_client):
423421

424422
def test_create_sigv4_client_no_credentials(self):
425423
"""Test that credential check happens in sign_request_hook, not during client creation."""
426-
mock_session = Mock()
427-
session_holder = SessionHolder(mock_session)
424+
session_holder = SessionHolder()
428425

429426
client = create_sigv4_client(
430427
service='test-service', region='test-region', session_holder=session_holder

0 commit comments

Comments
 (0)