Skip to content

Commit d3770cc

Browse files
feat: add stop condition to model customization trainers (#5579)
1 parent 723455f commit d3770cc

File tree

8 files changed

+141
-0
lines changed

8 files changed

+141
-0
lines changed

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sagemaker.core.shapes import VpcConfig
88
from sagemaker.train.defaults import TrainDefaults
99
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
10+
from sagemaker.train.configs import StoppingCondition
1011
from sagemaker.train.common_utils.finetune_utils import (
1112
_get_fine_tuning_options_and_model_arn,
1213
_validate_and_resolve_model_package_group,
@@ -96,6 +97,9 @@ class DPOTrainer(BaseTrainer):
9697
The KMS key ID for encrypting training job outputs.
9798
networking (Optional[VpcConfig]):
9899
The VPC configuration for the training job.
100+
stopping_condition (Optional[StoppingCondition]):
101+
The stopping condition to override training runtime limit.
102+
If not specified, defaults to 1 hour max runtime.
99103
"""
100104
def __init__(
101105
self,
@@ -111,6 +115,7 @@ def __init__(
111115
kms_key_id: Optional[str] = None,
112116
networking: Optional[VpcConfig] = None,
113117
accept_eula: bool = False,
118+
stopping_condition: Optional[StoppingCondition] = None,
114119
**kwargs,
115120
):
116121
super().__init__(**kwargs)
@@ -128,6 +133,7 @@ def __init__(
128133
self.s3_output_path = s3_output_path
129134
self.kms_key_id = kms_key_id
130135
self.networking = networking
136+
self.stopping_condition = stopping_condition
131137

132138
# Initialize fine-tuning options with beta session fallback
133139
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
@@ -198,6 +204,10 @@ def train(self,
198204
current_training_job_name = _get_unique_name(
199205
self.base_job_name or f"{self._model_name}-dpo"
200206
)
207+
208+
stopping_condition = TrainDefaults.get_stopping_condition(
209+
stopping_condition=self.stopping_condition
210+
)
201211

202212
logger.info(f"Training Job Name: {current_training_job_name}")
203213
print(f"Training Job Name: {current_training_job_name}")
@@ -251,6 +261,7 @@ def train(self,
251261
hyper_parameters=final_hyperparameters,
252262
model_package_config=model_package_config,
253263
vpc_config=vpc_config,
264+
stopping_condition=stopping_condition,
254265
session=sagemaker_session.boto_session,
255266
region=sagemaker_session.boto_session.region_name,
256267
tags=tags,

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata
1010
from sagemaker.ai_registry.dataset import DataSet
1111
from sagemaker.ai_registry.evaluator import Evaluator
12+
from sagemaker.train.configs import StoppingCondition
1213
from sagemaker.train.common_utils.finetune_utils import (
1314
_get_beta_session,
1415
_get_fine_tuning_options_and_model_arn,
@@ -110,6 +111,9 @@ class RLAIFTrainer(BaseTrainer):
110111
The KMS key ID for encrypting training job outputs.
111112
networking (Optional[VpcConfig]):
112113
The VPC configuration for the training job.
114+
stopping_condition (Optional[StoppingCondition]):
115+
The stopping condition to override training runtime limit.
116+
If not specified, defaults to 1 hour max runtime.
113117
"""
114118

115119
def __init__(
@@ -130,6 +134,7 @@ def __init__(
130134
# vpc config
131135
networking: Optional[VpcConfig] = None,
132136
accept_eula: bool = False,
137+
stopping_condition: Optional[StoppingCondition] = None,
133138
**kwargs,
134139
):
135140
super().__init__(**kwargs)
@@ -150,6 +155,7 @@ def __init__(
150155
self.s3_output_path = s3_output_path
151156
self.kms_key_id = kms_key_id
152157
self.networking = networking
158+
self.stopping_condition = stopping_condition
153159

154160
# Initialize fine-tuning options with beta session fallback
155161
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
@@ -215,6 +221,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
215221
current_training_job_name = _get_unique_name(
216222
self.base_job_name or f"{self._model_name}-rlaif"
217223
)
224+
225+
stopping_condition = TrainDefaults.get_stopping_condition(
226+
stopping_condition=self.stopping_condition
227+
)
218228

219229
logger.info(f"Training Job Name: {current_training_job_name}")
220230

@@ -270,6 +280,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
270280
hyper_parameters=final_hyperparameters,
271281
model_package_config=model_package_config,
272282
vpc_config=vpc_config,
283+
stopping_condition=stopping_condition,
273284
session=sagemaker_session.boto_session,
274285
region=sagemaker_session.boto_session.region_name,
275286
tags=tags,

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
99
from sagemaker.ai_registry.dataset import DataSet
1010
from sagemaker.ai_registry.evaluator import Evaluator
11+
from sagemaker.train.configs import StoppingCondition
1112
from sagemaker.train.common_utils.finetune_utils import (
1213
_get_fine_tuning_options_and_model_arn,
1314
_validate_and_resolve_model_package_group,
@@ -102,6 +103,9 @@ class RLVRTrainer(BaseTrainer):
102103
The KMS key ID for encrypting training job outputs.
103104
networking (Optional[VpcConfig]):
104105
The VPC configuration for the training job.
106+
stopping_condition (Optional[StoppingCondition]):
107+
The stopping condition to override training runtime limit.
108+
If not specified, defaults to 1 hour max runtime.
105109
"""
106110

107111
def __init__(
@@ -121,6 +125,7 @@ def __init__(
121125
# vpc config
122126
networking: Optional[VpcConfig] = None,
123127
accept_eula: bool = False,
128+
stopping_condition: Optional[StoppingCondition] = None,
124129
**kwargs,
125130
):
126131
super().__init__(**kwargs)
@@ -140,6 +145,7 @@ def __init__(
140145
self.s3_output_path = s3_output_path
141146
self.kms_key_id = kms_key_id
142147
self.networking = networking
148+
self.stopping_condition = stopping_condition
143149

144150
# Initialize fine-tuning options with beta session fallback
145151
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
@@ -202,6 +208,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
202208
current_training_job_name = _get_unique_name(
203209
self.base_job_name or f"{self._model_name}-rlvr"
204210
)
211+
212+
stopping_condition = TrainDefaults.get_stopping_condition(
213+
stopping_condition=self.stopping_condition
214+
)
205215

206216
logger.info(f"Training Job Name: {current_training_job_name}")
207217

@@ -258,6 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
258268
hyper_parameters=final_hyperparameters,
259269
model_package_config=model_package_config,
260270
vpc_config=vpc_config,
271+
stopping_condition=stopping_condition,
261272
session=sagemaker_session.boto_session,
262273
region=sagemaker_session.boto_session.region_name,
263274
tags=tags,

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sagemaker.train.defaults import TrainDefaults
99
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
1010
from sagemaker.ai_registry.dataset import DataSet
11+
from sagemaker.train.configs import StoppingCondition
1112
from sagemaker.train.common_utils.finetune_utils import (
1213
_get_fine_tuning_options_and_model_arn,
1314
_validate_and_resolve_model_package_group,
@@ -98,6 +99,9 @@ class SFTTrainer(BaseTrainer):
9899
The KMS key ID for encrypting training job outputs.
99100
networking (Optional[VpcConfig]):
100101
The VPC configuration for the training job.
102+
stopping_condition (Optional[StoppingCondition]):
103+
The stopping condition to override training runtime limit.
104+
If not specified, defaults to 1 hour max runtime.
101105
"""
102106

103107
def __init__(
@@ -114,6 +118,7 @@ def __init__(
114118
kms_key_id: Optional[str] = None,
115119
networking: Optional[VpcConfig] = None,
116120
accept_eula: Optional[bool] = False,
121+
stopping_condition: Optional[StoppingCondition] = None,
117122
**kwargs,
118123
):
119124
super().__init__(**kwargs)
@@ -132,6 +137,7 @@ def __init__(
132137
self.s3_output_path = s3_output_path
133138
self.kms_key_id = kms_key_id
134139
self.networking = networking
140+
self.stopping_condition = stopping_condition
135141

136142
# Initialize fine-tuning options with beta session fallback
137143
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
@@ -199,6 +205,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
199205
current_training_job_name = _get_unique_name(
200206
self.base_job_name or f"{self._model_name}-sft"
201207
)
208+
209+
stopping_condition = TrainDefaults.get_stopping_condition(
210+
stopping_condition=self.stopping_condition
211+
)
202212

203213
logger.info(f"Training Job Name: {current_training_job_name}")
204214

@@ -252,6 +262,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
252262
hyper_parameters=final_hyperparameters,
253263
model_package_config=model_package_config,
254264
vpc_config=vpc_config,
265+
stopping_condition=stopping_condition,
255266
session=sagemaker_session.boto_session,
256267
region=sagemaker_session.boto_session.region_name,
257268
tags=tags,

sagemaker-train/tests/unit/train/test_dpo_trainer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self):
357357

358358
# Should not raise an exception
359359
trainer._process_hyperparameters()
360+
361+
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
362+
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
363+
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
364+
"""Test DPOTrainer accepts stopping_condition parameter."""
365+
from sagemaker.train.configs import StoppingCondition
366+
367+
mock_validate.return_value = "test-group"
368+
mock_hyperparams = Mock()
369+
mock_hyperparams.to_dict.return_value = {}
370+
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)
371+
372+
stopping_condition = StoppingCondition(max_runtime_in_seconds=14400)
373+
trainer = DPOTrainer(
374+
model="test-model",
375+
model_package_group="test-group",
376+
stopping_condition=stopping_condition
377+
)
378+
379+
assert trainer.stopping_condition == stopping_condition
380+
assert trainer.stopping_condition.max_runtime_in_seconds == 14400

sagemaker-train/tests/unit/train/test_rlaif_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,25 @@ def test_validate_reward_model_id_none_model(self):
532532

533533
result = trainer._validate_reward_model_id(None)
534534
assert result is None
535+
536+
@patch('sagemaker.train.rlaif_trainer._validate_and_resolve_model_package_group')
537+
@patch('sagemaker.train.rlaif_trainer._get_fine_tuning_options_and_model_arn')
538+
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
539+
"""Test RLAIFTrainer accepts stopping_condition parameter."""
540+
from sagemaker.train.configs import StoppingCondition
541+
542+
mock_validate.return_value = "test-group"
543+
mock_hyperparams = Mock()
544+
mock_hyperparams.to_dict.return_value = {}
545+
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)
546+
547+
stopping_condition = StoppingCondition(max_runtime_in_seconds=86400)
548+
trainer = RLAIFTrainer(
549+
model="test-model",
550+
model_package_group="test-group",
551+
reward_model_id="openai.gpt-oss-120b-1:0",
552+
stopping_condition=stopping_condition
553+
)
554+
555+
assert trainer.stopping_condition == stopping_condition
556+
assert trainer.stopping_condition.max_runtime_in_seconds == 86400

sagemaker-train/tests/unit/train/test_rlvr_trainer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,24 @@ def test_process_hyperparameters_with_none_hyperparameters(self):
360360

361361
# Should not raise an exception
362362
trainer._process_hyperparameters()
363+
364+
@patch('sagemaker.train.rlvr_trainer._validate_and_resolve_model_package_group')
365+
@patch('sagemaker.train.rlvr_trainer._get_fine_tuning_options_and_model_arn')
366+
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
367+
"""Test RLVRTrainer accepts stopping_condition parameter."""
368+
from sagemaker.train.configs import StoppingCondition
369+
370+
mock_validate.return_value = "test-group"
371+
mock_hyperparams = Mock()
372+
mock_hyperparams.to_dict.return_value = {}
373+
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)
374+
375+
stopping_condition = StoppingCondition(max_runtime_in_seconds=259200)
376+
trainer = RLVRTrainer(
377+
model="test-model",
378+
model_package_group="test-group",
379+
stopping_condition=stopping_condition
380+
)
381+
382+
assert trainer.stopping_condition == stopping_condition
383+
assert trainer.stopping_condition.max_runtime_in_seconds == 259200

sagemaker-train/tests/unit/train/test_sft_trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,36 @@ def test_process_hyperparameters_with_none_hyperparameters(self):
359359

360360
# Should not raise an exception
361361
trainer._process_hyperparameters()
362+
363+
@patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
364+
@patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
365+
def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
366+
"""Test SFTTrainer accepts stopping_condition parameter."""
367+
from sagemaker.train.configs import StoppingCondition
368+
369+
mock_validate.return_value = "test-group"
370+
mock_hyperparams = Mock()
371+
mock_hyperparams.to_dict.return_value = {}
372+
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)
373+
374+
stopping_condition = StoppingCondition(max_runtime_in_seconds=7200)
375+
trainer = SFTTrainer(
376+
model="test-model",
377+
model_package_group="test-group",
378+
stopping_condition=stopping_condition
379+
)
380+
381+
assert trainer.stopping_condition == stopping_condition
382+
assert trainer.stopping_condition.max_runtime_in_seconds == 7200
383+
384+
@patch('sagemaker.train.sft_trainer._validate_and_resolve_model_package_group')
385+
@patch('sagemaker.train.sft_trainer._get_fine_tuning_options_and_model_arn')
386+
def test_default_stopping_condition_is_none(self, mock_finetuning, mock_validate):
387+
"""Test SFTTrainer defaults stopping_condition to None."""
388+
mock_validate.return_value = "test-group"
389+
mock_hyperparams = Mock()
390+
mock_hyperparams.to_dict.return_value = {}
391+
mock_finetuning.return_value = (mock_hyperparams, "model-arn", False)
392+
393+
trainer = SFTTrainer(model="test-model", model_package_group="test-group")
394+
assert trainer.stopping_condition is None

0 commit comments

Comments
 (0)