Skip to content

Commit ad9458c

Browse files
authored
Unskip sagemaker serve integration tests (aws#5810)
1 parent 8fd1fe6 commit ad9458c

2 files changed

Lines changed: 75 additions & 99 deletions

File tree

sagemaker-serve/tests/integ/test_model_customization_deployment.py

Lines changed: 60 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,22 @@
1313
"""Integration tests for ModelBuilder model customization deployment."""
1414
from __future__ import absolute_import
1515

16+
import boto3
1617
import pytest
1718
import random
1819

20+
from sagemaker.core.helper.session_helper import Session
21+
22+
# This test relies on resources in a specific region
23+
AWS_REGION = "us-west-2"
24+
25+
26+
@pytest.fixture(scope="module")
27+
def sagemaker_session():
28+
"""Create a SageMaker session with explicit region."""
29+
boto_session = boto3.Session(region_name=AWS_REGION)
30+
return Session(boto_session=boto_session)
31+
1932

2033
@pytest.fixture(scope="module")
2134
def training_job_name():
@@ -48,51 +61,6 @@ def endpoint_name():
4861
return f"e2e-{int(time.time())}-{random.randint(100, 10000)}"
4962

5063

51-
@pytest.fixture(scope="session", autouse=True)
52-
def cleanup_e2e_endpoints():
53-
"""Cleanup e2e endpoints before and after tests."""
54-
import os
55-
from botocore.exceptions import ClientError
56-
57-
# This file's tests use us-west-2 resources. Set SAGEMAKER_REGION so the
58-
# SDK's SageMakerClient creates sessions in the correct region from the start.
59-
# Save/restore to avoid leaking into other test files.
60-
original_sm_region = os.environ.get("SAGEMAKER_REGION")
61-
os.environ["SAGEMAKER_REGION"] = "us-west-2"
62-
63-
from sagemaker.core.resources import Endpoint
64-
65-
# Cleanup before tests
66-
try:
67-
for endpoint in Endpoint.get_all():
68-
try:
69-
if endpoint.endpoint_name.startswith('e2e-'):
70-
endpoint.delete()
71-
except (ClientError, Exception):
72-
pass
73-
except (ClientError, Exception):
74-
pass
75-
76-
yield
77-
78-
# Cleanup after tests
79-
try:
80-
for endpoint in Endpoint.get_all():
81-
try:
82-
if endpoint.endpoint_name.startswith('e2e-'):
83-
endpoint.delete()
84-
except (ClientError, Exception):
85-
pass
86-
except (ClientError, Exception):
87-
pass
88-
89-
# Restore original SAGEMAKER_REGION
90-
if original_sm_region:
91-
os.environ["SAGEMAKER_REGION"] = original_sm_region
92-
elif "SAGEMAKER_REGION" in os.environ:
93-
del os.environ["SAGEMAKER_REGION"]
94-
95-
9664
@pytest.fixture(scope="module")
9765
def cleanup_endpoints():
9866
"""Track endpoints to cleanup after tests."""
@@ -102,7 +70,7 @@ def cleanup_endpoints():
10270
for ep_name in endpoints_to_cleanup:
10371
try:
10472
from sagemaker.core.resources import Endpoint
105-
endpoint = Endpoint.get(endpoint_name=ep_name)
73+
endpoint = Endpoint.get(endpoint_name=ep_name, region=AWS_REGION)
10674
endpoint.delete()
10775
except Exception:
10876
pass
@@ -111,24 +79,23 @@ def cleanup_endpoints():
11179
class TestModelCustomizationFromTrainingJob:
11280
"""Test model customization deployment from TrainingJob."""
11381

114-
def test_build_from_training_job(self, training_job_name):
82+
def test_build_from_training_job(self, training_job_name, sagemaker_session):
11583
"""Test building model from training job."""
11684
from sagemaker.core.resources import TrainingJob
11785
from sagemaker.serve import ModelBuilder
11886
import time
11987

120-
training_job = TrainingJob.get(training_job_name=training_job_name)
121-
model_builder = ModelBuilder(model=training_job)
88+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
89+
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
12290
model_builder.accept_eula = True
123-
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
91+
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)
12492

12593
assert model is not None
12694
assert model.model_arn is not None
12795
assert model_builder.image_uri is not None
12896
assert model_builder.instance_type is not None
12997

130-
@pytest.mark.skip(reason="Skipped: parallel cleanup race condition under investigation")
131-
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
98+
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints, sagemaker_session):
13299
"""Test deploying model from training job.
133100
134101
For LORA models, this verifies the two-step deployment:
@@ -138,10 +105,10 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
138105
from sagemaker.serve import ModelBuilder
139106
import time
140107

141-
training_job = TrainingJob.get(training_job_name=training_job_name)
142-
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
108+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
109+
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge", sagemaker_session=sagemaker_session)
143110
model_builder.accept_eula = True
144-
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
111+
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)
145112

146113
peft_type = model_builder._fetch_peft()
147114
adapter_name = f"{endpoint_name}-adapter"
@@ -160,52 +127,52 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
160127
if peft_type == "LORA":
161128
# Verify base IC was created
162129
base_ic_name = f"{endpoint_name}-inference-component"
163-
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
130+
base_ic = InferenceComponent.get(inference_component_name=base_ic_name, region=AWS_REGION)
164131
assert base_ic is not None
165132
assert base_ic.inference_component_status == "InService"
166133

167134
# Verify adapter IC was created
168-
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
135+
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION)
169136
assert adapter_ic is not None
170137

171-
def test_fetch_endpoint_names_for_base_model(self, training_job_name):
138+
def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session):
172139
"""Test fetching endpoint names for base model."""
173140
from sagemaker.core.resources import TrainingJob
174141
from sagemaker.serve import ModelBuilder
175142

176-
training_job = TrainingJob.get(training_job_name=training_job_name)
177-
model_builder = ModelBuilder(model=training_job)
143+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
144+
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
178145
endpoint_names = model_builder.fetch_endpoint_names_for_base_model()
179146

180147
assert isinstance(endpoint_names, set)
181148

182149

183150
class TestModelCustomizationFromModelPackage:
184151

185-
def test_build_from_model_package(self, model_package_arn):
152+
def test_build_from_model_package(self, model_package_arn, sagemaker_session):
186153
"""Test building model from model package."""
187154
from sagemaker.core.resources import ModelPackage
188155
from sagemaker.serve import ModelBuilder
189156

190-
model_package = ModelPackage.get(model_package_name=model_package_arn)
191-
model_builder = ModelBuilder(model=model_package)
157+
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
158+
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
192159
model_builder.accept_eula = True
193-
model = model_builder.build()
160+
model = model_builder.build(region=AWS_REGION)
194161

195162
assert model is not None
196163
assert model.model_arn is not None
197164

198-
def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
165+
def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints, sagemaker_session):
199166
"""Test deploying model from model package."""
200167
from sagemaker.core.resources import ModelPackage
201168
from sagemaker.serve import ModelBuilder
202169
import time
203170

204-
model_package = ModelPackage.get(model_package_name=model_package_arn)
171+
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
205172
endpoint_name = f"e2e-{int(time.time())}-{random.randint(100, 10000)}"
206-
model_builder = ModelBuilder(model=model_package)
173+
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
207174
model_builder.accept_eula = True
208-
model_builder.build()
175+
model_builder.build(region=AWS_REGION)
209176
endpoint = model_builder.deploy(endpoint_name=endpoint_name)
210177

211178
cleanup_endpoints.append(endpoint_name)
@@ -217,15 +184,15 @@ def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
217184
class TestInstanceTypeAutoDetection:
218185
"""Test automatic instance type detection."""
219186

220-
def test_instance_type_from_recipe(self, training_job_name):
187+
def test_instance_type_from_recipe(self, training_job_name, sagemaker_session):
221188
"""Test instance type auto-detection from recipe."""
222189
from sagemaker.core.resources import TrainingJob
223190
from sagemaker.serve import ModelBuilder
224191

225-
training_job = TrainingJob.get(training_job_name=training_job_name)
226-
model_builder = ModelBuilder(model=training_job)
192+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
193+
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
227194
model_builder.accept_eula = True
228-
model_builder.build()
195+
model_builder.build(region=AWS_REGION)
229196

230197
assert model_builder.instance_type is not None
231198
assert "ml." in model_builder.instance_type
@@ -234,33 +201,33 @@ def test_instance_type_from_recipe(self, training_job_name):
234201
class TestModelCustomizationDetection:
235202
"""Test model customization detection logic."""
236203

237-
def test_is_model_customization_training_job(self, training_job_name):
204+
def test_is_model_customization_training_job(self, training_job_name, sagemaker_session):
238205
"""Test detection from training job."""
239206
from sagemaker.core.resources import TrainingJob
240207
from sagemaker.serve import ModelBuilder
241208

242-
training_job = TrainingJob.get(training_job_name=training_job_name)
243-
model_builder = ModelBuilder(model=training_job)
209+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
210+
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
244211

245212
assert model_builder._is_model_customization() is True
246213

247-
def test_is_model_customization_model_package(self, model_package_arn):
214+
def test_is_model_customization_model_package(self, model_package_arn, sagemaker_session):
248215
"""Test detection from model package."""
249216
from sagemaker.core.resources import ModelPackage
250217
from sagemaker.serve import ModelBuilder
251218

252-
model_package = ModelPackage.get(model_package_name=model_package_arn)
253-
model_builder = ModelBuilder(model=model_package)
219+
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
220+
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
254221

255222
assert model_builder._is_model_customization() is True
256223

257-
def test_fetch_model_package_arn(self, training_job_name):
224+
def test_fetch_model_package_arn(self, training_job_name, sagemaker_session):
258225
"""Test fetching model package ARN."""
259226
from sagemaker.core.resources import TrainingJob
260227
from sagemaker.serve import ModelBuilder
261228

262-
training_job = TrainingJob.get(training_job_name=training_job_name)
263-
model_builder = ModelBuilder(model=training_job)
229+
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
230+
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
264231

265232
arn = model_builder._fetch_model_package_arn()
266233

@@ -271,14 +238,14 @@ def test_fetch_model_package_arn(self, training_job_name):
271238
class TestTrainerIntegration:
272239
"""Test ModelBuilder integration with SFTTrainer and DPOTrainer."""
273240

274-
def test_sft_trainer_build(self, training_job_name):
241+
def test_sft_trainer_build(self, training_job_name, sagemaker_session):
275242
"""Test building model from SFTTrainer."""
276243
from sagemaker.core.resources import TrainingJob
277244
from sagemaker.train.sft_trainer import SFTTrainer
278245
from sagemaker.serve import ModelBuilder
279246

280247
training_job = TrainingJob.get(
281-
training_job_name=training_job_name
248+
training_job_name=training_job_name, region=AWS_REGION
282249
)
283250

284251
trainer = SFTTrainer(
@@ -289,21 +256,21 @@ def test_sft_trainer_build(self, training_job_name):
289256
)
290257
trainer._latest_training_job = training_job
291258

292-
model_builder = ModelBuilder(model=trainer)
293-
model = model_builder.build()
259+
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
260+
model = model_builder.build(region=AWS_REGION)
294261

295262
assert model is not None
296263
assert model.model_arn is not None
297264

298-
def test_dpo_trainer_build(self, training_job_name):
265+
def test_dpo_trainer_build(self, training_job_name, sagemaker_session):
299266
"""Test building model from DPOTrainer."""
300267
from sagemaker.core.resources import TrainingJob
301268
from sagemaker.train.dpo_trainer import DPOTrainer
302269
from sagemaker.serve import ModelBuilder
303270
from unittest.mock import patch
304271

305272
training_job = TrainingJob.get(
306-
training_job_name=training_job_name
273+
training_job_name=training_job_name, region=AWS_REGION
307274
)
308275

309276
with patch('sagemaker.train.common_utils.finetune_utils._get_fine_tuning_options_and_model_arn',
@@ -316,8 +283,8 @@ def test_dpo_trainer_build(self, training_job_name):
316283
)
317284
trainer._latest_training_job = training_job
318285

319-
model_builder = ModelBuilder(model=trainer)
320-
model = model_builder.build()
286+
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
287+
model = model_builder.build(region=AWS_REGION)
321288

322289
assert model is not None
323290
assert model.model_arn is not None
@@ -335,8 +302,6 @@ def test_dpo_trainer_build(self, training_job_name):
335302

336303
import json
337304
import time
338-
import random
339-
import boto3
340305
import pytest
341306
from sagemaker.core.resources import TrainingJob, ModelPackage
342307
from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder
@@ -361,6 +326,7 @@ def training_job(self, setup_config):
361326
"""Get the training job."""
362327
return TrainingJob.get(
363328
training_job_name=setup_config["training_job_name"],
329+
region=setup_config["region"],
364330
)
365331

366332
@pytest.fixture(scope="class")
@@ -432,7 +398,7 @@ def _setup_model_files(self, training_job, s3_client, setup_config):
432398
base_s3_path = training_job.model_artifacts.s3_model_artifacts
433399
elif hasattr(training_job, 'output_model_package_arn'):
434400
# If training job has model package ARN, get artifacts from model package
435-
model_package = ModelPackage.get(training_job.output_model_package_arn)
401+
model_package = ModelPackage.get(training_job.output_model_package_arn, region=AWS_REGION)
436402
if hasattr(model_package,
437403
'inference_specification') and model_package.inference_specification.containers:
438404
container = model_package.inference_specification.containers[0]
@@ -561,8 +527,7 @@ def test_zzz_cleanup_deployed_model(self, bedrock_client):
561527
def test_model_customization_workflow(training_job_name):
562528
"""Standalone test function for pytest discovery.
563529
564-
Relies on SAGEMAKER_REGION being set by the cleanup_e2e_endpoints
565-
session fixture (us-west-2).
530+
Uses explicit region parameter for all SDK calls.
566531
"""
567532
config = {
568533
"training_job_name": training_job_name,
@@ -572,7 +537,7 @@ def test_model_customization_workflow(training_job_name):
572537

573538
try:
574539
s3_client = boto3.client('s3', region_name=config["region"])
575-
training_job = TrainingJob.get(training_job_name=config["training_job_name"])
540+
training_job = TrainingJob.get(training_job_name=config["training_job_name"], region=config["region"])
576541

577542
test_class = TestModelCustomizationDeployment()
578543
test_class.test_training_job_exists(training_job)

0 commit comments

Comments
 (0)