|
16 | 16 | import uuid |
17 | 17 | import pytest |
18 | 18 | import logging |
| 19 | +import time |
| 20 | +from botocore.exceptions import ClientError |
19 | 21 |
|
20 | 22 | from sagemaker.serve.model_builder import ModelBuilder |
21 | 23 | from sagemaker.serve.utils.types import ModelServer |
@@ -116,11 +118,26 @@ def make_prediction(core_endpoint): |
116 | 118 | "inputs": "What are falcons?", |
117 | 119 | "parameters": {"max_new_tokens": 32}, |
118 | 120 | } |
119 | | - |
120 | | - result = core_endpoint.invoke( |
121 | | - body=json.dumps(test_data), |
122 | | - content_type="application/json" |
123 | | - ) |
| 121 | + |
| 122 | + # Retry logic to handle endpoint propagation delay in CodeBuild |
| 123 | + max_retries = 5 |
| 124 | + for attempt in range(max_retries): |
| 125 | + try: |
| 126 | + result = core_endpoint.invoke( |
| 127 | + body=json.dumps(test_data), |
| 128 | + content_type="application/json" |
| 129 | + ) |
| 130 | + break |
| 131 | + except ClientError as e: |
| 132 | + if e.response['Error']['Code'] == 'ValidationException' and 'not found' in str(e): |
| 133 | + if attempt < max_retries - 1: |
| 134 | + wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4, 8 seconds |
| 135 | + logger.warning(f"Endpoint not found, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})") |
| 136 | + time.sleep(wait_time) |
| 137 | + else: |
| 138 | + raise |
| 139 | + else: |
| 140 | + raise |
124 | 141 |
|
125 | 142 | # Decode the output of the invocation and print the result |
126 | 143 | prediction = json.loads(result.body.read().decode('utf-8')) |
|
0 commit comments