From ae4855f6603c6a779af5987f3443c15e30b54b67 Mon Sep 17 00:00:00 2001 From: nix-tkobayashi Date: Fri, 8 May 2026 19:29:09 +0900 Subject: [PATCH 1/4] fix(auth): use asyncio.to_thread for credential resolution in signing hook `_sign_request_hook` is an async httpx event hook, but it calls `session_holder.session.get_credentials()` and `session_holder.refresh_if_needed()` synchronously. For profiles that use assumed IAM roles (chained credentials), `get_credentials()` triggers a blocking STS `AssumeRole` call that stalls the event loop and causes connection timeouts (~60 s). Add `SessionHolder.async_get_credentials()` and `SessionHolder.async_refresh_if_needed()` which delegate to `asyncio.to_thread`, keeping the event loop responsive. Update `_sign_request_hook` to call the async variants. Fixes #176 Co-Authored-By: Claude Opus 4.6 --- mcp_proxy_for_aws/sigv4_helper.py | 34 ++++++++-- tests/unit/test_hooks.py | 13 +++- tests/unit/test_sigv4_helper.py | 108 +++++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 8 deletions(-) diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 24f8169..d6c01cf 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -14,6 +14,7 @@ """SigV4 Helper for AWS request signing functionality.""" +import asyncio import boto3 import httpx import json @@ -141,6 +142,27 @@ def refresh_if_needed(self) -> None: 'Failed to create fresh AWS session, keeping current session', exc_info=True ) + async def async_refresh_if_needed(self) -> None: + """Async version of refresh_if_needed that offloads blocking I/O to a thread. + + For profiles that use assumed IAM roles (chained credentials), session + creation involves a blocking STS AssumeRole call. Running that on the + event loop blocks all other coroutines and causes timeouts. + """ + if not self._needs_refresh: + return + await asyncio.to_thread(self.refresh_if_needed) + + async def async_get_credentials(self): + """Resolve credentials without blocking the event loop. + + ``boto3.Session.get_credentials()`` may trigger a synchronous STS + ``AssumeRole`` call when the profile uses chained credentials. + Offloading the call to a worker thread prevents the event loop from + stalling. + """ + return await asyncio.to_thread(self.session.get_credentials) + def create_sigv4_client( service: str, @@ -289,10 +311,14 @@ async def _sign_request_hook( # Refresh session if a previous request got an auth error. # Done here (at signing time) so the new session reads credentials # that the user may have refreshed since the error occurred. - session_holder.refresh_if_needed() - - # Get AWS credentials from the session - credentials = session_holder.session.get_credentials() + # Use async variant to avoid blocking the event loop when the profile + # requires an STS AssumeRole call (chained credentials). + await session_holder.async_refresh_if_needed() + + # Get AWS credentials from the session. + # Offloaded to a thread because get_credentials() may trigger a + # synchronous STS call for assumed-role profiles. + credentials = await session_holder.async_get_credentials() # Create SigV4 auth and use its signing logic auth = SigV4HTTPXAuth(credentials, service, region) diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index efdc4a6..3a54750 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -24,7 +24,7 @@ _inject_metadata_hook, _sign_request_hook, ) -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock def create_request_with_sigv4_headers( @@ -55,8 +55,15 @@ def create_mock_session(): def create_mock_session_holder(): """Helper to create a mocked SessionHolder.""" + mock_credentials = MagicMock() + mock_credentials.access_key = 'test-access-key' + mock_credentials.secret_key = 'test-secret-key' + mock_credentials.token = 'test-token' + holder = MagicMock(spec=SessionHolder) holder.session = create_mock_session() + holder.async_refresh_if_needed = AsyncMock() + holder.async_get_credentials = AsyncMock(return_value=mock_credentials) return holder @@ -402,14 +409,14 @@ class TestSignRequestHook: @pytest.mark.asyncio async def test_sign_request_hook_calls_refresh_if_needed(self): - """Signing hook calls refresh_if_needed before signing.""" + """Signing hook calls async_refresh_if_needed before signing.""" holder = create_mock_session_holder() request_body = b'{"test": "data"}' request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) await _sign_request_hook('us-east-1', 'execute-api', holder, request) - holder.refresh_if_needed.assert_called_once() + holder.async_refresh_if_needed.assert_awaited_once() @pytest.mark.asyncio async def test_sign_request_hook_signs_request(self): diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index c83bb13..60d542b 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -25,10 +25,11 @@ SessionHolder, SigV4HTTPXAuth, _sanitize_headers, + _sign_request_hook, create_aws_session, create_sigv4_client, ) -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch class TestSigV4HTTPXAuth: @@ -369,6 +370,44 @@ def test_refresh_logs_original_error_on_failure(self, mock_create, caplog): assert 'Failed to create fresh AWS session' in caplog.text assert 'session creation boom' in caplog.text + @pytest.mark.asyncio + async def test_async_refresh_if_needed_noop_when_not_marked(self): + """async_refresh_if_needed does nothing when not marked.""" + mock_session = Mock() + holder = SessionHolder(mock_session) + + await holder.async_refresh_if_needed() + + assert holder.session is mock_session + + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_if_needed_creates_new_session_when_marked(self, mock_create): + """async_refresh_if_needed replaces session after mark_needs_refresh.""" + old_session = Mock() + new_session = Mock() + mock_create.return_value = new_session + holder = SessionHolder(old_session, profile='my-profile') + + holder.mark_needs_refresh() + await holder.async_refresh_if_needed() + + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + + @pytest.mark.asyncio + async def test_async_get_credentials_delegates_to_session(self): + """async_get_credentials calls session.get_credentials in a thread.""" + mock_creds = Mock() + mock_session = Mock() + mock_session.get_credentials.return_value = mock_creds + holder = SessionHolder(mock_session) + + result = await holder.async_get_credentials() + + mock_session.get_credentials.assert_called_once() + assert result is mock_creds + class TestSanitizeHeaders: """Test cases for the _sanitize_headers function.""" @@ -442,3 +481,70 @@ def test_sensitive_headers_constant_is_frozen(self): assert 'authorization' in SENSITIVE_HEADERS assert 'x-amz-security-token' in SENSITIVE_HEADERS assert 'x-amz-date' in SENSITIVE_HEADERS + + +class TestSignRequestHook: + """Test cases for the _sign_request_hook function.""" + + @pytest.mark.asyncio + async def test_sign_request_hook_uses_async_credential_resolution(self): + """Test that _sign_request_hook resolves credentials asynchronously.""" + mock_credentials = Mock() + mock_credentials.access_key = 'test_access_key' + mock_credentials.secret_key = 'test_secret_key' + mock_credentials.token = 'test_token' + + mock_session = Mock() + mock_session.get_credentials.return_value = mock_credentials + holder = SessionHolder(mock_session) + + # Spy on the async methods to verify they are called + holder.async_refresh_if_needed = AsyncMock() + holder.async_get_credentials = AsyncMock(return_value=mock_credentials) + + request = httpx.Request( + 'POST', + 'https://example.com/mcp', + content=b'{"jsonrpc":"2.0"}', + headers={'Host': 'example.com'}, + ) + + await _sign_request_hook('us-east-1', 'aws-mcp', holder, request) + + holder.async_refresh_if_needed.assert_awaited_once() + holder.async_get_credentials.assert_awaited_once() + assert 'Authorization' in request.headers + + @pytest.mark.asyncio + async def test_sign_request_hook_does_not_block_event_loop(self): + """Test that blocking credential resolution runs in a thread.""" + import asyncio + import time + + mock_credentials = Mock() + mock_credentials.access_key = 'test_access_key' + mock_credentials.secret_key = 'test_secret_key' + mock_credentials.token = None + + def slow_get_credentials(): + time.sleep(0.1) + return mock_credentials + + mock_session = Mock() + mock_session.get_credentials.side_effect = slow_get_credentials + holder = SessionHolder(mock_session) + + request = httpx.Request( + 'POST', + 'https://example.com/mcp', + content=b'{"jsonrpc":"2.0"}', + headers={'Host': 'example.com'}, + ) + + # Should complete without blocking — the event loop stays responsive + await asyncio.wait_for( + _sign_request_hook('us-east-1', 'aws-mcp', holder, request), + timeout=5.0, + ) + + assert 'Authorization' in request.headers From b4dbb30985835bbacb6b183aead7266879abb394 Mon Sep 17 00:00:00 2001 From: nix-tkobayashi Date: Fri, 8 May 2026 19:35:32 +0900 Subject: [PATCH 2/4] fix(auth): improve docstring accuracy and strengthen non-blocking test Address review feedback: - Clarify that get_credentials() is the primary blocking path; async_refresh_if_needed() is a defensive measure - Replace timeout-based non-blocking test with ticker coroutine that proves the event loop stays responsive during credential resolution Co-Authored-By: Claude Opus 4.6 --- mcp_proxy_for_aws/sigv4_helper.py | 8 +++++--- tests/unit/test_sigv4_helper.py | 30 +++++++++++++++++++++++------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index d6c01cf..209d637 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -145,9 +145,11 @@ def refresh_if_needed(self) -> None: async def async_refresh_if_needed(self) -> None: """Async version of refresh_if_needed that offloads blocking I/O to a thread. - For profiles that use assumed IAM roles (chained credentials), session - creation involves a blocking STS AssumeRole call. Running that on the - event loop blocks all other coroutines and causes timeouts. + Defensive counterpart to :meth:`async_get_credentials`. While the + primary blocking path is ``get_credentials()`` (which triggers a + synchronous STS call for assumed-role profiles), session recreation + via ``create_aws_session()`` may also perform blocking I/O in the + future. Offloading it to a thread keeps the event loop safe. """ if not self._needs_refresh: return diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index 60d542b..5316e8d 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -517,7 +517,11 @@ async def test_sign_request_hook_uses_async_credential_resolution(self): @pytest.mark.asyncio async def test_sign_request_hook_does_not_block_event_loop(self): - """Test that blocking credential resolution runs in a thread.""" + """Test that the event loop stays responsive while credentials resolve. + + A ticker coroutine runs alongside the signing hook. If get_credentials() + blocked the loop, the ticker would not advance during the sleep. + """ import asyncio import time @@ -527,7 +531,7 @@ async def test_sign_request_hook_does_not_block_event_loop(self): mock_credentials.token = None def slow_get_credentials(): - time.sleep(0.1) + time.sleep(0.3) return mock_credentials mock_session = Mock() @@ -541,10 +545,22 @@ def slow_get_credentials(): headers={'Host': 'example.com'}, ) - # Should complete without blocking — the event loop stays responsive - await asyncio.wait_for( - _sign_request_hook('us-east-1', 'aws-mcp', holder, request), - timeout=5.0, - ) + tick_count = 0 + + async def ticker(): + nonlocal tick_count + while True: + await asyncio.sleep(0.05) + tick_count += 1 + ticker_task = asyncio.create_task(ticker()) + try: + await _sign_request_hook('us-east-1', 'aws-mcp', holder, request) + finally: + ticker_task.cancel() + + # If get_credentials() blocked the loop, ticker would not have advanced + assert tick_count >= 2, ( + f'ticker only ticked {tick_count} times — event loop was likely blocked' + ) assert 'Authorization' in request.headers From b673a69957ed12088c1fcd9f04bb0b9ca1c8ff57 Mon Sep 17 00:00:00 2001 From: nix-tkobayashi Date: Wed, 20 May 2026 21:30:05 +0900 Subject: [PATCH 3/4] fix: add asyncio.Lock to SessionHolder.async_refresh_if_needed Prevent concurrent refresh races when multiple in-flight requests trigger credential refresh simultaneously. Uses a double-check pattern: the outer check avoids lock overhead on the fast path, the inner check ensures only one caller performs the refresh. Co-Authored-By: Claude Opus 4.6 --- mcp_proxy_for_aws/sigv4_helper.py | 15 +++++---- tests/unit/test_sigv4_helper.py | 54 +++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 209d637..187abe1 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -124,6 +124,7 @@ def __init__(self, session: boto3.Session, profile: Optional[str] = None) -> Non self.session = session self._profile = profile self._needs_refresh = False + self._refresh_lock = asyncio.Lock() def mark_needs_refresh(self) -> None: """Mark that the next request should use a fresh session.""" @@ -145,15 +146,17 @@ def refresh_if_needed(self) -> None: async def async_refresh_if_needed(self) -> None: """Async version of refresh_if_needed that offloads blocking I/O to a thread. - Defensive counterpart to :meth:`async_get_credentials`. While the - primary blocking path is ``get_credentials()`` (which triggers a - synchronous STS call for assumed-role profiles), session recreation - via ``create_aws_session()`` may also perform blocking I/O in the - future. Offloading it to a thread keeps the event loop safe. + Uses a lock with double-check to prevent concurrent refresh races + when multiple requests are in flight. The outer check avoids + acquiring the lock on the fast path; the inner check ensures only + one caller actually performs the refresh. """ if not self._needs_refresh: return - await asyncio.to_thread(self.refresh_if_needed) + async with self._refresh_lock: + if not self._needs_refresh: + return + await asyncio.to_thread(self.refresh_if_needed) async def async_get_credentials(self): """Resolve credentials without blocking the event loop. diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index 5316e8d..e44caec 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -14,6 +14,7 @@ """Unit tests for sigv4_helper module.""" +import asyncio import httpx import logging import pytest @@ -395,6 +396,59 @@ async def test_async_refresh_if_needed_creates_new_session_when_marked(self, moc mock_create.assert_called_once_with('my-profile') assert holder.session is new_session + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_concurrent_calls_refresh_once(self, mock_create): + """Concurrent async_refresh_if_needed calls only refresh once. + + Uses threading.Event to hold the first refresh inside create_aws_session + while the other callers enter async_refresh_if_needed, ensuring the + lock is actually contended. + """ + import threading + + old_session = Mock() + new_session = Mock() + entered_create = threading.Event() + proceed = threading.Event() + + def blocking_create(profile): + entered_create.set() + assert proceed.wait(timeout=5), 'proceed event was never set' + return new_session + + mock_create.side_effect = blocking_create + holder = SessionHolder(old_session, profile='my-profile') + holder.mark_needs_refresh() + + loop = asyncio.get_running_loop() + + async def wait_then_refresh(): + assert await loop.run_in_executor(None, entered_create.wait, 5), ( + 'entered_create event was never set' + ) + await holder.async_refresh_if_needed() + + async def first_refresh_and_unblock(): + task = asyncio.create_task(holder.async_refresh_if_needed()) + assert await loop.run_in_executor(None, entered_create.wait, 5), ( + 'entered_create event was never set' + ) + # Yield to give wait_then_refresh coroutines a chance to + # enter async_refresh_if_needed before the first refresh completes. + await asyncio.sleep(0) + proceed.set() + await task + + await asyncio.gather( + first_refresh_and_unblock(), + wait_then_refresh(), + wait_then_refresh(), + ) + + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + @pytest.mark.asyncio async def test_async_get_credentials_delegates_to_session(self): """async_get_credentials calls session.get_credentials in a thread.""" From f2314762236ff6179a101ed7a3cd8fe352e84d82 Mon Sep 17 00:00:00 2001 From: nix-tkobayashi Date: Wed, 20 May 2026 21:46:38 +0900 Subject: [PATCH 4/4] fix: shield refresh from cancellation to prevent duplicate STS calls When the task holding _refresh_lock is cancelled during asyncio.to_thread(), the lock would be released while the worker thread is still running. A second caller could then acquire the lock, see _needs_refresh == True, and start a duplicate refresh. Shield the to_thread call so that cancellation waits for the refresh to complete before releasing the lock and re-raising CancelledError. Co-Authored-By: Claude Opus 4.6 --- mcp_proxy_for_aws/sigv4_helper.py | 10 +++- tests/unit/test_sigv4_helper.py | 78 +++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 187abe1..b464633 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -150,13 +150,21 @@ async def async_refresh_if_needed(self) -> None: when multiple requests are in flight. The outer check avoids acquiring the lock on the fast path; the inner check ensures only one caller actually performs the refresh. + + The refresh is shielded from cancellation so that a cancelled caller + does not release the lock while the worker thread is still running. """ if not self._needs_refresh: return async with self._refresh_lock: if not self._needs_refresh: return - await asyncio.to_thread(self.refresh_if_needed) + refresh = asyncio.ensure_future(asyncio.to_thread(self.refresh_if_needed)) + try: + await asyncio.shield(refresh) + except asyncio.CancelledError: + await refresh + raise async def async_get_credentials(self): """Resolve credentials without blocking the event loop. diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index e44caec..92ba2d0 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -449,6 +449,84 @@ async def first_refresh_and_unblock(): mock_create.assert_called_once_with('my-profile') assert holder.session is new_session + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_cancellation_does_not_duplicate(self, mock_create): + """Cancelling the caller does not release the lock prematurely. + + Scenario: task1 holds the lock and is inside create_aws_session. + task2 is waiting on the lock. task1 is cancelled. Without the + shield, the lock would be released while the worker thread is still + running, allowing task2 to start a duplicate refresh. + + The test distinguishes first and second create_aws_session calls + and verifies that a second call never happens. + """ + import threading + + old_session = Mock() + new_session = Mock() + entered_first = threading.Event() + finish_first = threading.Event() + entered_second = threading.Event() + + call_count = 0 + + def create_side_effect(profile): + nonlocal call_count + call_count += 1 + if call_count == 1: + entered_first.set() + assert finish_first.wait(timeout=5), 'finish_first event was never set' + return new_session + else: + entered_second.set() + return new_session + + mock_create.side_effect = create_side_effect + holder = SessionHolder(old_session, profile='my-profile') + holder.mark_needs_refresh() + + loop = asyncio.get_running_loop() + + # task1 enters create_aws_session and blocks on finish_first. + task1 = asyncio.create_task(holder.async_refresh_if_needed()) + assert await loop.run_in_executor(None, entered_first.wait, 5), ( + 'entered_first event was never set' + ) + + # task2 passes the outer _needs_refresh check and blocks on the lock. + task2 = asyncio.create_task(holder.async_refresh_if_needed()) + await asyncio.sleep(0) + + # Cancel task1 while the worker thread is still blocked. + task1.cancel() + + # Give the event loop time to process the cancellation and let + # task2 potentially acquire the lock and start a second refresh. + # run_in_executor yields control to the event loop while waiting. + duplicate_started = await loop.run_in_executor( + None, entered_second.wait, 0.5 + ) + assert not duplicate_started, ( + 'task2 started a duplicate refresh — lock was released prematurely' + ) + + # Now let the first worker thread finish. + finish_first.set() + + # task1 re-raises CancelledError after the refresh completes. + with pytest.raises(asyncio.CancelledError): + await task1 + + # task2 acquires the lock, sees _needs_refresh == False, returns. + await task2 + + # create_aws_session was called exactly once. + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + assert not entered_second.is_set() + @pytest.mark.asyncio async def test_async_get_credentials_delegates_to_session(self): """async_get_credentials calls session.get_credentials in a thread."""