Skip to content

Commit 23dc08c

Browse files
gauravmadarkalGaurav Madarkal
andauthored
feat(train): Add wait_timeout parameter to train() (#5786)
* feat(train): Add wait_timeout parameter to train() Updated trainers: SFT, DPO, RLAIF, RLVR, and BaseTrainer. * feat(train): added unit tests for wait_timeout --------- Co-authored-by: Gaurav Madarkal <mdgaurav@amazon.com>
1 parent d4f392f commit 23dc08c

9 files changed

Lines changed: 550 additions & 9 deletions

File tree

sagemaker-train/src/sagemaker/train/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,6 @@ def _is_nova_model_for_telemetry(self) -> bool:
7676
return _is_nova_model(model_name) if model_name else False
7777

7878
@abstractmethod
79-
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True):
79+
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True, wait_timeout: Optional[int] = None):
8080
"""Common training method that calls the specific implementation."""
8181
pass

sagemaker-train/src/sagemaker/train/dpo_trainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def _process_hyperparameters(self):
180180
def train(self,
181181
training_dataset: Optional[Union[str, DataSet]] = None,
182182
validation_dataset: Optional[Union[str, DataSet]] = None,
183-
wait: bool = True):
183+
wait: bool = True,
184+
wait_timeout: Optional[int] = None):
184185
"""Execute the DPO training job.
185186
186187
Parameters:
@@ -192,6 +193,9 @@ def train(self,
192193
Can be an S3 URI, dataset ARN, or DataSet object.
193194
wait (bool):
194195
Whether to wait for the training job to complete. Defaults to True.
196+
wait_timeout (Optional[int]):
197+
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
198+
If None, uses the default timeout from the wait utility.
195199
196200
Returns:
197201
TrainingJob: The SageMaker training job object.
@@ -276,7 +280,10 @@ def train(self,
276280
from sagemaker.train.common_utils.trainer_wait import wait as _wait
277281
from sagemaker.core.utils.exceptions import TimeoutExceededError
278282
try :
279-
_wait(training_job)
283+
wait_kwargs = {}
284+
if wait_timeout is not None:
285+
wait_kwargs['timeout'] = wait_timeout
286+
_wait(training_job, **wait_kwargs)
280287
except TimeoutExceededError as e:
281288
logger.error("Error: %s", e)
282289

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _validate_reward_model_id(self, reward_model_id):
197197

198198

199199
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
200-
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
200+
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
201201
"""Execute the RLAIF training job.
202202
203203
Parameters:
@@ -209,6 +209,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
209209
Can be an S3 URI, dataset ARN, or DataSet object.
210210
wait (bool):
211211
Whether to wait for the training job to complete. Defaults to True.
212+
wait_timeout (Optional[int]):
213+
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
214+
If None, uses the default timeout from the wait utility.
212215
213216
Returns:
214217
TrainingJob: The SageMaker training job object.
@@ -295,7 +298,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
295298
from sagemaker.train.common_utils.trainer_wait import wait as _wait
296299
from sagemaker.core.utils.exceptions import TimeoutExceededError
297300
try :
298-
_wait(training_job)
301+
wait_kwargs = {}
302+
if wait_timeout is not None:
303+
wait_kwargs['timeout'] = wait_timeout
304+
_wait(training_job, **wait_kwargs)
299305
except TimeoutExceededError as e:
300306
logger.error("Error: %s", e)
301307

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _process_hyperparameters(self):
183183

184184
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
185185
def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
186-
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
186+
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
187187
"""Execute the RLVR training job.
188188
189189
Parameters:
@@ -195,6 +195,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
195195
Can be an S3 URI, dataset ARN, or DataSet object.
196196
wait (bool):
197197
Whether to wait for the training job to complete. Defaults to True.
198+
wait_timeout (Optional[int]):
199+
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
200+
If None, uses the default timeout from the wait utility.
198201
199202
Returns:
200203
TrainingJob: The SageMaker training job object.
@@ -283,7 +286,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
283286
from sagemaker.train.common_utils.trainer_wait import wait as _wait
284287
from sagemaker.core.utils.exceptions import TimeoutExceededError
285288
try:
286-
_wait(training_job)
289+
wait_kwargs = {}
290+
if wait_timeout is not None:
291+
wait_kwargs['timeout'] = wait_timeout
292+
_wait(training_job, **wait_kwargs)
287293
except TimeoutExceededError as e:
288294
logger.error("Error: %s", e)
289295

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _process_hyperparameters(self):
180180
self.hyperparameters._specs.pop('validation_data_path', None)
181181

182182
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train")
183-
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
183+
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
184184
"""Execute the SFT training job.
185185
186186
Parameters:
@@ -192,6 +192,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
192192
Can be an S3 URI, dataset ARN, or DataSet object.
193193
wait (bool):
194194
Whether to wait for the training job to complete. Defaults to True.
195+
wait_timeout (Optional[int]):
196+
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
197+
If None, uses the default timeout from the wait utility.
195198
196199
Returns:
197200
TrainingJob: The SageMaker training job object.
@@ -277,7 +280,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
277280
from sagemaker.train.common_utils.trainer_wait import wait as _wait
278281
from sagemaker.core.utils.exceptions import TimeoutExceededError
279282
try :
280-
_wait(training_job)
283+
wait_kwargs = {}
284+
if wait_timeout is not None:
285+
wait_kwargs['timeout'] = wait_timeout
286+
_wait(training_job, **wait_kwargs)
281287
except TimeoutExceededError as e:
282288
logger.error("Error: %s", e)
283289

sagemaker-train/tests/unit/train/test_dpo_trainer.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):
378378

379379
assert trainer.stopping_condition == stopping_condition
380380
assert trainer.stopping_condition.max_runtime_in_seconds == 14400
381+
382+
@patch('sagemaker.train.common_utils.trainer_wait.wait')
383+
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
384+
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
385+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
386+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
387+
@patch('sagemaker.train.dpo_trainer._get_unique_name')
388+
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
389+
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
390+
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
391+
@patch('sagemaker.train.dpo_trainer._create_output_config')
392+
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
393+
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
394+
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
395+
@patch('sagemaker.core.resources.TrainingJob.create')
396+
def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config,
397+
mock_mlflow_config, mock_serverless_config, mock_output_config,
398+
mock_convert_channels, mock_input_config, mock_validate_group,
399+
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
400+
mock_get_options, mock_resolve_model, mock_wait):
401+
"""Test that wait_timeout is passed to _wait as timeout kwarg."""
402+
mock_validate_group.return_value = "test-group"
403+
mock_resolve_model.return_value = ("test-model", "test-model")
404+
mock_get_sagemaker_session.return_value = Mock()
405+
mock_fine_tuning_options = Mock()
406+
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
407+
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
408+
mock_get_role.return_value = "test-role"
409+
mock_unique_name.return_value = "test-job-name"
410+
mock_input_config.return_value = [Mock()]
411+
mock_convert_channels.return_value = [Mock()]
412+
mock_output_config.return_value = Mock()
413+
mock_serverless_config.return_value = Mock()
414+
mock_mlflow_config.return_value = Mock()
415+
mock_model_package_config.return_value = Mock()
416+
mock_training_job = Mock()
417+
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
418+
mock_training_job_create.return_value = mock_training_job
419+
420+
trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
421+
trainer.train(wait=True, wait_timeout=600)
422+
423+
mock_wait.assert_called_once_with(mock_training_job, timeout=600)
424+
425+
@patch('sagemaker.train.common_utils.trainer_wait.wait')
426+
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
427+
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
428+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
429+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
430+
@patch('sagemaker.train.dpo_trainer._get_unique_name')
431+
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
432+
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
433+
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
434+
@patch('sagemaker.train.dpo_trainer._create_output_config')
435+
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
436+
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
437+
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
438+
@patch('sagemaker.core.resources.TrainingJob.create')
439+
def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config,
440+
mock_mlflow_config, mock_serverless_config, mock_output_config,
441+
mock_convert_channels, mock_input_config, mock_validate_group,
442+
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
443+
mock_get_options, mock_resolve_model, mock_wait):
444+
"""Test that _wait is called without timeout kwarg when wait_timeout is None."""
445+
mock_validate_group.return_value = "test-group"
446+
mock_resolve_model.return_value = ("test-model", "test-model")
447+
mock_get_sagemaker_session.return_value = Mock()
448+
mock_fine_tuning_options = Mock()
449+
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
450+
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
451+
mock_get_role.return_value = "test-role"
452+
mock_unique_name.return_value = "test-job-name"
453+
mock_input_config.return_value = [Mock()]
454+
mock_convert_channels.return_value = [Mock()]
455+
mock_output_config.return_value = Mock()
456+
mock_serverless_config.return_value = Mock()
457+
mock_mlflow_config.return_value = Mock()
458+
mock_model_package_config.return_value = Mock()
459+
mock_training_job = Mock()
460+
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
461+
mock_training_job_create.return_value = mock_training_job
462+
463+
trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
464+
trainer.train(wait=True)
465+
466+
mock_wait.assert_called_once_with(mock_training_job)
467+
468+
@patch('sagemaker.train.common_utils.trainer_wait.wait')
469+
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
470+
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
471+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
472+
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
473+
@patch('sagemaker.train.dpo_trainer._get_unique_name')
474+
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
475+
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
476+
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
477+
@patch('sagemaker.train.dpo_trainer._create_output_config')
478+
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
479+
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
480+
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
481+
@patch('sagemaker.core.resources.TrainingJob.create')
482+
def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config,
483+
mock_mlflow_config, mock_serverless_config, mock_output_config,
484+
mock_convert_channels, mock_input_config, mock_validate_group,
485+
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
486+
mock_get_options, mock_resolve_model, mock_wait):
487+
"""Test that _wait is not called when wait=False."""
488+
mock_validate_group.return_value = "test-group"
489+
mock_resolve_model.return_value = ("test-model", "test-model")
490+
mock_get_sagemaker_session.return_value = Mock()
491+
mock_fine_tuning_options = Mock()
492+
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
493+
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
494+
mock_get_role.return_value = "test-role"
495+
mock_unique_name.return_value = "test-job-name"
496+
mock_input_config.return_value = [Mock()]
497+
mock_convert_channels.return_value = [Mock()]
498+
mock_output_config.return_value = Mock()
499+
mock_serverless_config.return_value = Mock()
500+
mock_mlflow_config.return_value = Mock()
501+
mock_model_package_config.return_value = Mock()
502+
mock_training_job = Mock()
503+
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
504+
mock_training_job_create.return_value = mock_training_job
505+
506+
trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
507+
trainer.train(wait=False, wait_timeout=600)
508+
509+
mock_wait.assert_not_called()

0 commit comments

Comments
 (0)