Skip to content
47 changes: 43 additions & 4 deletions mcp_proxy_for_aws/sigv4_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""SigV4 Helper for AWS request signing functionality."""

import asyncio
import boto3
import httpx
import json
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(self, session: boto3.Session, profile: Optional[str] = None) -> Non
self.session = session
self._profile = profile
self._needs_refresh = False
self._refresh_lock = asyncio.Lock()

def mark_needs_refresh(self) -> None:
"""Mark that the next request should use a fresh session."""
Expand All @@ -141,6 +143,39 @@ def refresh_if_needed(self) -> None:
'Failed to create fresh AWS session, keeping current session', exc_info=True
)

async def async_refresh_if_needed(self) -> None:
"""Async version of refresh_if_needed that offloads blocking I/O to a thread.

Uses a lock with double-check to prevent concurrent refresh races
when multiple requests are in flight. The outer check avoids
acquiring the lock on the fast path; the inner check ensures only
one caller actually performs the refresh.

The refresh is shielded from cancellation so that a cancelled caller
does not release the lock while the worker thread is still running.
"""
if not self._needs_refresh:
return
async with self._refresh_lock:
if not self._needs_refresh:
return
refresh = asyncio.ensure_future(asyncio.to_thread(self.refresh_if_needed))
try:
await asyncio.shield(refresh)
except asyncio.CancelledError:
await refresh
raise

async def async_get_credentials(self):
"""Resolve credentials without blocking the event loop.

``boto3.Session.get_credentials()`` may trigger a synchronous STS
``AssumeRole`` call when the profile uses chained credentials.
Offloading the call to a worker thread prevents the event loop from
stalling.
"""
return await asyncio.to_thread(self.session.get_credentials)


def create_sigv4_client(
service: str,
Expand Down Expand Up @@ -293,10 +328,14 @@ async def _sign_request_hook(
# Refresh session if a previous request got an auth error.
# Done here (at signing time) so the new session reads credentials
# that the user may have refreshed since the error occurred.
session_holder.refresh_if_needed()

# Get AWS credentials from the session
credentials = session_holder.session.get_credentials()
# Use async variant to avoid blocking the event loop when the profile
# requires an STS AssumeRole call (chained credentials).
await session_holder.async_refresh_if_needed()

# Get AWS credentials from the session.
# Offloaded to a thread because get_credentials() may trigger a
# synchronous STS call for assumed-role profiles.
credentials = await session_holder.async_get_credentials()

if credentials is None:
if skip_auth:
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_inject_metadata_hook,
_sign_request_hook,
)
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch


def create_request_with_sigv4_headers(
Expand Down Expand Up @@ -55,8 +55,15 @@ def create_mock_session():

def create_mock_session_holder():
"""Helper to create a mocked SessionHolder."""
mock_credentials = MagicMock()
mock_credentials.access_key = 'test-access-key'
mock_credentials.secret_key = 'test-secret-key'
mock_credentials.token = 'test-token'

holder = MagicMock(spec=SessionHolder)
holder.session = create_mock_session()
holder.async_refresh_if_needed = AsyncMock()
holder.async_get_credentials = AsyncMock(return_value=mock_credentials)
return holder


Expand Down Expand Up @@ -402,14 +409,14 @@ class TestSignRequestHook:

@pytest.mark.asyncio
async def test_sign_request_hook_calls_refresh_if_needed(self):
"""Signing hook calls refresh_if_needed before signing."""
"""Signing hook calls async_refresh_if_needed before signing."""
holder = create_mock_session_holder()
request_body = b'{"test": "data"}'
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)

await _sign_request_hook('us-east-1', 'execute-api', holder, False, request)

holder.refresh_if_needed.assert_called_once()
holder.async_refresh_if_needed.assert_awaited_once()

@pytest.mark.asyncio
async def test_sign_request_hook_signs_request(self):
Expand Down
Loading