Skip to content

Commit f1ea9d5

Browse files
committed
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613)
1 parent 272fdbf commit f1ea9d5

File tree

2 files changed

+64
-5
lines changed

2 files changed

+64
-5
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs):
15041504
model_trainer.stopping_condition.max_wait_time_in_seconds
15051505
)
15061506

1507+
# Get environment variables from model_trainer
1508+
env = getattr(model_trainer, "environment", None)
1509+
if not env or not isinstance(env, dict):
1510+
env = None
1511+
15071512
definition = HyperParameterTrainingJobDefinition(
15081513
algorithm_specification=algorithm_spec,
15091514
role_arn=model_trainer.role,
@@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs):
15131518
stopping_condition=stopping_condition,
15141519
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
15151520
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
1521+
environment=env,
15161522
)
15171523

1518-
# Pass through environment variables from model_trainer
1519-
env = getattr(model_trainer, "environment", None)
1520-
if env and isinstance(env, dict):
1521-
definition.environment = env
1522-
15231524
# Pass through VPC config from model_trainer
15241525
networking = getattr(model_trainer, "networking", None)
15251526
if networking and hasattr(networking, "_to_vpc_config"):

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,61 @@ def test_build_training_job_definition_includes_spot_params(self):
596596
assert isinstance(
597597
definition.stopping_condition.max_wait_time_in_seconds, int
598598
), "Max wait time should be set"
599+
600+
def test_build_training_job_definition_includes_environment_variables(self):
601+
"""Test that _build_training_job_definition includes environment variables.
602+
603+
This test verifies the fix for GitHub issue #5613 where tuning jobs were
604+
missing environment variables that were set on the ModelTrainer.
605+
"""
606+
mock_trainer = _create_mock_model_trainer()
607+
mock_trainer.environment = {
608+
"FOO": "bar",
609+
"RANDOM_STATE": "42",
610+
}
611+
612+
tuner = HyperparameterTuner(
613+
model_trainer=mock_trainer,
614+
objective_metric_name="accuracy",
615+
hyperparameter_ranges=_create_single_hp_range(),
616+
)
617+
618+
definition = tuner._build_training_job_definition(None)
619+
620+
assert definition.environment is not None, "Environment should not be None"
621+
assert definition.environment == {
622+
"FOO": "bar",
623+
"RANDOM_STATE": "42",
624+
}, "Environment variables should match those set on ModelTrainer"
625+
626+
def test_build_training_job_definition_with_none_environment(self):
627+
"""Test that _build_training_job_definition handles None environment gracefully."""
628+
mock_trainer = _create_mock_model_trainer()
629+
mock_trainer.environment = None
630+
631+
tuner = HyperparameterTuner(
632+
model_trainer=mock_trainer,
633+
objective_metric_name="accuracy",
634+
hyperparameter_ranges=_create_single_hp_range(),
635+
)
636+
637+
definition = tuner._build_training_job_definition(None)
638+
639+
assert definition.environment is None, "Environment should be None when not set"
640+
641+
def test_build_training_job_definition_with_empty_environment(self):
642+
"""Test that _build_training_job_definition handles empty environment gracefully."""
643+
mock_trainer = _create_mock_model_trainer()
644+
mock_trainer.environment = {}
645+
646+
tuner = HyperparameterTuner(
647+
model_trainer=mock_trainer,
648+
objective_metric_name="accuracy",
649+
hyperparameter_ranges=_create_single_hp_range(),
650+
)
651+
652+
definition = tuner._build_training_job_definition(None)
653+
654+
assert definition.environment is None, (
655+
"Environment should be None when empty dict is provided"
656+
)

0 commit comments

Comments
 (0)