diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index da68166..82e648c 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -14,6 +14,7 @@ """SigV4 Helper for AWS request signing functionality.""" +import asyncio import boto3 import httpx import json @@ -123,6 +124,7 @@ def __init__(self, session: boto3.Session, profile: Optional[str] = None) -> Non self.session = session self._profile = profile self._needs_refresh = False + self._refresh_lock = asyncio.Lock() def mark_needs_refresh(self) -> None: """Mark that the next request should use a fresh session.""" @@ -141,6 +143,39 @@ def refresh_if_needed(self) -> None: 'Failed to create fresh AWS session, keeping current session', exc_info=True ) + async def async_refresh_if_needed(self) -> None: + """Async version of refresh_if_needed that offloads blocking I/O to a thread. + + Uses a lock with double-check to prevent concurrent refresh races + when multiple requests are in flight. The outer check avoids + acquiring the lock on the fast path; the inner check ensures only + one caller actually performs the refresh. + + The refresh is shielded from cancellation so that a cancelled caller + does not release the lock while the worker thread is still running. + """ + if not self._needs_refresh: + return + async with self._refresh_lock: + if not self._needs_refresh: + return + refresh = asyncio.ensure_future(asyncio.to_thread(self.refresh_if_needed)) + try: + await asyncio.shield(refresh) + except asyncio.CancelledError: + await refresh + raise + + async def async_get_credentials(self): + """Resolve credentials without blocking the event loop. + + ``boto3.Session.get_credentials()`` may trigger a synchronous STS + ``AssumeRole`` call when the profile uses chained credentials. + Offloading the call to a worker thread prevents the event loop from + stalling. + """ + return await asyncio.to_thread(self.session.get_credentials) + def create_sigv4_client( service: str, @@ -293,10 +328,14 @@ async def _sign_request_hook( # Refresh session if a previous request got an auth error. # Done here (at signing time) so the new session reads credentials # that the user may have refreshed since the error occurred. - session_holder.refresh_if_needed() - - # Get AWS credentials from the session - credentials = session_holder.session.get_credentials() + # Use async variant to avoid blocking the event loop when the profile + # requires an STS AssumeRole call (chained credentials). + await session_holder.async_refresh_if_needed() + + # Get AWS credentials from the session. + # Offloaded to a thread because get_credentials() may trigger a + # synchronous STS call for assumed-role profiles. + credentials = await session_holder.async_get_credentials() if credentials is None: if skip_auth: diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 5fdecd4..c4d7dd1 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -24,7 +24,7 @@ _inject_metadata_hook, _sign_request_hook, ) -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch def create_request_with_sigv4_headers( @@ -55,8 +55,15 @@ def create_mock_session(): def create_mock_session_holder(): """Helper to create a mocked SessionHolder.""" + mock_credentials = MagicMock() + mock_credentials.access_key = 'test-access-key' + mock_credentials.secret_key = 'test-secret-key' + mock_credentials.token = 'test-token' + holder = MagicMock(spec=SessionHolder) holder.session = create_mock_session() + holder.async_refresh_if_needed = AsyncMock() + holder.async_get_credentials = AsyncMock(return_value=mock_credentials) return holder @@ -402,14 +409,14 @@ class TestSignRequestHook: @pytest.mark.asyncio async def test_sign_request_hook_calls_refresh_if_needed(self): - """Signing hook calls refresh_if_needed before signing.""" + """Signing hook calls async_refresh_if_needed before signing.""" holder = create_mock_session_holder() request_body = b'{"test": "data"}' request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) await _sign_request_hook('us-east-1', 'execute-api', holder, False, request) - holder.refresh_if_needed.assert_called_once() + holder.async_refresh_if_needed.assert_awaited_once() @pytest.mark.asyncio async def test_sign_request_hook_signs_request(self): diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index c83bb13..92ba2d0 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -14,6 +14,7 @@ """Unit tests for sigv4_helper module.""" +import asyncio import httpx import logging import pytest @@ -25,10 +26,11 @@ SessionHolder, SigV4HTTPXAuth, _sanitize_headers, + _sign_request_hook, create_aws_session, create_sigv4_client, ) -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch class TestSigV4HTTPXAuth: @@ -369,6 +371,175 @@ def test_refresh_logs_original_error_on_failure(self, mock_create, caplog): assert 'Failed to create fresh AWS session' in caplog.text assert 'session creation boom' in caplog.text + @pytest.mark.asyncio + async def test_async_refresh_if_needed_noop_when_not_marked(self): + """async_refresh_if_needed does nothing when not marked.""" + mock_session = Mock() + holder = SessionHolder(mock_session) + + await holder.async_refresh_if_needed() + + assert holder.session is mock_session + + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_if_needed_creates_new_session_when_marked(self, mock_create): + """async_refresh_if_needed replaces session after mark_needs_refresh.""" + old_session = Mock() + new_session = Mock() + mock_create.return_value = new_session + holder = SessionHolder(old_session, profile='my-profile') + + holder.mark_needs_refresh() + await holder.async_refresh_if_needed() + + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_concurrent_calls_refresh_once(self, mock_create): + """Concurrent async_refresh_if_needed calls only refresh once. + + Uses threading.Event to hold the first refresh inside create_aws_session + while the other callers enter async_refresh_if_needed, ensuring the + lock is actually contended. + """ + import threading + + old_session = Mock() + new_session = Mock() + entered_create = threading.Event() + proceed = threading.Event() + + def blocking_create(profile): + entered_create.set() + assert proceed.wait(timeout=5), 'proceed event was never set' + return new_session + + mock_create.side_effect = blocking_create + holder = SessionHolder(old_session, profile='my-profile') + holder.mark_needs_refresh() + + loop = asyncio.get_running_loop() + + async def wait_then_refresh(): + assert await loop.run_in_executor(None, entered_create.wait, 5), ( + 'entered_create event was never set' + ) + await holder.async_refresh_if_needed() + + async def first_refresh_and_unblock(): + task = asyncio.create_task(holder.async_refresh_if_needed()) + assert await loop.run_in_executor(None, entered_create.wait, 5), ( + 'entered_create event was never set' + ) + # Yield to give wait_then_refresh coroutines a chance to + # enter async_refresh_if_needed before the first refresh completes. + await asyncio.sleep(0) + proceed.set() + await task + + await asyncio.gather( + first_refresh_and_unblock(), + wait_then_refresh(), + wait_then_refresh(), + ) + + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + + @pytest.mark.asyncio + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + async def test_async_refresh_cancellation_does_not_duplicate(self, mock_create): + """Cancelling the caller does not release the lock prematurely. + + Scenario: task1 holds the lock and is inside create_aws_session. + task2 is waiting on the lock. task1 is cancelled. Without the + shield, the lock would be released while the worker thread is still + running, allowing task2 to start a duplicate refresh. + + The test distinguishes first and second create_aws_session calls + and verifies that a second call never happens. + """ + import threading + + old_session = Mock() + new_session = Mock() + entered_first = threading.Event() + finish_first = threading.Event() + entered_second = threading.Event() + + call_count = 0 + + def create_side_effect(profile): + nonlocal call_count + call_count += 1 + if call_count == 1: + entered_first.set() + assert finish_first.wait(timeout=5), 'finish_first event was never set' + return new_session + else: + entered_second.set() + return new_session + + mock_create.side_effect = create_side_effect + holder = SessionHolder(old_session, profile='my-profile') + holder.mark_needs_refresh() + + loop = asyncio.get_running_loop() + + # task1 enters create_aws_session and blocks on finish_first. + task1 = asyncio.create_task(holder.async_refresh_if_needed()) + assert await loop.run_in_executor(None, entered_first.wait, 5), ( + 'entered_first event was never set' + ) + + # task2 passes the outer _needs_refresh check and blocks on the lock. + task2 = asyncio.create_task(holder.async_refresh_if_needed()) + await asyncio.sleep(0) + + # Cancel task1 while the worker thread is still blocked. + task1.cancel() + + # Give the event loop time to process the cancellation and let + # task2 potentially acquire the lock and start a second refresh. + # run_in_executor yields control to the event loop while waiting. + duplicate_started = await loop.run_in_executor( + None, entered_second.wait, 0.5 + ) + assert not duplicate_started, ( + 'task2 started a duplicate refresh — lock was released prematurely' + ) + + # Now let the first worker thread finish. + finish_first.set() + + # task1 re-raises CancelledError after the refresh completes. + with pytest.raises(asyncio.CancelledError): + await task1 + + # task2 acquires the lock, sees _needs_refresh == False, returns. + await task2 + + # create_aws_session was called exactly once. + mock_create.assert_called_once_with('my-profile') + assert holder.session is new_session + assert not entered_second.is_set() + + @pytest.mark.asyncio + async def test_async_get_credentials_delegates_to_session(self): + """async_get_credentials calls session.get_credentials in a thread.""" + mock_creds = Mock() + mock_session = Mock() + mock_session.get_credentials.return_value = mock_creds + holder = SessionHolder(mock_session) + + result = await holder.async_get_credentials() + + mock_session.get_credentials.assert_called_once() + assert result is mock_creds + class TestSanitizeHeaders: """Test cases for the _sanitize_headers function.""" @@ -442,3 +613,86 @@ def test_sensitive_headers_constant_is_frozen(self): assert 'authorization' in SENSITIVE_HEADERS assert 'x-amz-security-token' in SENSITIVE_HEADERS assert 'x-amz-date' in SENSITIVE_HEADERS + + +class TestSignRequestHook: + """Test cases for the _sign_request_hook function.""" + + @pytest.mark.asyncio + async def test_sign_request_hook_uses_async_credential_resolution(self): + """Test that _sign_request_hook resolves credentials asynchronously.""" + mock_credentials = Mock() + mock_credentials.access_key = 'test_access_key' + mock_credentials.secret_key = 'test_secret_key' + mock_credentials.token = 'test_token' + + mock_session = Mock() + mock_session.get_credentials.return_value = mock_credentials + holder = SessionHolder(mock_session) + + # Spy on the async methods to verify they are called + holder.async_refresh_if_needed = AsyncMock() + holder.async_get_credentials = AsyncMock(return_value=mock_credentials) + + request = httpx.Request( + 'POST', + 'https://example.com/mcp', + content=b'{"jsonrpc":"2.0"}', + headers={'Host': 'example.com'}, + ) + + await _sign_request_hook('us-east-1', 'aws-mcp', holder, request) + + holder.async_refresh_if_needed.assert_awaited_once() + holder.async_get_credentials.assert_awaited_once() + assert 'Authorization' in request.headers + + @pytest.mark.asyncio + async def test_sign_request_hook_does_not_block_event_loop(self): + """Test that the event loop stays responsive while credentials resolve. + + A ticker coroutine runs alongside the signing hook. If get_credentials() + blocked the loop, the ticker would not advance during the sleep. + """ + import asyncio + import time + + mock_credentials = Mock() + mock_credentials.access_key = 'test_access_key' + mock_credentials.secret_key = 'test_secret_key' + mock_credentials.token = None + + def slow_get_credentials(): + time.sleep(0.3) + return mock_credentials + + mock_session = Mock() + mock_session.get_credentials.side_effect = slow_get_credentials + holder = SessionHolder(mock_session) + + request = httpx.Request( + 'POST', + 'https://example.com/mcp', + content=b'{"jsonrpc":"2.0"}', + headers={'Host': 'example.com'}, + ) + + tick_count = 0 + + async def ticker(): + nonlocal tick_count + while True: + await asyncio.sleep(0.05) + tick_count += 1 + + ticker_task = asyncio.create_task(ticker()) + try: + await _sign_request_hook('us-east-1', 'aws-mcp', holder, request) + finally: + ticker_task.cancel() + + # If get_credentials() blocked the loop, ticker would not have advanced + assert tick_count >= 2, ( + f'ticker only ticked {tick_count} times — event loop was likely blocked' + ) + assert 'Authorization' in request.headers