Skip to content

Commit 8144733

Browse files
committed
Fix CI failures
1 parent 506d1b1 commit 8144733

4 files changed

Lines changed: 61 additions & 44 deletions

File tree

sagemaker-core/tests/unit/test_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ def test_run_with_wait(self, mock_session):
10361036
)
10371037

10381038
mock_job = Mock()
1039+
mock_job.processing_job_name = "test-processing-job"
10391040
mock_job.wait = Mock()
10401041

10411042
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f:
@@ -1049,7 +1050,7 @@ def test_run_with_wait(self, mock_session):
10491050
"sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/code.py"
10501051
):
10511052
processor.run(code=temp_file, wait=True, logs=False)
1052-
mock_job.wait.assert_called_once()
1053+
mock_session._wait_for_processing_job.assert_called_once()
10531054
finally:
10541055
if os.path.exists(temp_file):
10551056
os.unlink(temp_file)

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,33 @@
1818
def _refresh_training_job(training_job, sagemaker_session=None):
1919
"""Refresh a training job using the session-aware client if available.
2020
21-
When sagemaker_session is provided, re-fetches the training job via
22-
TrainingJob.get() with the user's boto_session, which avoids the
23-
NoCredentialsError that occurs when refresh() uses the global default client.
21+
When sagemaker_session is provided, uses the session's sagemaker_client
22+
to describe the training job directly, avoiding the global default client.
2423
2524
Args:
2625
training_job (TrainingJob): The training job to refresh.
2726
sagemaker_session: SageMaker session with the correct credentials.
28-
If None, falls back to the default _refresh_training_job(training_job, sagemaker_session).
27+
If None, falls back to training_job.refresh().
2928
"""
3029
if sagemaker_session is not None:
31-
refreshed = TrainingJob.get(
32-
training_job_name=training_job.training_job_name,
33-
session=sagemaker_session.boto_session,
34-
)
35-
# Copy refreshed attributes back to the original object.
36-
# Skip Unassigned values to avoid Pydantic validation errors.
37-
from sagemaker.core.utils.utils import Unassigned
38-
for attr in ("training_job_status", "secondary_status", "failure_reason"):
39-
if hasattr(refreshed, attr):
40-
value = getattr(refreshed, attr)
41-
if isinstance(value, Unassigned):
42-
continue
43-
try:
44-
setattr(training_job, attr, value)
45-
except (AttributeError, TypeError, ValueError):
46-
pass
30+
try:
31+
response = sagemaker_session.sagemaker_client.describe_training_job(
32+
TrainingJobName=training_job.training_job_name
33+
)
34+
# Update key status attributes from the describe response
35+
for api_key, attr_name in (
36+
("TrainingJobStatus", "training_job_status"),
37+
("SecondaryStatus", "secondary_status"),
38+
("FailureReason", "failure_reason"),
39+
):
40+
if api_key in response:
41+
try:
42+
setattr(training_job, attr_name, response[api_key])
43+
except (AttributeError, TypeError, ValueError):
44+
pass
45+
except Exception:
46+
# Fall back to default refresh if session-aware call fails
47+
training_job.refresh()
4748
else:
4849
training_job.refresh()
4950

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
119119
from sagemaker.core.helper.pipeline_variable import StrPipeVar
120120

121+
from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait
121122
from sagemaker.train.local.local_container import _LocalContainer
122123

123124

@@ -790,7 +791,6 @@ def train(
790791
self._latest_training_job = training_job
791792

792793
if wait:
793-
from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait
794794
trainer_wait(
795795
training_job=training_job,
796796
sagemaker_session=self.sagemaker_session,

sagemaker-train/tests/unit/train/test_model_trainer.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,13 @@ def test_model_trainer_param_validation(test_case, modules_session):
245245
assert trainer.base_job_name == DEFAULT_BASE_NAME
246246

247247

248+
@patch("sagemaker.train.model_trainer.trainer_wait")
248249
@patch("sagemaker.train.model_trainer.TrainingJob")
249-
def test_train_with_default_params(mock_training_job, model_trainer):
250+
def test_train_with_default_params(mock_training_job, mock_trainer_wait, model_trainer):
250251
model_trainer.train()
251252

252253
mock_training_job.create.assert_called_once()
253-
254-
training_job_instance = mock_training_job.create.return_value
255-
training_job_instance.wait.assert_called_once_with(logs=True)
254+
mock_trainer_wait.assert_called_once()
256255

257256

258257
@pytest.mark.parametrize(
@@ -292,13 +291,15 @@ def test_train_with_default_params(mock_training_job, model_trainer):
292291
},
293292
],
294293
)
294+
@patch("sagemaker.train.model_trainer.trainer_wait")
295295
@patch("sagemaker.train.model_trainer.TrainingJob")
296296
@patch("sagemaker.train.model_trainer.SageMakerConfig")
297297
@patch("sagemaker.train.model_trainer.ModelTrainer.create_input_data_channel")
298298
def test_train_with_intelligent_defaults(
299299
mock_create_input_data_channel,
300300
mock_sagemaker_config,
301301
mock_training_job,
302+
mock_trainer_wait,
302303
default_config,
303304
model_trainer,
304305
):
@@ -314,15 +315,14 @@ def test_train_with_intelligent_defaults(
314315
model_trainer.train()
315316

316317
mock_training_job.create.assert_called_once()
317-
318-
training_job_instance = mock_training_job.create.return_value
319-
training_job_instance.wait.assert_called_once_with(logs=True)
318+
mock_trainer_wait.assert_called_once()
320319

321320

321+
@patch("sagemaker.train.model_trainer.trainer_wait")
322322
@patch("sagemaker.train.model_trainer.TrainingJob")
323323
@patch("sagemaker.train.model_trainer.SageMakerConfig")
324324
def test_train_with_intelligent_defaults_training_job_space(
325-
mock_sagemaker_config, mock_training_job, model_trainer
325+
mock_sagemaker_config, mock_training_job, mock_trainer_wait, model_trainer
326326
):
327327
mock_config_instance = MagicMock()
328328
mock_sagemaker_config.return_value = mock_config_instance
@@ -379,12 +379,13 @@ def test_train_with_intelligent_defaults_training_job_space(
379379
)
380380

381381
training_job_instance = mock_training_job.create.return_value
382-
training_job_instance.wait.assert_called_once_with(logs=True)
382+
mock_trainer_wait.assert_called_once()
383383

384384

385+
@patch("sagemaker.train.model_trainer.trainer_wait")
385386
@patch("sagemaker.train.model_trainer.TrainingJob")
386387
@patch.object(ModelTrainer, "_get_input_data_config")
387-
def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, model_trainer):
388+
def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, mock_trainer_wait, model_trainer):
388389
train_data = InputData(channel_name="train", data_source="train/dir")
389390
test_data = InputData(channel_name="test", data_source="test/dir")
390391
mock_input_data_config = [train_data, test_data]
@@ -517,13 +518,15 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
517518
"mpi",
518519
],
519520
)
521+
@patch("sagemaker.train.model_trainer.trainer_wait")
520522
@patch("sagemaker.train.model_trainer.TrainingJob")
521523
@patch("sagemaker.train.model_trainer.TemporaryDirectory")
522524
@patch("sagemaker.train.model_trainer.SageMakerConfig")
523525
def test_train_with_distributed_config(
524526
mock_sagemaker_config,
525527
mock_tmp_dir,
526528
mock_training_job,
529+
mock_trainer_wait,
527530
test_case,
528531
request,
529532
modules_session,
@@ -580,16 +583,18 @@ def test_train_with_distributed_config(
580583
assert not os.path.exists(tmp_dir.name)
581584

582585

586+
@patch("sagemaker.train.model_trainer.trainer_wait")
583587
@patch("sagemaker.train.model_trainer.TrainingJob")
584-
def test_train_stores_created_training_job(mock_training_job, model_trainer):
588+
def test_train_stores_created_training_job(mock_training_job, mock_trainer_wait, model_trainer):
585589
mock_training_job.create.return_value = TrainingJob(training_job_name="Created-job")
586590
model_trainer.train(wait=False)
587591
assert model_trainer._latest_training_job is not None
588592
assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job")
589593

590594

595+
@patch("sagemaker.train.model_trainer.trainer_wait")
591596
@patch("sagemaker.train.model_trainer.TrainingJob")
592-
def test_tensorboard_output_config(mock_training_job, modules_session):
597+
def test_tensorboard_output_config(mock_training_job, mock_trainer_wait, modules_session):
593598
image_uri = DEFAULT_IMAGE
594599
role = DEFAULT_ROLE
595600
tensorboard_output_config = TensorBoardOutputConfig(
@@ -616,8 +621,9 @@ def test_tensorboard_output_config(mock_training_job, modules_session):
616621
)
617622

618623

624+
@patch("sagemaker.train.model_trainer.trainer_wait")
619625
@patch("sagemaker.train.model_trainer.TrainingJob")
620-
def test_retry_strategy(mock_training_job, modules_session):
626+
def test_retry_strategy(mock_training_job, mock_trainer_wait, modules_session):
621627
image_uri = DEFAULT_IMAGE
622628
role = DEFAULT_ROLE
623629
retry_strategy = RetryStrategy(
@@ -640,8 +646,9 @@ def test_retry_strategy(mock_training_job, modules_session):
640646
assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy
641647

642648

649+
@patch("sagemaker.train.model_trainer.trainer_wait")
643650
@patch("sagemaker.train.model_trainer.TrainingJob")
644-
def test_infra_check_config(mock_training_job, modules_session):
651+
def test_infra_check_config(mock_training_job, mock_trainer_wait, modules_session):
645652
image_uri = DEFAULT_IMAGE
646653
role = DEFAULT_ROLE
647654
infra_check_config = InfraCheckConfig(
@@ -664,8 +671,9 @@ def test_infra_check_config(mock_training_job, modules_session):
664671
assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config
665672

666673

674+
@patch("sagemaker.train.model_trainer.trainer_wait")
667675
@patch("sagemaker.train.model_trainer.TrainingJob")
668-
def test_session_chaining_config(mock_training_job, modules_session):
676+
def test_session_chaining_config(mock_training_job, mock_trainer_wait, modules_session):
669677
image_uri = DEFAULT_IMAGE
670678
role = DEFAULT_ROLE
671679
session_chaining_config = SessionChainingConfig(
@@ -691,8 +699,9 @@ def test_session_chaining_config(mock_training_job, modules_session):
691699
)
692700

693701

702+
@patch("sagemaker.train.model_trainer.trainer_wait")
694703
@patch("sagemaker.train.model_trainer.TrainingJob")
695-
def test_remote_debug_config(mock_training_job, modules_session):
704+
def test_remote_debug_config(mock_training_job, mock_trainer_wait, modules_session):
696705
image_uri = DEFAULT_IMAGE
697706
role = DEFAULT_ROLE
698707
remote_debug_config = RemoteDebugConfig(
@@ -717,9 +726,10 @@ def test_remote_debug_config(mock_training_job, modules_session):
717726
)
718727

719728

729+
@patch("sagemaker.train.model_trainer.trainer_wait")
720730
@patch("sagemaker.train.model_trainer._get_unique_name")
721731
@patch("sagemaker.train.model_trainer.TrainingJob")
722-
def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session):
732+
def test_model_trainer_full_init(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session):
723733
def mock_upload_data(path, bucket, key_prefix):
724734
return f"s3://{bucket}/{key_prefix}"
725735

@@ -1249,9 +1259,10 @@ def test_hyperparameters_invalid(mock_exists, modules_session):
12491259
)
12501260

12511261

1262+
@patch("sagemaker.train.model_trainer.trainer_wait")
12521263
@patch("sagemaker.train.model_trainer._get_unique_name")
12531264
@patch("sagemaker.train.model_trainer.TrainingJob")
1254-
def test_model_trainer_default_paths(mock_training_job, mock_unique_name, modules_session):
1265+
def test_model_trainer_default_paths(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session):
12551266
def mock_upload_data(path, bucket, key_prefix):
12561267
return f"s3://{bucket}/{key_prefix}"
12571268

@@ -1287,8 +1298,9 @@ def mock_upload_data(path, bucket, key_prefix):
12871298
assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard"
12881299

12891300

1301+
@patch("sagemaker.train.model_trainer.trainer_wait")
12901302
@patch("sagemaker.train.model_trainer.TrainingJob")
1291-
def test_input_merge(mock_training_job, modules_session):
1303+
def test_input_merge(mock_training_job, mock_trainer_wait, modules_session):
12921304
model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz")
12931305
model_trainer = ModelTrainer(
12941306
training_image=DEFAULT_IMAGE,
@@ -1327,8 +1339,9 @@ def test_input_merge(mock_training_job, modules_session):
13271339
),
13281340
]
13291341

1342+
@patch("sagemaker.train.model_trainer.trainer_wait")
13301343
@patch("sagemaker.train.model_trainer.TrainingJob")
1331-
def test_metric_definitions(mock_training_job, modules_session):
1344+
def test_metric_definitions(mock_training_job, mock_trainer_wait, modules_session):
13321345
image_uri = DEFAULT_IMAGE
13331346
role = DEFAULT_ROLE
13341347
metric_definitions = [
@@ -1352,9 +1365,10 @@ def test_metric_definitions(mock_training_job, modules_session):
13521365
)
13531366

13541367

1368+
@patch("sagemaker.train.model_trainer.trainer_wait")
13551369
@patch("sagemaker.train.model_trainer._get_unique_name")
13561370
@patch("sagemaker.core.resources.TrainingJob")
1357-
def test_nova_recipe(mock_training_job, mock_unique_name, modules_session):
1371+
def test_nova_recipe(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session):
13581372
def mock_upload_data(path, bucket, key_prefix):
13591373
if os.path.isfile(path):
13601374
file_name = os.path.basename(path)
@@ -1442,9 +1456,10 @@ def test_nova_recipe_with_distillation(modules_session):
14421456
os.unlink(recipe.name)
14431457

14441458

1459+
@patch("sagemaker.train.model_trainer.trainer_wait")
14451460
@patch("sagemaker.train.model_trainer._get_unique_name")
14461461
@patch("sagemaker.train.model_trainer.TrainingJob")
1447-
def test_llmft_recipe(mock_training_job, mock_unique_name, modules_session):
1462+
def test_llmft_recipe(mock_training_job, mock_unique_name, mock_trainer_wait, modules_session):
14481463
def mock_upload_data(path, bucket, key_prefix):
14491464
if os.path.isfile(path):
14501465
file_name = os.path.basename(path)

0 commit comments

Comments
 (0)