Skip to content

Commit 8fc772f

Browse files
anasstahriddv
andauthored
fix: refresh stale credentials on auth failure without restart (#245)
* fix: refresh stale credentials on auth failure without restart When AWS credentials expire, the proxy now automatically picks up refreshed credentials on the next request without requiring a restart. Problem: - boto3.Session was created once at startup and cached forever - session.get_credentials() returned the same stale Credentials object - Even after refreshing creds on disk (ada, aws sso login), the proxy kept using the old frozen session until restarted Fix: - Add SessionHolder that wraps the boto3 session with lazy refresh - On 401/403, mark the session for refresh (don't refresh immediately, since creds on disk may not be updated yet) - On the next request's signing, create a fresh boto3.Session that reads the current credentials from disk - Improve error messages: credential errors now clearly say 'expired or invalid AWS credentials' instead of 'Unknown tool' The lazy refresh pattern ensures the new session is created at signing time (after the user has refreshed creds), not at error time (when creds may still be stale). * fix: address review comments on credential refresh * chore: fix ruff lint and format issues * test: add unit tests for credential refresh flow * fix: address PR review comments --------- Co-authored-by: Ian de Villiers <iddv@amazon.com>
1 parent 08a93de commit 8fc772f

9 files changed

Lines changed: 331 additions & 135 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,10 @@ For long-running sessions, consider using long-lived credentials:
353353
- Use an AWS profile via `--profile`
354354
- Use IAM Identity Center and run `aws sso login` before starting the proxy
355355

356+
If your credentials do expire during a session, the proxy will automatically detect the auth failure and pick up refreshed credentials on the next request — no restart required. Simply refresh your credentials (e.g., `aws sso login`) and retry.
357+
356358
### Client hangs on tool calls
357-
If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely.
359+
If your MCP client hangs waiting for a tool call response (e.g., due to an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely.
358360

359361
## Development & Contributing
360362

mcp_proxy_for_aws/middleware/tool_error_middleware.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,27 @@ async def on_call_tool(
6262
logger.error('Tool call %r failed: %s.', tool_name, e)
6363
message = f'Tool call {tool_name!r} failed: {e}. Please retry.'
6464
if self._is_credential_error(e):
65-
message += (
66-
' This may be caused by expired or invalid AWS credentials.'
67-
' Consider using long-lived credentials such as an AWS profile'
68-
' (--profile) or IAM Identity Center (aws sso login).'
65+
message = (
66+
f'Tool call {tool_name!r} failed due to expired or invalid AWS credentials.'
67+
' Please refresh your credentials and retry.'
68+
' The proxy will automatically use the new credentials on the next request.'
6969
)
7070
raise ToolError(message) from e
7171

7272
@staticmethod
7373
def _is_credential_error(error: Exception) -> bool:
74-
"""Check if the error is likely caused by expired or invalid credentials."""
75-
return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in (
76-
401,
77-
403,
78-
)
74+
"""Check if the error is likely caused by expired or invalid credentials.
75+
76+
Walks the exception chain (__cause__/__context__) because the
77+
HTTPStatusError may be wrapped. isinstance already respects the
78+
MRO, so subclasses are caught too.
79+
"""
80+
current: BaseException | None = error
81+
while current is not None:
82+
if isinstance(current, httpx.HTTPStatusError) and current.response.status_code in (
83+
401,
84+
403,
85+
):
86+
return True
87+
current = current.__cause__ or current.__context__
88+
return False

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,42 @@ def create_aws_session(profile: Optional[str] = None) -> boto3.Session:
120120
return session
121121

122122

123+
class SessionHolder:
124+
"""Holds a boto3 session that can be refreshed on credential errors.
125+
126+
Wraps a boto3.Session so the signing hook always uses the current session,
127+
and can create a fresh session when the previous request got an auth error.
128+
"""
129+
130+
def __init__(self, session: boto3.Session, profile: Optional[str] = None) -> None:
131+
"""Initialize SessionHolder with the given session and optional profile."""
132+
self.session = session
133+
self._profile = profile
134+
self._needs_refresh = False
135+
136+
def mark_needs_refresh(self) -> None:
137+
"""Mark that the next request should use a fresh session."""
138+
self._needs_refresh = True
139+
140+
def refresh_if_needed(self) -> None:
141+
"""Create a fresh session if a previous request got an auth error."""
142+
if not self._needs_refresh:
143+
return
144+
logger.info('Refreshing AWS session to pick up new credentials')
145+
try:
146+
self.session = create_aws_session(self._profile)
147+
self._needs_refresh = False
148+
except Exception:
149+
logger.warning(
150+
'Failed to create fresh AWS session, keeping current session', exc_info=True
151+
)
152+
153+
123154
def create_sigv4_client(
124155
service: str,
125156
region: str,
157+
session_holder: SessionHolder,
126158
timeout: Optional[httpx.Timeout] = None,
127-
profile: Optional[str] = None,
128-
session: Optional[boto3.Session] = None,
129159
headers: Optional[Dict[str, str]] = None,
130160
metadata: Optional[Dict[str, Any]] = None,
131161
disable_telemetry: bool = False,
@@ -135,9 +165,9 @@ def create_sigv4_client(
135165
136166
Args:
137167
service: AWS service name for SigV4 signing
138-
profile: AWS profile to use (optional, only used if session is not provided)
139-
session: AWS boto3 session to use (optional, takes precedence over profile)
140-
region: AWS region (optional, defaults to AWS_REGION env var or us-east-1)
168+
region: AWS region for SigV4 signing
169+
session_holder: SessionHolder that provides the current boto3 session
170+
and can refresh it on credential errors
141171
timeout: Timeout configuration for the HTTP client
142172
headers: Headers to include in requests
143173
metadata: Metadata to inject into MCP _meta field
@@ -147,10 +177,6 @@ def create_sigv4_client(
147177
Returns:
148178
httpx.AsyncClient with SigV4 authentication
149179
"""
150-
# Create or use provided AWS session
151-
if session is None:
152-
session = create_aws_session(profile)
153-
154180
# Create a copy of kwargs to avoid modifying the passed dict
155181
client_kwargs = {
156182
'follow_redirects': True,
@@ -184,28 +210,33 @@ def create_sigv4_client(
184210
return httpx.AsyncClient(
185211
**client_kwargs,
186212
event_hooks={
187-
'response': [_handle_error_response],
213+
'response': [partial(_handle_error_response, session_holder)],
188214
'request': [
189215
partial(_inject_metadata_hook, metadata or {}),
190-
partial(_sign_request_hook, region, service, session),
216+
partial(_sign_request_hook, region, service, session_holder),
191217
],
192218
},
193219
)
194220

195221

196-
async def _handle_error_response(response: httpx.Response) -> None:
222+
async def _handle_error_response(session_holder: SessionHolder, response: httpx.Response) -> None:
197223
"""Event hook to handle HTTP error responses and extract details.
198224
199225
This function is called for every HTTP response to check for errors
200226
and provide more detailed error information when requests fail.
201227
202228
Args:
229+
session_holder: SessionHolder to refresh on credential errors
203230
response: The HTTP response object
204231
205232
Raises:
206233
No raises. let the mcp http client handle the errors.
207234
"""
208235
if response.is_error:
236+
# Mark session for refresh so the next request picks up new credentials
237+
if response.status_code in (401, 403):
238+
session_holder.mark_needs_refresh()
239+
209240
# warning only because the SDK logs error
210241
log_level = logging.WARNING
211242
if (
@@ -246,7 +277,7 @@ async def _handle_error_response(response: httpx.Response) -> None:
246277
async def _sign_request_hook(
247278
region: str,
248279
service: str,
249-
session: boto3.Session,
280+
session_holder: SessionHolder,
250281
request: httpx.Request,
251282
) -> None:
252283
"""Request hook to sign HTTP requests with AWS SigV4.
@@ -258,14 +289,19 @@ async def _sign_request_hook(
258289
Args:
259290
region: AWS region for SigV4 signing
260291
service: AWS service name for SigV4 signing
261-
session: AWS boto3 session to use for credentials
292+
session_holder: Holder providing the current boto3 session (refreshed on auth errors)
262293
request: The HTTP request object to sign (modified in-place)
263294
"""
264295
# Set Content-Length for signing
265296
request.headers['Content-Length'] = str(len(request.content))
266297

298+
# Refresh session if a previous request got an auth error.
299+
# Done here (at signing time) so the new session reads credentials
300+
# that the user may have refreshed since the error occurred.
301+
session_holder.refresh_if_needed()
302+
267303
# Get AWS credentials from the session
268-
credentials = session.get_credentials()
304+
credentials = session_holder.session.get_credentials()
269305

270306
# Create SigV4 auth and use its signing logic
271307
auth = SigV4HTTPXAuth(credentials, service, region)

mcp_proxy_for_aws/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
from fastmcp.client.transports import StreamableHttpTransport
22-
from mcp_proxy_for_aws.sigv4_helper import create_aws_session, create_sigv4_client
22+
from mcp_proxy_for_aws.sigv4_helper import SessionHolder, create_aws_session, create_sigv4_client
2323
from typing import Any, Dict, Optional, Tuple
2424
from urllib.parse import urlparse
2525

@@ -90,9 +90,10 @@ def create_transport_with_sigv4(
9090
Returns:
9191
StreamableHttpTransport instance with SigV4 authentication
9292
"""
93-
# Create AWS session once and reuse it for all httpx clients
93+
# Create AWS session with a holder that can refresh on credential errors
9494
logger.debug('Creating AWS session with profile: %s', profile)
9595
session = create_aws_session(profile)
96+
session_holder = SessionHolder(session, profile)
9697

9798
def client_factory(
9899
headers: Optional[Dict[str, str]] = None,
@@ -102,7 +103,7 @@ def client_factory(
102103
) -> httpx.AsyncClient:
103104
return create_sigv4_client(
104105
service=service,
105-
session=session,
106+
session_holder=session_holder,
106107
region=region,
107108
headers=headers,
108109
timeout=custom_timeout,

0 commit comments

Comments
 (0)