Skip to content

Commit 8f8ea38

Browse files
committed
Fix serve integ tests
1 parent 22d30f5 commit 8f8ea38

2 files changed

Lines changed: 20 additions & 15 deletions

File tree

sagemaker-serve/tests/integ/test_train_inference_e2e_integration.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.train.model_trainer import ModelTrainer
2626
from sagemaker.train.configs import SourceCode
2727
from sagemaker.core.resources import EndpointConfig
28+
from sagemaker.core.helper.session_helper import Session
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -37,6 +38,8 @@
3738
AWS_REGION = "us-west-2"
3839
PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39"
3940

41+
sagemaker_session = Session()
42+
4043

4144
@pytest.mark.slow_test
4245
def test_train_inference_e2e_build_deploy_invoke_cleanup():
@@ -143,13 +146,6 @@ def create_schema_builder():
143146

144147
def train_model():
145148
"""Train model using ModelTrainer."""
146-
from sagemaker.core.helper.session_helper import Session
147-
import boto3
148-
149-
# Create SageMaker session with AWS region
150-
boto_session = boto3.Session(region_name=AWS_REGION)
151-
sagemaker_session = Session(boto_session=boto_session)
152-
153149
training_code_dir = create_pytorch_training_code()
154150
unique_id = str(uuid.uuid4())[:8]
155151

@@ -192,9 +188,10 @@ def invoke(self, input_object, model):
192188
inference_spec=SimpleInferenceSpec(),
193189
image_uri=PYTORCH_TRAINING_IMAGE.replace("training", "inference"),
194190
dependencies={"auto": False},
191+
sagemaker_session=sagemaker_session,
195192
)
196193

197-
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}", region="us-west-2")
194+
core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
198195
logger.info(f"Model Successfully Created: {core_model.model_name}")
199196

200197
core_endpoint = model_builder.deploy(
@@ -221,7 +218,9 @@ def make_prediction(core_endpoint):
221218

222219
def cleanup_resources(core_model, core_endpoint):
223220
"""Fully clean up model and endpoint creation - preserving exact logic from manual test"""
224-
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
221+
core_endpoint_config = EndpointConfig.get(
222+
endpoint_config_name=core_endpoint.endpoint_name,
223+
)
225224

226225
core_model.delete()
227226
core_endpoint.delete()

sagemaker-serve/tests/integ/test_triton_integration.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.serve.model_builder import ModelBuilder
2323
from sagemaker.serve.utils.types import ModelServer
2424
from sagemaker.core.resources import EndpointConfig
25+
from sagemaker.core.helper.session_helper import Session
2526

2627
# PyTorch Imports
2728
import torch
@@ -33,6 +34,8 @@
3334
MODEL_NAME_PREFIX = "triton-test-model"
3435
ENDPOINT_NAME_PREFIX = "triton-test-endpoint"
3536

37+
sagemaker_session = Session()
38+
3639

3740
# Create a simple PyTorch model
3841
class SimpleModel(nn.Module):
@@ -96,11 +99,12 @@ def build_and_deploy():
9699
schema_builder = create_schema_builder()
97100

98101
model_builder = ModelBuilder(
99-
model=pytorch_model,
100-
model_path=model_path,
101-
model_server=ModelServer.TRITON,
102-
schema_builder=schema_builder
103-
)
102+
model=pytorch_model,
103+
model_path=model_path,
104+
model_server=ModelServer.TRITON,
105+
schema_builder=schema_builder,
106+
sagemaker_session=sagemaker_session,
107+
)
104108

105109
unique_id = str(uuid.uuid4())[:8]
106110
# Build and deploy your model. Returns SageMaker Core Model and Endpoint objects
@@ -139,7 +143,9 @@ def make_prediction(core_endpoint):
139143

140144
def cleanup_resources(core_model, core_endpoint):
141145
"""Fully clean up model and endpoint creation - preserving exact logic from manual test"""
142-
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
146+
core_endpoint_config = EndpointConfig.get(
147+
endpoint_config_name=core_endpoint.endpoint_name,
148+
)
143149

144150
core_model.delete()
145151
core_endpoint.delete()

0 commit comments

Comments
 (0)