@@ -245,14 +245,13 @@ def test_model_trainer_param_validation(test_case, modules_session):
245245 assert trainer .base_job_name == DEFAULT_BASE_NAME
246246
247247
248+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
248249@patch ("sagemaker.train.model_trainer.TrainingJob" )
249- def test_train_with_default_params (mock_training_job , model_trainer ):
250+ def test_train_with_default_params (mock_training_job , mock_trainer_wait , model_trainer ):
250251 model_trainer .train ()
251252
252253 mock_training_job .create .assert_called_once ()
253-
254- training_job_instance = mock_training_job .create .return_value
255- training_job_instance .wait .assert_called_once_with (logs = True )
254+ mock_trainer_wait .assert_called_once ()
256255
257256
258257@pytest .mark .parametrize (
@@ -292,13 +291,15 @@ def test_train_with_default_params(mock_training_job, model_trainer):
292291 },
293292 ],
294293)
294+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
295295@patch ("sagemaker.train.model_trainer.TrainingJob" )
296296@patch ("sagemaker.train.model_trainer.SageMakerConfig" )
297297@patch ("sagemaker.train.model_trainer.ModelTrainer.create_input_data_channel" )
298298def test_train_with_intelligent_defaults (
299299 mock_create_input_data_channel ,
300300 mock_sagemaker_config ,
301301 mock_training_job ,
302+ mock_trainer_wait ,
302303 default_config ,
303304 model_trainer ,
304305):
@@ -314,15 +315,14 @@ def test_train_with_intelligent_defaults(
314315 model_trainer .train ()
315316
316317 mock_training_job .create .assert_called_once ()
317-
318- training_job_instance = mock_training_job .create .return_value
319- training_job_instance .wait .assert_called_once_with (logs = True )
318+ mock_trainer_wait .assert_called_once ()
320319
321320
321+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
322322@patch ("sagemaker.train.model_trainer.TrainingJob" )
323323@patch ("sagemaker.train.model_trainer.SageMakerConfig" )
324324def test_train_with_intelligent_defaults_training_job_space (
325- mock_sagemaker_config , mock_training_job , model_trainer
325+ mock_sagemaker_config , mock_training_job , mock_trainer_wait , model_trainer
326326):
327327 mock_config_instance = MagicMock ()
328328 mock_sagemaker_config .return_value = mock_config_instance
@@ -379,12 +379,13 @@ def test_train_with_intelligent_defaults_training_job_space(
379379 )
380380
381381 training_job_instance = mock_training_job .create .return_value
382- training_job_instance . wait . assert_called_once_with ( logs = True )
382+ mock_trainer_wait . assert_called_once ( )
383383
384384
385+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
385386@patch ("sagemaker.train.model_trainer.TrainingJob" )
386387@patch .object (ModelTrainer , "_get_input_data_config" )
387- def test_train_with_input_data_channels (mock_get_input_config , mock_training_job , model_trainer ):
388+ def test_train_with_input_data_channels (mock_get_input_config , mock_training_job , mock_trainer_wait , model_trainer ):
388389 train_data = InputData (channel_name = "train" , data_source = "train/dir" )
389390 test_data = InputData (channel_name = "test" , data_source = "test/dir" )
390391 mock_input_data_config = [train_data , test_data ]
@@ -517,13 +518,15 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
517518 "mpi" ,
518519 ],
519520)
521+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
520522@patch ("sagemaker.train.model_trainer.TrainingJob" )
521523@patch ("sagemaker.train.model_trainer.TemporaryDirectory" )
522524@patch ("sagemaker.train.model_trainer.SageMakerConfig" )
523525def test_train_with_distributed_config (
524526 mock_sagemaker_config ,
525527 mock_tmp_dir ,
526528 mock_training_job ,
529+ mock_trainer_wait ,
527530 test_case ,
528531 request ,
529532 modules_session ,
@@ -580,16 +583,18 @@ def test_train_with_distributed_config(
580583 assert not os .path .exists (tmp_dir .name )
581584
582585
586+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
583587@patch ("sagemaker.train.model_trainer.TrainingJob" )
584- def test_train_stores_created_training_job (mock_training_job , model_trainer ):
588+ def test_train_stores_created_training_job (mock_training_job , mock_trainer_wait , model_trainer ):
585589 mock_training_job .create .return_value = TrainingJob (training_job_name = "Created-job" )
586590 model_trainer .train (wait = False )
587591 assert model_trainer ._latest_training_job is not None
588592 assert model_trainer ._latest_training_job == TrainingJob (training_job_name = "Created-job" )
589593
590594
595+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
591596@patch ("sagemaker.train.model_trainer.TrainingJob" )
592- def test_tensorboard_output_config (mock_training_job , modules_session ):
597+ def test_tensorboard_output_config (mock_training_job , mock_trainer_wait , modules_session ):
593598 image_uri = DEFAULT_IMAGE
594599 role = DEFAULT_ROLE
595600 tensorboard_output_config = TensorBoardOutputConfig (
@@ -616,8 +621,9 @@ def test_tensorboard_output_config(mock_training_job, modules_session):
616621 )
617622
618623
624+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
619625@patch ("sagemaker.train.model_trainer.TrainingJob" )
620- def test_retry_strategy (mock_training_job , modules_session ):
626+ def test_retry_strategy (mock_training_job , mock_trainer_wait , modules_session ):
621627 image_uri = DEFAULT_IMAGE
622628 role = DEFAULT_ROLE
623629 retry_strategy = RetryStrategy (
@@ -640,8 +646,9 @@ def test_retry_strategy(mock_training_job, modules_session):
640646 assert mock_training_job .create .call_args .kwargs ["retry_strategy" ] == retry_strategy
641647
642648
649+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
643650@patch ("sagemaker.train.model_trainer.TrainingJob" )
644- def test_infra_check_config (mock_training_job , modules_session ):
651+ def test_infra_check_config (mock_training_job , mock_trainer_wait , modules_session ):
645652 image_uri = DEFAULT_IMAGE
646653 role = DEFAULT_ROLE
647654 infra_check_config = InfraCheckConfig (
@@ -664,8 +671,9 @@ def test_infra_check_config(mock_training_job, modules_session):
664671 assert mock_training_job .create .call_args .kwargs ["infra_check_config" ] == infra_check_config
665672
666673
674+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
667675@patch ("sagemaker.train.model_trainer.TrainingJob" )
668- def test_session_chaining_config (mock_training_job , modules_session ):
676+ def test_session_chaining_config (mock_training_job , mock_trainer_wait , modules_session ):
669677 image_uri = DEFAULT_IMAGE
670678 role = DEFAULT_ROLE
671679 session_chaining_config = SessionChainingConfig (
@@ -691,8 +699,9 @@ def test_session_chaining_config(mock_training_job, modules_session):
691699 )
692700
693701
702+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
694703@patch ("sagemaker.train.model_trainer.TrainingJob" )
695- def test_remote_debug_config (mock_training_job , modules_session ):
704+ def test_remote_debug_config (mock_training_job , mock_trainer_wait , modules_session ):
696705 image_uri = DEFAULT_IMAGE
697706 role = DEFAULT_ROLE
698707 remote_debug_config = RemoteDebugConfig (
@@ -717,9 +726,10 @@ def test_remote_debug_config(mock_training_job, modules_session):
717726 )
718727
719728
729+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
720730@patch ("sagemaker.train.model_trainer._get_unique_name" )
721731@patch ("sagemaker.train.model_trainer.TrainingJob" )
722- def test_model_trainer_full_init (mock_training_job , mock_unique_name , modules_session ):
732+ def test_model_trainer_full_init (mock_training_job , mock_unique_name , mock_trainer_wait , modules_session ):
723733 def mock_upload_data (path , bucket , key_prefix ):
724734 return f"s3://{ bucket } /{ key_prefix } "
725735
@@ -1249,9 +1259,10 @@ def test_hyperparameters_invalid(mock_exists, modules_session):
12491259 )
12501260
12511261
1262+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
12521263@patch ("sagemaker.train.model_trainer._get_unique_name" )
12531264@patch ("sagemaker.train.model_trainer.TrainingJob" )
1254- def test_model_trainer_default_paths (mock_training_job , mock_unique_name , modules_session ):
1265+ def test_model_trainer_default_paths (mock_training_job , mock_unique_name , mock_trainer_wait , modules_session ):
12551266 def mock_upload_data (path , bucket , key_prefix ):
12561267 return f"s3://{ bucket } /{ key_prefix } "
12571268
@@ -1287,8 +1298,9 @@ def mock_upload_data(path, bucket, key_prefix):
12871298 assert kwargs ["tensor_board_output_config" ].local_path == "/opt/ml/output/tensorboard"
12881299
12891300
1301+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
12901302@patch ("sagemaker.train.model_trainer.TrainingJob" )
1291- def test_input_merge (mock_training_job , modules_session ):
1303+ def test_input_merge (mock_training_job , mock_trainer_wait , modules_session ):
12921304 model_input = InputData (channel_name = "model" , data_source = "s3://bucket/model/model.tar.gz" )
12931305 model_trainer = ModelTrainer (
12941306 training_image = DEFAULT_IMAGE ,
@@ -1327,8 +1339,9 @@ def test_input_merge(mock_training_job, modules_session):
13271339 ),
13281340 ]
13291341
1342+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
13301343@patch ("sagemaker.train.model_trainer.TrainingJob" )
1331- def test_metric_definitions (mock_training_job , modules_session ):
1344+ def test_metric_definitions (mock_training_job , mock_trainer_wait , modules_session ):
13321345 image_uri = DEFAULT_IMAGE
13331346 role = DEFAULT_ROLE
13341347 metric_definitions = [
@@ -1352,9 +1365,10 @@ def test_metric_definitions(mock_training_job, modules_session):
13521365 )
13531366
13541367
1368+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
13551369@patch ("sagemaker.train.model_trainer._get_unique_name" )
13561370@patch ("sagemaker.core.resources.TrainingJob" )
1357- def test_nova_recipe (mock_training_job , mock_unique_name , modules_session ):
1371+ def test_nova_recipe (mock_training_job , mock_unique_name , mock_trainer_wait , modules_session ):
13581372 def mock_upload_data (path , bucket , key_prefix ):
13591373 if os .path .isfile (path ):
13601374 file_name = os .path .basename (path )
@@ -1442,9 +1456,10 @@ def test_nova_recipe_with_distillation(modules_session):
14421456 os .unlink (recipe .name )
14431457
14441458
1459+ @patch ("sagemaker.train.model_trainer.trainer_wait" )
14451460@patch ("sagemaker.train.model_trainer._get_unique_name" )
14461461@patch ("sagemaker.train.model_trainer.TrainingJob" )
1447- def test_llmft_recipe (mock_training_job , mock_unique_name , modules_session ):
1462+ def test_llmft_recipe (mock_training_job , mock_unique_name , mock_trainer_wait , modules_session ):
14481463 def mock_upload_data (path , bucket , key_prefix ):
14491464 if os .path .isfile (path ):
14501465 file_name = os .path .basename (path )
0 commit comments