Skip to content

Commit 2ad1d4d

Browse files
author
Kyon Caldera
committed
refactor(siv4_helper.py): move signing logic from client creation to an event hook
1 parent 4ffb75d commit 2ad1d4d

6 files changed

Lines changed: 181 additions & 257 deletions

File tree

mcp_proxy_for_aws/hooks.py

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import httpx
1818
import json
1919
import logging
20-
from botocore.auth import SigV4Auth
21-
from botocore.awsrequest import AWSRequest
2220
from httpx._content import ByteStream
2321
from typing import Any, Dict, Optional
2422

@@ -76,73 +74,50 @@ async def _handle_error_response(response: httpx.Response) -> None:
7674
)
7775

7876

79-
def _resign_request_with_sigv4(
80-
request: httpx.Request,
77+
async def _sign_request_hook(
8178
region: str,
8279
service: str,
83-
profile: Optional[str] = None,
80+
profile: Optional[str],
81+
request: httpx.Request,
8482
) -> None:
85-
"""Re-sign an HTTP request with AWS SigV4 after content modification.
83+
"""Request hook to sign HTTP requests with AWS SigV4.
8684
87-
This function removes old signature headers, creates a new signature based on
88-
the current request content, and updates the request headers with the new signature.
85+
This hook signs the request with AWS SigV4 credentials and adds signature headers.
86+
87+
This should be the last hook called to ensure the signature includes any modifications.
8988
9089
Args:
91-
request: The HTTP request object to re-sign (modified in-place)
9290
region: AWS region for SigV4 signing
9391
service: AWS service name for SigV4 signing
9492
profile: AWS profile to use (optional)
93+
request: The HTTP request object to sign (modified in-place)
9594
"""
96-
# Import here to avoid circular dependency
97-
from mcp_proxy_for_aws.sigv4_helper import create_aws_session
98-
99-
# Remove old signature headers before re-signing
100-
headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization']
101-
for header in headers_to_remove:
102-
request.headers.pop(header, None)
95+
# Import here to avoid circular dependency and for compatibility
96+
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
10397

104-
# Set the new Content-Length
98+
# Set Content-Length for signing
10599
request.headers['Content-Length'] = str(len(request.content))
106100

107-
logger.info('Headers after cleanup: %s', request.headers)
108-
109101
# Get AWS credentials
110102
session = create_aws_session(profile)
111103
credentials = session.get_credentials()
112-
logger.info('Re-signing request with credentials for access key: %s', credentials.access_key)
113-
114-
# Create headers dict for signing, removing connection header like in auth_flow
115-
headers_for_signing = dict(request.headers)
116-
headers_for_signing.pop('connection', None) # Remove connection header for signing
104+
logger.info('Signing request with credentials for access key: %s', credentials.access_key)
117105

118-
# Create SigV4 signer and AWS request
119-
signer = SigV4Auth(credentials, service, region)
120-
aws_request = AWSRequest(
121-
method=request.method,
122-
url=str(request.url),
123-
data=request.content,
124-
headers=headers_for_signing,
125-
)
106+
# Create SigV4 auth and use its signing logic
107+
auth = SigV4HTTPXAuth(credentials, service, region)
126108

127-
# Sign the request
128-
logger.info('AWS request before signing: %s', aws_request.headers)
129-
signer.add_auth(aws_request)
130-
logger.info('AWS request after signing: %s', aws_request.headers)
109+
# Call auth_flow to sign the request (it modifies request in-place)
110+
auth_flow = auth.auth_flow(request)
111+
next(auth_flow) # Execute the generator to perform signing
131112

132-
# Update request headers with signed headers
133-
request.headers.update(dict(aws_request.headers))
134-
logger.info('Request headers after re-signing: %s', request.headers)
113+
logger.debug('Request headers after signing: %s', request.headers)
135114

136115

137-
async def _inject_metadata_hook(
138-
metadata: Dict[str, Any], region: str, service: str, request: httpx.Request
139-
) -> None:
116+
async def _inject_metadata_hook(metadata: Dict[str, Any], request: httpx.Request) -> None:
140117
"""Request hook to inject metadata into MCP calls.
141118
142119
Args:
143120
metadata: Dictionary of metadata to inject into _meta field
144-
region: AWS region for SigV4 re-signing after metadata injection
145-
service: AWS service name for SigV4 re-signing after metadata injection
146121
request: The HTTP request object
147122
"""
148123
logger.info('=== Outgoing Request ===')
@@ -189,9 +164,6 @@ async def _inject_metadata_hook(
189164
request.stream = ByteStream(new_content)
190165
request._content = new_content
191166

192-
# Re-sign the request with the new content
193-
_resign_request_with_sigv4(request, region, service)
194-
195167
logger.info('Injected metadata into _meta: %s', body['params']['_meta'])
196168

197169
except (json.JSONDecodeError, KeyError, TypeError) as e:

mcp_proxy_for_aws/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
6464

6565
# Log server configuration
6666
logger.info(
67-
'Using service: %s, region: %s, forwarding region: %s, profile: %s',
67+
'Using service: %s, region: %s, metadata: %s, profile: %s',
6868
service,
6969
region,
70-
metadata.get('AWS_REGION'),
70+
metadata,
7171
profile,
7272
)
7373
logger.info('Running in MCP mode')

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from botocore.awsrequest import AWSRequest
2222
from botocore.credentials import Credentials
2323
from functools import partial
24-
from mcp_proxy_for_aws.hooks import _handle_error_response, _inject_metadata_hook
24+
from mcp_proxy_for_aws.hooks import (
25+
_handle_error_response,
26+
_inject_metadata_hook,
27+
_sign_request_hook,
28+
)
2529
from typing import Any, Dict, Generator, Optional
2630

2731

@@ -102,42 +106,12 @@ def create_aws_session(profile: Optional[str] = None) -> boto3.Session:
102106
return session
103107

104108

105-
def create_sigv4_auth(service: str, region: str, profile: Optional[str] = None) -> SigV4HTTPXAuth:
106-
"""Create SigV4 authentication for AWS requests.
107-
108-
Args:
109-
service: AWS service name for SigV4 signing
110-
profile: AWS profile to use (optional)
111-
region: AWS region (defaults to AWS_REGION env var or us-east-1)
112-
113-
Returns:
114-
SigV4HTTPXAuth instance
115-
116-
Raises:
117-
ValueError: If credentials cannot be obtained
118-
"""
119-
# Create session and get credentials
120-
session = create_aws_session(profile)
121-
credentials = session.get_credentials()
122-
123-
# Create SigV4Auth with explicit credentials
124-
sigv4_auth = SigV4HTTPXAuth(
125-
credentials=credentials,
126-
service=service,
127-
region=region,
128-
)
129-
130-
logger.info("Created SigV4 authentication for service '%s' in region '%s'", service, region)
131-
return sigv4_auth
132-
133-
134109
def create_sigv4_client(
135110
service: str,
136111
region: str,
137112
timeout: Optional[httpx.Timeout] = None,
138113
profile: Optional[str] = None,
139114
headers: Optional[Dict[str, str]] = None,
140-
auth: Optional[httpx.Auth] = None,
141115
metadata: Optional[Dict[str, Any]] = None,
142116
**kwargs: Any,
143117
) -> httpx.AsyncClient:
@@ -149,7 +123,6 @@ def create_sigv4_client(
149123
region: AWS region (optional, defaults to AWS_REGION env var or us-east-1)
150124
timeout: Timeout configuration for the HTTP client
151125
headers: Headers to include in requests
152-
auth: Auth parameter (ignored as we provide our own)
153126
metadata: Metadata to inject into MCP _meta field
154127
**kwargs: Additional arguments to pass to httpx.AsyncClient
155128
@@ -174,17 +147,15 @@ def create_sigv4_client(
174147
'Creating httpx.AsyncClient with custom headers: %s', client_kwargs.get('headers', {})
175148
)
176149

177-
# Create SigV4 auth
178-
sigv4_auth = create_sigv4_auth(service, region, profile)
179-
180-
# Create the client with SigV4 auth and error handling event hook
181-
logger.info("Creating httpx.AsyncClient with SigV4 authentication for service '%s'", service)
150+
logger.info("Creating httpx.AsyncClient with SigV4 request hooks for service '%s'", service)
182151

183152
return httpx.AsyncClient(
184-
auth=sigv4_auth,
185153
**client_kwargs,
186154
event_hooks={
187155
'response': [_handle_error_response],
188-
'request': [partial(_inject_metadata_hook, metadata or {}, region, service)],
156+
'request': [
157+
partial(_inject_metadata_hook, metadata or {}),
158+
partial(_sign_request_hook, region, service, profile),
159+
],
189160
},
190161
)

0 commit comments

Comments
 (0)