Skip to content

Commit 3fbd1fb

Browse files
fix: HyperparameterTuner to pass enable_managed_spot_training flag to training jobs (aws#5597)
* fix: fix HyperparameterTuner to launch training jobs with provided spot parameters HyperparameterTuner._build_training_job_definition() was not copying parameters needed for managed spot training: - enable_managed_spot_training - max_wait_time_in_seconds This caused training jobs to launch with on-demend instances. - Include the additional parameters in the job definition - Add a unit test * fix: fix black-check warnings --------- Co-authored-by: Molly He <mollyhe@amazon.com>
1 parent 387213b commit 3fbd1fb

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,13 +1494,15 @@ def _build_training_job_definition(self, inputs):
14941494

14951495
# Build stopping condition
14961496
stopping_condition = StoppingCondition()
1497-
if (
1498-
model_trainer.stopping_condition
1499-
and model_trainer.stopping_condition.max_runtime_in_seconds
1500-
):
1501-
stopping_condition.max_runtime_in_seconds = (
1502-
model_trainer.stopping_condition.max_runtime_in_seconds
1503-
)
1497+
if model_trainer.stopping_condition:
1498+
if model_trainer.stopping_condition.max_runtime_in_seconds:
1499+
stopping_condition.max_runtime_in_seconds = (
1500+
model_trainer.stopping_condition.max_runtime_in_seconds
1501+
)
1502+
if model_trainer.stopping_condition.max_wait_time_in_seconds:
1503+
stopping_condition.max_wait_time_in_seconds = (
1504+
model_trainer.stopping_condition.max_wait_time_in_seconds
1505+
)
15041506

15051507
definition = HyperParameterTrainingJobDefinition(
15061508
algorithm_specification=algorithm_spec,
@@ -1510,6 +1512,7 @@ def _build_training_job_definition(self, inputs):
15101512
resource_config=resource_config,
15111513
stopping_condition=stopping_condition,
15121514
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
1515+
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
15131516
)
15141517

15151518
# Pass through environment variables from model_trainer

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
# ---------------------------------------------------------------------------
4040

4141

42-
def _create_mock_model_trainer(with_internal_channels=False):
42+
def _create_mock_model_trainer(with_internal_channels=False, with_spot_training=False):
4343
"""Create a mock ModelTrainer with common attributes.
4444
4545
Args:
4646
with_internal_channels: If True, adds internal channels (code, sm_drivers)
4747
to input_data_config for testing channel inclusion in tuning jobs.
48+
with_spot_training: If True, sets spot parameters (enable_managed_spot_training,
49+
max_wait_time_in_seconds)
4850
"""
4951
trainer = MagicMock()
5052
trainer.sagemaker_session = MagicMock()
@@ -67,6 +69,9 @@ def _create_mock_model_trainer(with_internal_channels=False):
6769
_create_channel("code", "s3://bucket/code"),
6870
_create_channel("sm_drivers", "s3://bucket/drivers"),
6971
]
72+
if with_spot_training:
73+
trainer.compute.enable_managed_spot_training = True
74+
trainer.stopping_condition.max_wait_time_in_seconds = 3600
7075
return trainer
7176

7277

@@ -574,3 +579,20 @@ def test_build_training_job_definition_includes_internal_channels(self):
574579
assert "train" in channel_names, "User 'train' channel should be included"
575580
assert "validation" in channel_names, "User 'validation' channel should be included"
576581
assert len(channel_names) == 4, "Should have exactly 4 channels"
582+
583+
def test_build_training_job_definition_includes_spot_params(self):
584+
"""Test that _build_training_job_definition includes spot parameters."""
585+
tuner = HyperparameterTuner(
586+
model_trainer=_create_mock_model_trainer(with_spot_training=True),
587+
objective_metric_name="accuracy",
588+
hyperparameter_ranges=_create_single_hp_range(),
589+
)
590+
591+
# Build training job definition
592+
definition = tuner._build_training_job_definition(None)
593+
594+
# Verify managed spot training enabled
595+
assert definition.enable_managed_spot_training is True, "Spot should be enabled"
596+
assert isinstance(
597+
definition.stopping_condition.max_wait_time_in_seconds, int
598+
), "Max wait time should be set"

0 commit comments

Comments
 (0)