From eec28f2b55fe96ceeb1fe4de73d1fd0e3b1a8134 Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 25 Feb 2026 14:38:42 -0800 Subject: [PATCH] feat: add stop condition to model customization trainers --- .../src/sagemaker/train/dpo_trainer.py | 11 +++++++ .../src/sagemaker/train/rlaif_trainer.py | 11 +++++++ .../src/sagemaker/train/rlvr_trainer.py | 11 +++++++ .../src/sagemaker/train/sft_trainer.py | 11 +++++++ .../tests/unit/train/test_dpo_trainer.py | 21 ++++++++++++ .../tests/unit/train/test_rlaif_trainer.py | 22 +++++++++++++ .../tests/unit/train/test_rlvr_trainer.py | 21 ++++++++++++ .../tests/unit/train/test_sft_trainer.py | 33 +++++++++++++++++++ 8 files changed, 141 insertions(+) diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 690bf30e48..1cbae102e1 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -7,6 +7,7 @@ from sagemaker.core.shapes import VpcConfig from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags +from sagemaker.train.configs import StoppingCondition from sagemaker.train.common_utils.finetune_utils import ( _get_fine_tuning_options_and_model_arn, _validate_and_resolve_model_package_group, @@ -96,6 +97,9 @@ class DPOTrainer(BaseTrainer): The KMS key ID for encrypting training job outputs. networking (Optional[VpcConfig]): The VPC configuration for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition to override training runtime limit. + If not specified, defaults to 1 hour max runtime. """ def __init__( self, @@ -111,6 +115,7 @@ def __init__( kms_key_id: Optional[str] = None, networking: Optional[VpcConfig] = None, accept_eula: bool = False, + stopping_condition: Optional[StoppingCondition] = None, **kwargs, ): super().__init__(**kwargs) @@ -128,6 +133,7 @@ def __init__( self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking + self.stopping_condition = stopping_condition # Initialize fine-tuning options with beta session fallback self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, @@ -198,6 +204,10 @@ def train(self, current_training_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-dpo" ) + + stopping_condition = TrainDefaults.get_stopping_condition( + stopping_condition=self.stopping_condition + ) logger.info(f"Training Job Name: {current_training_job_name}") print(f"Training Job Name: {current_training_job_name}") @@ -251,6 +261,7 @@ def train(self, hyper_parameters=final_hyperparameters, model_package_config=model_package_config, vpc_config=vpc_config, + stopping_condition=stopping_condition, session=sagemaker_session.boto_session, region=sagemaker_session.boto_session.region_name, tags=tags, diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index ebadc6bfda..03dbe2fed6 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -9,6 +9,7 @@ from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata from sagemaker.ai_registry.dataset import DataSet from sagemaker.ai_registry.evaluator import Evaluator +from sagemaker.train.configs import StoppingCondition from sagemaker.train.common_utils.finetune_utils import ( _get_beta_session, _get_fine_tuning_options_and_model_arn, @@ -110,6 +111,9 @@ class RLAIFTrainer(BaseTrainer): The KMS key ID for encrypting training job outputs. networking (Optional[VpcConfig]): The VPC configuration for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition to override training runtime limit. + If not specified, defaults to 1 hour max runtime. """ def __init__( @@ -130,6 +134,7 @@ def __init__( # vpc config networking: Optional[VpcConfig] = None, accept_eula: bool = False, + stopping_condition: Optional[StoppingCondition] = None, **kwargs, ): super().__init__(**kwargs) @@ -150,6 +155,7 @@ def __init__( self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking + self.stopping_condition = stopping_condition # Initialize fine-tuning options with beta session fallback self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, @@ -215,6 +221,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati current_training_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-rlaif" ) + + stopping_condition = TrainDefaults.get_stopping_condition( + stopping_condition=self.stopping_condition + ) logger.info(f"Training Job Name: {current_training_job_name}") @@ -270,6 +280,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati hyper_parameters=final_hyperparameters, model_package_config=model_package_config, vpc_config=vpc_config, + stopping_condition=stopping_condition, session=sagemaker_session.boto_session, region=sagemaker_session.boto_session.region_name, tags=tags, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index b28c9d865c..5fae8ed316 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -8,6 +8,7 @@ from sagemaker.train.utils import _get_unique_name, _get_studio_tags from sagemaker.ai_registry.dataset import DataSet from sagemaker.ai_registry.evaluator import Evaluator +from sagemaker.train.configs import StoppingCondition from sagemaker.train.common_utils.finetune_utils import ( _get_fine_tuning_options_and_model_arn, _validate_and_resolve_model_package_group, @@ -102,6 +103,9 @@ class RLVRTrainer(BaseTrainer): The KMS key ID for encrypting training job outputs. networking (Optional[VpcConfig]): The VPC configuration for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition to override training runtime limit. + If not specified, defaults to 1 hour max runtime. """ def __init__( @@ -121,6 +125,7 @@ def __init__( # vpc config networking: Optional[VpcConfig] = None, accept_eula: bool = False, + stopping_condition: Optional[StoppingCondition] = None, **kwargs, ): super().__init__(**kwargs) @@ -140,6 +145,7 @@ def __init__( self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking + self.stopping_condition = stopping_condition # Initialize fine-tuning options with beta session fallback self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, @@ -202,6 +208,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, current_training_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-rlvr" ) + + stopping_condition = TrainDefaults.get_stopping_condition( + stopping_condition=self.stopping_condition + ) logger.info(f"Training Job Name: {current_training_job_name}") @@ -258,6 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, hyper_parameters=final_hyperparameters, model_package_config=model_package_config, vpc_config=vpc_config, + stopping_condition=stopping_condition, session=sagemaker_session.boto_session, region=sagemaker_session.boto_session.region_name, tags=tags, diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index b2688dce5d..e3fedf6d7f 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -8,6 +8,7 @@ from sagemaker.train.defaults import TrainDefaults from sagemaker.train.utils import _get_unique_name, _get_studio_tags from sagemaker.ai_registry.dataset import DataSet +from sagemaker.train.configs import StoppingCondition from sagemaker.train.common_utils.finetune_utils import ( _get_fine_tuning_options_and_model_arn, _validate_and_resolve_model_package_group, @@ -98,6 +99,9 @@ class SFTTrainer(BaseTrainer): The KMS key ID for encrypting training job outputs. networking (Optional[VpcConfig]): The VPC configuration for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition to override training runtime limit. + If not specified, defaults to 1 hour max runtime. """ def __init__( @@ -114,6 +118,7 @@ def __init__( kms_key_id: Optional[str] = None, networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, + stopping_condition: Optional[StoppingCondition] = None, **kwargs, ): super().__init__(**kwargs) @@ -132,6 +137,7 @@ def __init__( self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking + self.stopping_condition = stopping_condition # Initialize fine-tuning options with beta session fallback self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, @@ -199,6 +205,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati current_training_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-sft" ) + + stopping_condition = TrainDefaults.get_stopping_condition( + stopping_condition=self.stopping_condition + ) logger.info(f"Training Job Name: {current_training_job_name}") @@ -252,6 +262,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati hyper_parameters=final_hyperparameters, model_package_config=model_package_config, vpc_config=vpc_config, + stopping_condition=stopping_condition, session=sagemaker_session.boto_session, region=sagemaker_session.boto_session.region_name, tags=tags, diff --git a/sagemaker-train/tests/unit/train/test_dpo_trainer.py b/sagemaker-train/tests/unit/train/test_dpo_trainer.py index 4f67221029..93a4b18fa9 100644 --- a/sagemaker-train/tests/unit/train/test_dpo_trainer.py +++ b/sagemaker-train/tests/unit/train/test_dpo_trainer.py @@ -357,3 +357,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self): # Should not raise an exception trainer._process_hyperparameters() + + @patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn') + def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): + """Test DPOTrainer accepts stopping_condition parameter.""" + from sagemaker.train.configs import StoppingCondition + + mock_validate.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning.return_value = (mock_hyperparams, "model-arn", False) + + stopping_condition = StoppingCondition(max_runtime_in_seconds=14400) + trainer = DPOTrainer( + model="test-model", + model_package_group="test-group", + stopping_condition=stopping_condition + ) + + assert trainer.stopping_condition == stopping_condition + assert trainer.stopping_condition.max_runtime_in_seconds == 14400 diff --git a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py index 4c45e21ba1..be8b9b96b6 100644 --- a/sagemaker-train/tests/unit/train/test_rlaif_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlaif_trainer.py @@ -532,3 +532,25 @@ def test_validate_reward_model_id_none_model(self): result = trainer._validate_reward_model_id(None) assert result is None + + @patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn') + def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): + """Test RLAIFTrainer accepts stopping_condition parameter.""" + from sagemaker.train.configs import StoppingCondition + + mock_validate.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning.return_value = (mock_hyperparams, "model-arn", False) + + stopping_condition = StoppingCondition(max_runtime_in_seconds=86400) + trainer = RLAIFTrainer( + model="test-model", + model_package_group="test-group", + reward_model_id="openai.gpt-oss-120b-1:0", + stopping_condition=stopping_condition + ) + + assert trainer.stopping_condition == stopping_condition + assert trainer.stopping_condition.max_runtime_in_seconds == 86400 diff --git a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py index c68cd1c94d..4ee785285e 100644 --- a/sagemaker-train/tests/unit/train/test_rlvr_trainer.py +++ b/sagemaker-train/tests/unit/train/test_rlvr_trainer.py @@ -360,3 +360,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self): # Should not raise an exception trainer._process_hyperparameters() + + @patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn') + def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): + """Test RLVRTrainer accepts stopping_condition parameter.""" + from sagemaker.train.configs import StoppingCondition + + mock_validate.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning.return_value = (mock_hyperparams, "model-arn", False) + + stopping_condition = StoppingCondition(max_runtime_in_seconds=259200) + trainer = RLVRTrainer( + model="test-model", + model_package_group="test-group", + stopping_condition=stopping_condition + ) + + assert trainer.stopping_condition == stopping_condition + assert trainer.stopping_condition.max_runtime_in_seconds == 259200 diff --git a/sagemaker-train/tests/unit/train/test_sft_trainer.py b/sagemaker-train/tests/unit/train/test_sft_trainer.py index 38042594d4..6af829e1a7 100644 --- a/sagemaker-train/tests/unit/train/test_sft_trainer.py +++ b/sagemaker-train/tests/unit/train/test_sft_trainer.py @@ -359,3 +359,36 @@ def test_process_hyperparameters_with_none_hyperparameters(self): # Should not raise an exception trainer._process_hyperparameters() + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_accepts_stopping_condition(self, mock_finetuning, mock_validate): + """Test SFTTrainer accepts stopping_condition parameter.""" + from sagemaker.train.configs import StoppingCondition + + mock_validate.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning.return_value = (mock_hyperparams, "model-arn", False) + + stopping_condition = StoppingCondition(max_runtime_in_seconds=7200) + trainer = SFTTrainer( + model="test-model", + model_package_group="test-group", + stopping_condition=stopping_condition + ) + + assert trainer.stopping_condition == stopping_condition + assert trainer.stopping_condition.max_runtime_in_seconds == 7200 + + @patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group') + @patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn') + def test_default_stopping_condition_is_none(self, mock_finetuning, mock_validate): + """Test SFTTrainer defaults stopping_condition to None.""" + mock_validate.return_value = "test-group" + mock_hyperparams = Mock() + mock_hyperparams.to_dict.return_value = {} + mock_finetuning.return_value = (mock_hyperparams, "model-arn", False) + + trainer = SFTTrainer(model="test-model", model_package_group="test-group") + assert trainer.stopping_condition is None