Skip to content

Commit e96f157

Browse files
committed
fix: fix HyperparameterTuner to launch training jobs with provided spot parameters
HyperparameterTuner._build_training_job_definition() was not copying parameters needed for managed spot training: - enable_managed_spot_training - max_wait_time_in_seconds This caused training jobs to launch with on-demend instances. - Include the additional parameters in the job definition - Add a unit test
1 parent 71c8d70 commit e96f157

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,13 +1444,15 @@ def _build_training_job_definition(self, inputs):
14441444

14451445
# Build stopping condition
14461446
stopping_condition = StoppingCondition()
1447-
if (
1448-
model_trainer.stopping_condition
1449-
and model_trainer.stopping_condition.max_runtime_in_seconds
1450-
):
1451-
stopping_condition.max_runtime_in_seconds = (
1452-
model_trainer.stopping_condition.max_runtime_in_seconds
1453-
)
1447+
if model_trainer.stopping_condition:
1448+
if model_trainer.stopping_condition.max_runtime_in_seconds:
1449+
stopping_condition.max_runtime_in_seconds = (
1450+
model_trainer.stopping_condition.max_runtime_in_seconds
1451+
)
1452+
if model_trainer.stopping_condition.max_wait_time_in_seconds:
1453+
stopping_condition.max_wait_time_in_seconds = (
1454+
model_trainer.stopping_condition.max_wait_time_in_seconds
1455+
)
14541456

14551457
definition = HyperParameterTrainingJobDefinition(
14561458
algorithm_specification=algorithm_spec,
@@ -1460,6 +1462,7 @@ def _build_training_job_definition(self, inputs):
14601462
resource_config=resource_config,
14611463
stopping_condition=stopping_condition,
14621464
static_hyper_parameters=self.static_hyperparameters or {},
1465+
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
14631466
)
14641467

14651468
return definition

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
# ---------------------------------------------------------------------------
4040

4141

42-
def _create_mock_model_trainer(with_internal_channels=False):
42+
def _create_mock_model_trainer(with_internal_channels=False, with_spot_training=False):
4343
"""Create a mock ModelTrainer with common attributes.
4444
4545
Args:
4646
with_internal_channels: If True, adds internal channels (code, sm_drivers)
4747
to input_data_config for testing channel inclusion in tuning jobs.
48+
with_spot_training: If True, sets spot parameters (enable_managed_spot_training,
49+
max_wait_time_in_seconds)
4850
"""
4951
trainer = MagicMock()
5052
trainer.sagemaker_session = MagicMock()
@@ -67,6 +69,9 @@ def _create_mock_model_trainer(with_internal_channels=False):
6769
_create_channel("code", "s3://bucket/code"),
6870
_create_channel("sm_drivers", "s3://bucket/drivers"),
6971
]
72+
if with_spot_training:
73+
trainer.compute.enable_managed_spot_training = True
74+
trainer.stopping_condition.max_wait_time_in_seconds = 3600
7075
return trainer
7176

7277

@@ -574,3 +579,19 @@ def test_build_training_job_definition_includes_internal_channels(self):
574579
assert "train" in channel_names, "User 'train' channel should be included"
575580
assert "validation" in channel_names, "User 'validation' channel should be included"
576581
assert len(channel_names) == 4, "Should have exactly 4 channels"
582+
583+
def test_build_training_job_definition_includes_spot_params(self):
584+
"""Test that _build_training_job_definition includes spot parameters.
585+
"""
586+
tuner = HyperparameterTuner(
587+
model_trainer=_create_mock_model_trainer(with_spot_training=True),
588+
objective_metric_name="accuracy",
589+
hyperparameter_ranges=_create_single_hp_range(),
590+
)
591+
592+
# Build training job definition
593+
definition = tuner._build_training_job_definition(None)
594+
595+
# Verify managed spot training enabled
596+
assert definition.enable_managed_spot_training is True, "Spot should be enabled"
597+
assert isinstance(definition.stopping_condition.max_wait_time_in_seconds, int), "Max wait time should be set"

0 commit comments

Comments
 (0)