@@ -378,3 +378,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
378378
379379 assert trainer .stopping_condition == stopping_condition
380380 assert trainer .stopping_condition .max_runtime_in_seconds == 14400
381+
382+ @patch ('sagemaker.train.common_utils.trainer_wait.wait' )
383+ @patch ('sagemaker.train.dpo_trainer._resolve_model_and_name' )
384+ @patch ('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn' )
385+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_role' )
386+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session' )
387+ @patch ('sagemaker.train.dpo_trainer._get_unique_name' )
388+ @patch ('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group' )
389+ @patch ('sagemaker.train.dpo_trainer._create_input_data_config' )
390+ @patch ('sagemaker.train.dpo_trainer._convert_input_data_to_channels' )
391+ @patch ('sagemaker.train.dpo_trainer._create_output_config' )
392+ @patch ('sagemaker.train.dpo_trainer._create_serverless_config' )
393+ @patch ('sagemaker.train.dpo_trainer._create_mlflow_config' )
394+ @patch ('sagemaker.train.dpo_trainer._create_model_package_config' )
395+ @patch ('sagemaker.core.resources.TrainingJob.create' )
396+ def test_train_passes_wait_timeout (self , mock_training_job_create , mock_model_package_config ,
397+ mock_mlflow_config , mock_serverless_config , mock_output_config ,
398+ mock_convert_channels , mock_input_config , mock_validate_group ,
399+ mock_unique_name , mock_get_sagemaker_session , mock_get_role ,
400+ mock_get_options , mock_resolve_model , mock_wait ):
401+ """Test that wait_timeout is passed to _wait as timeout kwarg."""
402+ mock_validate_group .return_value = "test-group"
403+ mock_resolve_model .return_value = ("test-model" , "test-model" )
404+ mock_get_sagemaker_session .return_value = Mock ()
405+ mock_fine_tuning_options = Mock ()
406+ mock_fine_tuning_options .to_dict .return_value = {"learning_rate" : "0.001" }
407+ mock_get_options .return_value = (mock_fine_tuning_options , "model-arn" , False )
408+ mock_get_role .return_value = "test-role"
409+ mock_unique_name .return_value = "test-job-name"
410+ mock_input_config .return_value = [Mock ()]
411+ mock_convert_channels .return_value = [Mock ()]
412+ mock_output_config .return_value = Mock ()
413+ mock_serverless_config .return_value = Mock ()
414+ mock_mlflow_config .return_value = Mock ()
415+ mock_model_package_config .return_value = Mock ()
416+ mock_training_job = Mock ()
417+ mock_training_job .arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
418+ mock_training_job_create .return_value = mock_training_job
419+
420+ trainer = DPOTrainer (model = "test-model" , model_package_group = "test-group" , training_dataset = "s3://bucket/train" )
421+ trainer .train (wait = True , wait_timeout = 600 )
422+
423+ mock_wait .assert_called_once_with (mock_training_job , timeout = 600 )
424+
425+ @patch ('sagemaker.train.common_utils.trainer_wait.wait' )
426+ @patch ('sagemaker.train.dpo_trainer._resolve_model_and_name' )
427+ @patch ('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn' )
428+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_role' )
429+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session' )
430+ @patch ('sagemaker.train.dpo_trainer._get_unique_name' )
431+ @patch ('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group' )
432+ @patch ('sagemaker.train.dpo_trainer._create_input_data_config' )
433+ @patch ('sagemaker.train.dpo_trainer._convert_input_data_to_channels' )
434+ @patch ('sagemaker.train.dpo_trainer._create_output_config' )
435+ @patch ('sagemaker.train.dpo_trainer._create_serverless_config' )
436+ @patch ('sagemaker.train.dpo_trainer._create_mlflow_config' )
437+ @patch ('sagemaker.train.dpo_trainer._create_model_package_config' )
438+ @patch ('sagemaker.core.resources.TrainingJob.create' )
439+ def test_train_without_wait_timeout_uses_default (self , mock_training_job_create , mock_model_package_config ,
440+ mock_mlflow_config , mock_serverless_config , mock_output_config ,
441+ mock_convert_channels , mock_input_config , mock_validate_group ,
442+ mock_unique_name , mock_get_sagemaker_session , mock_get_role ,
443+ mock_get_options , mock_resolve_model , mock_wait ):
444+ """Test that _wait is called without timeout kwarg when wait_timeout is None."""
445+ mock_validate_group .return_value = "test-group"
446+ mock_resolve_model .return_value = ("test-model" , "test-model" )
447+ mock_get_sagemaker_session .return_value = Mock ()
448+ mock_fine_tuning_options = Mock ()
449+ mock_fine_tuning_options .to_dict .return_value = {"learning_rate" : "0.001" }
450+ mock_get_options .return_value = (mock_fine_tuning_options , "model-arn" , False )
451+ mock_get_role .return_value = "test-role"
452+ mock_unique_name .return_value = "test-job-name"
453+ mock_input_config .return_value = [Mock ()]
454+ mock_convert_channels .return_value = [Mock ()]
455+ mock_output_config .return_value = Mock ()
456+ mock_serverless_config .return_value = Mock ()
457+ mock_mlflow_config .return_value = Mock ()
458+ mock_model_package_config .return_value = Mock ()
459+ mock_training_job = Mock ()
460+ mock_training_job .arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
461+ mock_training_job_create .return_value = mock_training_job
462+
463+ trainer = DPOTrainer (model = "test-model" , model_package_group = "test-group" , training_dataset = "s3://bucket/train" )
464+ trainer .train (wait = True )
465+
466+ mock_wait .assert_called_once_with (mock_training_job )
467+
468+ @patch ('sagemaker.train.common_utils.trainer_wait.wait' )
469+ @patch ('sagemaker.train.dpo_trainer._resolve_model_and_name' )
470+ @patch ('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn' )
471+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_role' )
472+ @patch ('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session' )
473+ @patch ('sagemaker.train.dpo_trainer._get_unique_name' )
474+ @patch ('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group' )
475+ @patch ('sagemaker.train.dpo_trainer._create_input_data_config' )
476+ @patch ('sagemaker.train.dpo_trainer._convert_input_data_to_channels' )
477+ @patch ('sagemaker.train.dpo_trainer._create_output_config' )
478+ @patch ('sagemaker.train.dpo_trainer._create_serverless_config' )
479+ @patch ('sagemaker.train.dpo_trainer._create_mlflow_config' )
480+ @patch ('sagemaker.train.dpo_trainer._create_model_package_config' )
481+ @patch ('sagemaker.core.resources.TrainingJob.create' )
482+ def test_train_wait_false_skips_wait (self , mock_training_job_create , mock_model_package_config ,
483+ mock_mlflow_config , mock_serverless_config , mock_output_config ,
484+ mock_convert_channels , mock_input_config , mock_validate_group ,
485+ mock_unique_name , mock_get_sagemaker_session , mock_get_role ,
486+ mock_get_options , mock_resolve_model , mock_wait ):
487+ """Test that _wait is not called when wait=False."""
488+ mock_validate_group .return_value = "test-group"
489+ mock_resolve_model .return_value = ("test-model" , "test-model" )
490+ mock_get_sagemaker_session .return_value = Mock ()
491+ mock_fine_tuning_options = Mock ()
492+ mock_fine_tuning_options .to_dict .return_value = {"learning_rate" : "0.001" }
493+ mock_get_options .return_value = (mock_fine_tuning_options , "model-arn" , False )
494+ mock_get_role .return_value = "test-role"
495+ mock_unique_name .return_value = "test-job-name"
496+ mock_input_config .return_value = [Mock ()]
497+ mock_convert_channels .return_value = [Mock ()]
498+ mock_output_config .return_value = Mock ()
499+ mock_serverless_config .return_value = Mock ()
500+ mock_mlflow_config .return_value = Mock ()
501+ mock_model_package_config .return_value = Mock ()
502+ mock_training_job = Mock ()
503+ mock_training_job .arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
504+ mock_training_job_create .return_value = mock_training_job
505+
506+ trainer = DPOTrainer (model = "test-model" , model_package_group = "test-group" , training_dataset = "s3://bucket/train" )
507+ trainer .train (wait = False , wait_timeout = 600 )
508+
509+ mock_wait .assert_not_called ()
0 commit comments