2525from sagemaker .train .model_trainer import ModelTrainer
2626from sagemaker .train .configs import SourceCode
2727from sagemaker .core .resources import EndpointConfig
28+ from sagemaker .core .helper .session_helper import Session
2829
2930logger = logging .getLogger (__name__ )
3031
3738AWS_REGION = "us-west-2"
3839PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39"
3940
41+ sagemaker_session = Session ()
42+
4043
4144@pytest .mark .slow_test
4245def test_train_inference_e2e_build_deploy_invoke_cleanup ():
@@ -143,13 +146,6 @@ def create_schema_builder():
143146
144147def train_model ():
145148 """Train model using ModelTrainer."""
146- from sagemaker .core .helper .session_helper import Session
147- import boto3
148-
149- # Create SageMaker session with AWS region
150- boto_session = boto3 .Session (region_name = AWS_REGION )
151- sagemaker_session = Session (boto_session = boto_session )
152-
153149 training_code_dir = create_pytorch_training_code ()
154150 unique_id = str (uuid .uuid4 ())[:8 ]
155151
@@ -192,9 +188,10 @@ def invoke(self, input_object, model):
192188 inference_spec = SimpleInferenceSpec (),
193189 image_uri = PYTORCH_TRAINING_IMAGE .replace ("training" , "inference" ),
194190 dependencies = {"auto" : False },
191+ sagemaker_session = sagemaker_session ,
195192 )
196193
197- core_model = model_builder .build (model_name = f"{ MODEL_NAME_PREFIX } -{ unique_id } " , region = "us-west-2" )
194+ core_model = model_builder .build (model_name = f"{ MODEL_NAME_PREFIX } -{ unique_id } " )
198195 logger .info (f"Model Successfully Created: { core_model .model_name } " )
199196
200197 core_endpoint = model_builder .deploy (
@@ -221,7 +218,9 @@ def make_prediction(core_endpoint):
221218
222219def cleanup_resources (core_model , core_endpoint ):
223220 """Fully clean up model and endpoint creation - preserving exact logic from manual test"""
224- core_endpoint_config = EndpointConfig .get (endpoint_config_name = core_endpoint .endpoint_name )
221+ core_endpoint_config = EndpointConfig .get (
222+ endpoint_config_name = core_endpoint .endpoint_name ,
223+ )
225224
226225 core_model .delete ()
227226 core_endpoint .delete ()
0 commit comments