Skip to content

Commit a65cf35

Browse files
authored
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) (aws#5725)
* fix: ModelTrainer and HyperparameterTuner missing environment variables (5613) * fix: address review comments (iteration #1) * fix: address review comments (iteration #2) * fix: address review comments (iteration #1)
1 parent 98683ac commit a65cf35

File tree

3 files changed

+117
-9
lines changed

3 files changed

+117
-9
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,13 @@ def _build_training_job_definition(self, inputs):
15041504
model_trainer.stopping_condition.max_wait_time_in_seconds
15051505
)
15061506

1507-
definition = HyperParameterTrainingJobDefinition(
1507+
# Propagate environment variables from ModelTrainer.
1508+
# Only include when it's a dict (even empty); omit otherwise so the
1509+
# Pydantic field stays Unassigned and is excluded during serialization.
1510+
env = model_trainer.environment
1511+
1512+
# Build base kwargs for the definition
1513+
definition_kwargs = dict(
15081514
algorithm_specification=algorithm_spec,
15091515
role_arn=model_trainer.role,
15101516
input_data_config=input_data_config if input_data_config else None,
@@ -1515,10 +1521,11 @@ def _build_training_job_definition(self, inputs):
15151521
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
15161522
)
15171523

1518-
# Pass through environment variables from model_trainer
1519-
env = getattr(model_trainer, "environment", None)
1520-
if env and isinstance(env, dict):
1521-
definition.environment = env
1524+
# Include environment only when it's a dict (including empty).
1525+
if isinstance(env, dict):
1526+
definition_kwargs["environment"] = env
1527+
1528+
definition = HyperParameterTrainingJobDefinition(**definition_kwargs)
15221529

15231530
# Pass through VPC config from model_trainer
15241531
networking = getattr(model_trainer, "networking", None)

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,73 @@ def test_build_training_job_definition_includes_spot_params(self):
596596
assert isinstance(
597597
definition.stopping_condition.max_wait_time_in_seconds, int
598598
), "Max wait time should be set"
599+
600+
def test_build_training_job_definition_includes_environment_variables(self):
601+
"""Test that _build_training_job_definition includes environment variables.
602+
603+
This test verifies the fix for GitHub issue #5613 where tuning jobs were
604+
missing environment variables that were set on the ModelTrainer.
605+
"""
606+
mock_trainer = _create_mock_model_trainer()
607+
mock_trainer.environment = {
608+
"FOO": "bar",
609+
"RANDOM_STATE": "42",
610+
}
611+
612+
tuner = HyperparameterTuner(
613+
model_trainer=mock_trainer,
614+
objective_metric_name="accuracy",
615+
hyperparameter_ranges=_create_single_hp_range(),
616+
)
617+
618+
definition = tuner._build_training_job_definition(None)
619+
620+
assert definition.environment is not None, "Environment should not be None"
621+
assert definition.environment == {
622+
"FOO": "bar",
623+
"RANDOM_STATE": "42",
624+
}, "Environment variables should match those set on ModelTrainer"
625+
626+
def test_build_training_job_definition_with_none_environment(self):
627+
"""Test that _build_training_job_definition handles None environment gracefully.
628+
629+
When environment is None, it should not be passed to the Pydantic constructor,
630+
so the field stays as Unassigned (excluded from serialization).
631+
"""
632+
from sagemaker.core.utils.utils import Unassigned
633+
634+
mock_trainer = _create_mock_model_trainer()
635+
mock_trainer.environment = None
636+
637+
tuner = HyperparameterTuner(
638+
model_trainer=mock_trainer,
639+
objective_metric_name="accuracy",
640+
hyperparameter_ranges=_create_single_hp_range(),
641+
)
642+
643+
definition = tuner._build_training_job_definition(None)
644+
645+
assert isinstance(definition.environment, Unassigned), (
646+
"Environment should be Unassigned when model_trainer.environment is None"
647+
)
648+
649+
def test_build_training_job_definition_with_empty_environment(self):
650+
"""Test that _build_training_job_definition passes through empty environment.
651+
652+
An empty dict is valid for the SageMaker API, so we pass it through as-is
653+
rather than silently converting it to None.
654+
"""
655+
mock_trainer = _create_mock_model_trainer()
656+
mock_trainer.environment = {}
657+
658+
tuner = HyperparameterTuner(
659+
model_trainer=mock_trainer,
660+
objective_metric_name="accuracy",
661+
hyperparameter_ranges=_create_single_hp_range(),
662+
)
663+
664+
definition = tuner._build_training_job_definition(None)
665+
666+
assert definition.environment == {}, (
667+
"Empty dict environment should be passed through as-is"
668+
)

sagemaker-train/tests/unit/train/test_tuner_driver_channels.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,31 @@ def test_passes_environment_variables(self):
405405
definition = tuner._build_training_job_definition(inputs=None)
406406
assert definition.environment == {"MY_VAR": "value", "OTHER": "123"}
407407

408+
def test_passes_empty_environment(self):
409+
"""Should pass through empty dict environment as-is.
410+
411+
An empty dict is valid for the SageMaker API, so we pass it through
412+
rather than silently converting it to None/Unassigned.
413+
"""
414+
trainer = _mock_model_trainer(environment={})
415+
416+
tuner = HyperparameterTuner(
417+
model_trainer=trainer,
418+
objective_metric_name="accuracy",
419+
hyperparameter_ranges=_hp_ranges(),
420+
)
421+
422+
definition = tuner._build_training_job_definition(inputs=None)
423+
assert definition.environment == {}, (
424+
"Empty dict environment should be passed through as-is"
425+
)
426+
408427
def test_skips_environment_when_none(self):
409-
"""Should not set environment when model_trainer.environment is None."""
428+
"""Should not set environment when model_trainer.environment is None.
429+
430+
When environment is None, it is not passed to the Pydantic constructor,
431+
so the field stays as Unassigned (excluded from serialization).
432+
"""
410433
trainer = _mock_model_trainer(environment=None)
411434

412435
tuner = HyperparameterTuner(
@@ -416,10 +439,16 @@ def test_skips_environment_when_none(self):
416439
)
417440

418441
definition = tuner._build_training_job_definition(inputs=None)
419-
assert _is_unassigned(definition.environment)
442+
assert _is_unassigned(definition.environment), (
443+
"Environment should be Unassigned when model_trainer.environment is None"
444+
)
420445

421446
def test_skips_environment_when_not_dict(self):
422-
"""Should not set environment when it's not a dict (e.g. MagicMock)."""
447+
"""Should not set environment when it's not a dict (e.g. MagicMock).
448+
449+
Non-dict values are not passed to the Pydantic constructor to avoid
450+
validation errors. The field stays as Unassigned.
451+
"""
423452
trainer = _mock_model_trainer(environment=MagicMock())
424453

425454
tuner = HyperparameterTuner(
@@ -429,7 +458,9 @@ def test_skips_environment_when_not_dict(self):
429458
)
430459

431460
definition = tuner._build_training_job_definition(inputs=None)
432-
assert _is_unassigned(definition.environment)
461+
assert _is_unassigned(definition.environment), (
462+
"Environment should be Unassigned when model_trainer.environment is not a dict"
463+
)
433464

434465
def test_passes_vpc_config(self):
435466
"""Should set definition.vpc_config from model_trainer.networking._to_vpc_config()."""

0 commit comments

Comments
 (0)