Skip to content

Commit cbf30e0

Browse files
authored
Make timeout configurable. (#59)
1 parent d5648d2 commit cbf30e0

7 files changed

Lines changed: 138 additions & 27 deletions

File tree

README.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,21 @@ docker build -t mcp-proxy-for-aws .
4040

4141
## Configuration Parameters
4242

43-
|Parameter |Description |Default |Required |
44-
|--- |--- |--- |--- |
45-
|`endpoint` |MCP endpoint URL (e.g., `https://your-service.us-east-1.amazonaws.com/mcp`) |N/A |Yes |
46-
|--- |--- |--- |--- |
47-
|`--service` |AWS service name for SigV4 signing |Inferred from endpoint if not provided |No |
48-
|`--profile` |AWS profile for AWS credentials to use |Uses `AWS_PROFILE` environment variable if not set|No |
49-
|`--region` |AWS region to use |Uses `AWS_REGION` environment variable if not set, defaults to `us-east-1` |No |
50-
|`--read-only` |Disable tools which may require write permissions (tools which DO NOT require write permissions are annotated with [`readOnlyHint=true`](https://modelcontextprotocol.io/specification/2025-06-18/schema#toolannotations-readonlyhint))|`False` |No |
51-
| `--retries` |Configures number of retries done when calling upstream services, setting this to 0 disables retries. | 0 |No |
52-
|`--log-level` |Set the logging level (`DEBUG/INFO/WARNING/ERROR/CRITICAL`) |`INFO` |No |
43+
| Parameter | Description | Default |Required |
44+
|----------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------|--- |
45+
| `endpoint` | MCP endpoint URL (e.g., `https://your-service.us-east-1.amazonaws.com/mcp`) | N/A |Yes |
46+
| --- | --- | --- |--- |
47+
| `--service` | AWS service name for SigV4 signing | Inferred from endpoint if not provided |No |
48+
| `--profile` | AWS profile for AWS credentials to use | Uses `AWS_PROFILE` environment variable if not set |No |
49+
| `--region` | AWS region to use | Uses `AWS_REGION` environment variable if not set, defaults to `us-east-1` |No |
50+
| `--read-only` | Disable tools which may require write permissions (tools which DO NOT require write permissions are annotated with [`readOnlyHint=true`](https://modelcontextprotocol.io/specification/2025-06-18/schema#toolannotations-readonlyhint)) | `False` |No |
51+
| `--retries` | Configures number of retries done when calling upstream services, setting this to 0 disables retries. | 0 |No |
52+
| `--log-level` | Set the logging level (`DEBUG/INFO/WARNING/ERROR/CRITICAL`) | `INFO` |No |
53+
| `--timeout` | Set desired timeout in seconds across all operations | 180 |No |
54+
| `--connect-timeout` | Set desired connect timeout in seconds | 60 |No |
55+
| `--read-timeout` | Set desired read timeout in seconds | 120 |No |
56+
| `--write-timeout` | Set desired write timeout in seconds | 180 |No |
57+
5358

5459
## Optional Environment Variables
5560

mcp_proxy_for_aws/server.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import argparse
2626
import asyncio
27+
import httpx
2728
import logging
2829
import os
2930
from fastmcp.server.middleware.error_handling import RetryMiddleware
@@ -36,6 +37,7 @@
3637
create_transport_with_sigv4,
3738
determine_aws_region,
3839
determine_service_name,
40+
within_range,
3941
)
4042
from typing import Any
4143

@@ -62,8 +64,15 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
6264
logger.info('Using service: %s, region: %s, profile: %s', service, region, profile)
6365
logger.info('Running in MCP mode')
6466

67+
timeout = httpx.Timeout(
68+
args.timeout,
69+
connect=args.connect_timeout,
70+
read=args.read_timeout,
71+
write=args.write_timeout,
72+
)
73+
6574
# Create transport with SigV4 authentication
66-
transport = create_transport_with_sigv4(args.endpoint, service, region, profile)
75+
transport = create_transport_with_sigv4(args.endpoint, service, region, timeout, profile)
6776

6877
# Create proxy with the transport
6978
proxy = FastMCP.as_proxy(transport)
@@ -180,6 +189,34 @@ def parse_args():
180189
help='Number of retries when calling endpoint mcp (default: 0) - setting this to 0 disables retries.',
181190
)
182191

192+
parser.add_argument(
193+
'--timeout',
194+
type=within_range(0),
195+
default=180.0,
196+
help='Timeout (seconds) when connecting to endpoint (default: 180)',
197+
)
198+
199+
parser.add_argument(
200+
'--connect-timeout',
201+
type=within_range(0),
202+
default=60.0,
203+
help='Connection timeout (seconds) when connecting to endpoint (default: 60)',
204+
)
205+
206+
parser.add_argument(
207+
'--read-timeout',
208+
type=within_range(0),
209+
default=120.0,
210+
help='Read timeout (seconds) when connecting to endpoint (default: 120)',
211+
)
212+
213+
parser.add_argument(
214+
'--write-timeout',
215+
type=within_range(0),
216+
default=180.0,
217+
help='Write timeout (seconds) when connecting to endpoint (default: 180)',
218+
)
219+
183220
return parser.parse_args()
184221

185222

mcp_proxy_for_aws/sigv4_helper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def create_sigv4_auth(service: str, region: str, profile: Optional[str] = None)
181181
def create_sigv4_client(
182182
service: str,
183183
region: str,
184+
timeout: Optional[httpx.Timeout] = None,
184185
profile: Optional[str] = None,
185186
headers: Optional[Dict[str, str]] = None,
186187
auth: Optional[httpx.Auth] = None,
@@ -192,6 +193,7 @@ def create_sigv4_client(
192193
service: AWS service name for SigV4 signing
193194
profile: AWS profile to use (optional)
194195
region: AWS region (optional, defaults to AWS_REGION env var or us-east-1)
196+
timeout: Timeout configuration for the HTTP client
195197
headers: Headers to include in requests
196198
auth: Auth parameter (ignored as we provide our own)
197199
**kwargs: Additional arguments to pass to httpx.AsyncClient
@@ -202,7 +204,7 @@ def create_sigv4_client(
202204
# Create a copy of kwargs to avoid modifying the passed dict
203205
client_kwargs = {
204206
'follow_redirects': True,
205-
'timeout': httpx.Timeout(180.0, connect=60.0, read=120.0, write=180.0),
207+
'timeout': timeout,
206208
'limits': httpx.Limits(max_keepalive_connections=1, max_connections=5),
207209
**kwargs,
208210
}

mcp_proxy_for_aws/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Utility functions for the MCP Proxy for AWS."""
1616

17+
import argparse
1718
import httpx
1819
import logging
1920
import os
@@ -31,15 +32,18 @@ def create_transport_with_sigv4(
3132
url: str,
3233
service: str,
3334
region: str,
35+
custom_timeout: httpx.Timeout,
3436
profile: Optional[str] = None,
3537
) -> StreamableHttpTransport:
3638
"""Create a StreamableHttpTransport with SigV4 authentication.
3739
3840
Args:
3941
url: The endpoint URL
4042
service: AWS service name for SigV4 signing
43+
region: AWS region to use
44+
custom_timeout: httpx.Timeout used to connect to the endpoint
4145
profile: AWS profile to use (optional)
42-
region: AWS region to use (Optional)
46+
4347
4448
Returns:
4549
StreamableHttpTransport instance with SigV4 authentication
@@ -55,7 +59,7 @@ def client_factory(
5559
profile=profile,
5660
region=region,
5761
headers=headers,
58-
timeout=timeout,
62+
timeout=custom_timeout,
5963
auth=auth,
6064
)
6165

@@ -133,3 +137,35 @@ def determine_aws_region(endpoint: str, region: Optional[str]) -> str:
133137
f"Could not determine AWS region from endpoint '{endpoint}' or from environment variable AWS_REGION. "
134138
'Please provide the region explicitly using --region argument.'
135139
)
140+
141+
142+
def within_range(min_value: float, max_value: Optional[float] = None):
143+
"""Factory function to create range validators.
144+
145+
Args:
146+
min_value: Minimum value
147+
max_value: Maximum value
148+
149+
150+
Returns:
151+
The argparse validator function
152+
153+
Raises:
154+
argparse.ArgumentTypeError: If min and max are not within range
155+
"""
156+
157+
def validator(value):
158+
try:
159+
fvalue = float(value)
160+
except ValueError:
161+
raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer")
162+
163+
if min_value is not None and fvalue < min_value:
164+
raise argparse.ArgumentTypeError(f"'{value}' must be >= {min_value}")
165+
166+
if max_value is not None and fvalue > max_value:
167+
raise argparse.ArgumentTypeError(f"'{value}' must be <= {max_value}")
168+
169+
return fvalue
170+
171+
return validator

tests/unit/test_server.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ async def test_setup_mcp_mode(
5656
mock_args.profile = None
5757
mock_args.read_only = True
5858
mock_args.retries = 1
59+
# Add timeout parameters
60+
mock_args.timeout = 180.0
61+
mock_args.connect_timeout = 60.0
62+
mock_args.read_timeout = 120.0
63+
mock_args.write_timeout = 180.0
64+
mock_args.log_level = 'INFO'
5965

6066
# Mock return values
6167
mock_determine_service.return_value = 'test-service'
@@ -74,9 +80,14 @@ async def test_setup_mcp_mode(
7480
# Assert
7581
mock_determine_service.assert_called_once_with('https://test.example.com', 'test-service')
7682
mock_determine_region.assert_called_once_with('https://test.example.com', 'us-east-1')
77-
mock_create_transport.assert_called_once_with(
78-
'https://test.example.com', 'test-service', 'us-east-1', None
79-
)
83+
# Verify create_transport was called (we check args differently since Timeout object comparison is complex)
84+
assert mock_create_transport.call_count == 1
85+
call_args = mock_create_transport.call_args
86+
assert call_args[0][0] == 'https://test.example.com'
87+
assert call_args[0][1] == 'test-service'
88+
assert call_args[0][2] == 'us-east-1'
89+
# call_args[0][3] is the Timeout object
90+
assert call_args[0][4] is None # profile
8091
mock_as_proxy.assert_called_once_with(mock_transport)
8192
mock_add_filtering.assert_called_once_with(mock_proxy, True)
8293
mock_add_retry.assert_called_once_with(mock_proxy, 1)
@@ -105,6 +116,12 @@ async def test_setup_mcp_mode_no_retries(
105116
mock_args.profile = 'test-profile'
106117
mock_args.read_only = False
107118
mock_args.retries = 0 # No retries
119+
# Add timeout parameters
120+
mock_args.timeout = 180.0
121+
mock_args.connect_timeout = 60.0
122+
mock_args.read_timeout = 120.0
123+
mock_args.write_timeout = 180.0
124+
mock_args.log_level = 'INFO'
108125

109126
# Mock return values
110127
mock_determine_service.return_value = 'test-service'
@@ -123,9 +140,14 @@ async def test_setup_mcp_mode_no_retries(
123140
# Assert
124141
mock_determine_service.assert_called_once_with('https://test.example.com', 'test-service')
125142
mock_determine_region.assert_called_once_with('https://test.example.com', 'us-east-1')
126-
mock_create_transport.assert_called_once_with(
127-
'https://test.example.com', 'test-service', 'us-east-1', 'test-profile'
128-
)
143+
# Verify create_transport was called (we check args differently since Timeout object comparison is complex)
144+
assert mock_create_transport.call_count == 1
145+
call_args = mock_create_transport.call_args
146+
assert call_args[0][0] == 'https://test.example.com'
147+
assert call_args[0][1] == 'test-service'
148+
assert call_args[0][2] == 'us-east-1'
149+
# call_args[0][3] is the Timeout object
150+
assert call_args[0][4] == 'test-profile' # profile
129151
mock_as_proxy.assert_called_once_with(mock_transport)
130152
mock_add_filtering.assert_called_once_with(mock_proxy, False)
131153
mock_proxy.run_async.assert_called_once()

tests/unit/test_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,18 @@ class TestCreateTransportWithSigv4:
3030
@patch('mcp_proxy_for_aws.utils.create_sigv4_client')
3131
def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
3232
"""Test creating StreamableHttpTransport with SigV4 authentication."""
33+
from httpx import Timeout
34+
3335
mock_client = MagicMock()
3436
mock_create_sigv4_client.return_value = mock_client
3537

3638
url = 'https://test-service.us-west-2.api.aws/mcp'
3739
service = 'test-service'
3840
profile = 'test-profile'
3941
region = 'us-east-1'
42+
custom_timeout = Timeout(30.0)
4043

41-
result = create_transport_with_sigv4(url, service, region, profile)
44+
result = create_transport_with_sigv4(url, service, region, custom_timeout, profile)
4245

4346
# Verify result is StreamableHttpTransport
4447
assert isinstance(result, StreamableHttpTransport)
@@ -47,8 +50,6 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
4750
# Test that the httpx_client_factory calls create_sigv4_client correctly
4851
# We need to access the factory through the transport's internal structure
4952
if hasattr(result, 'httpx_client_factory') and result.httpx_client_factory:
50-
from httpx import Timeout
51-
5253
factory = result.httpx_client_factory
5354
test_kwargs = {'headers': {'test': 'header'}, 'timeout': Timeout(30.0), 'auth': None}
5455
factory(**test_kwargs)
@@ -58,7 +59,7 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
5859
profile=profile,
5960
region=region,
6061
headers={'test': 'header'},
61-
timeout=Timeout(30.0),
62+
timeout=custom_timeout,
6263
auth=None,
6364
)
6465
else:
@@ -68,11 +69,14 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client):
6869
@patch('mcp_proxy_for_aws.utils.create_sigv4_client')
6970
def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
7071
"""Test creating transport without profile."""
72+
from httpx import Timeout
73+
7174
url = 'https://test-service.us-west-2.api.aws/mcp'
7275
service = 'test-service'
7376
region = 'test-region'
77+
custom_timeout = Timeout(60.0)
7478

75-
result = create_transport_with_sigv4(url, service, region)
79+
result = create_transport_with_sigv4(url, service, region, custom_timeout)
7680

7781
# Test that the httpx_client_factory calls create_sigv4_client correctly
7882
# We need to access the factory through the transport's internal structure
@@ -81,7 +85,12 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client):
8185
factory(headers=None, timeout=None, auth=None)
8286

8387
mock_create_sigv4_client.assert_called_once_with(
84-
service=service, region=region, profile=None, headers=None, timeout=None, auth=None
88+
service=service,
89+
region=region,
90+
profile=None,
91+
headers=None,
92+
timeout=custom_timeout,
93+
auth=None,
8594
)
8695
else:
8796
# If we can't access the factory directly, just verify the transport was created

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)