Skip to content

Commit b22e8a6

Browse files
committed
fix: ModelTrainer and HyperparameterTuner missing environment variables (5613)
1 parent ee420cc commit b22e8a6

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,23 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke
442442

443443
return new_static_hyperparameters, auto_parameters
444444

445+
@staticmethod
446+
def _get_model_trainer_environment(model_trainer):
447+
"""Extract environment variables from a ModelTrainer instance.
448+
449+
Returns the environment dict if it is non-empty, otherwise None.
450+
451+
Args:
452+
model_trainer: ModelTrainer instance
453+
454+
Returns:
455+
dict or None: Environment variables dict, or None if empty/not set.
456+
"""
457+
env = getattr(model_trainer, "environment", None)
458+
if env:
459+
return dict(env)
460+
return None
461+
445462
@classmethod
446463
def _prepare_model_trainer_for_tuning(cls, model_trainer, inputs=None, job_name=None, **kwargs):
447464
"""Prepare ModelTrainer before tuning by building sm_drivers and code channels.

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@
3939
# ---------------------------------------------------------------------------
4040

4141

42-
def _create_mock_model_trainer(with_internal_channels=False):
42+
def _create_mock_model_trainer(with_internal_channels=False, environment=None):
4343
"""Create a mock ModelTrainer with common attributes.
4444
4545
Args:
4646
with_internal_channels: If True, adds internal channels (code, sm_drivers)
4747
to input_data_config for testing channel inclusion in tuning jobs.
48+
environment: Optional dict of environment variables to set on the trainer.
4849
"""
4950
trainer = MagicMock()
5051
trainer.sagemaker_session = MagicMock()
@@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
6162
trainer.stopping_condition = MagicMock()
6263
trainer.stopping_condition.max_runtime_in_seconds = 3600
6364
trainer.input_data_config = None
65+
trainer.environment = environment if environment is not None else {}
6466

6567
if with_internal_channels:
6668
trainer.input_data_config = [
@@ -574,3 +576,90 @@ def test_build_training_job_definition_includes_internal_channels(self):
574576
assert "train" in channel_names, "User 'train' channel should be included"
575577
assert "validation" in channel_names, "User 'validation' channel should be included"
576578
assert len(channel_names) == 4, "Should have exactly 4 channels"
579+
580+
def test_build_training_job_definition_includes_environment_variables(self):
581+
"""Test that _build_training_job_definition includes environment variables.
582+
583+
This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
584+
environment variables that were set on the ModelTrainer.
585+
"""
586+
env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"}
587+
mock_trainer = _create_mock_model_trainer(environment=env_vars)
588+
589+
tuner = HyperparameterTuner(
590+
model_trainer=mock_trainer,
591+
objective_metric_name="accuracy",
592+
hyperparameter_ranges=_create_single_hp_range(),
593+
)
594+
595+
definition = tuner._build_training_job_definition(None)
596+
597+
# The definition should contain the environment variables
598+
assert hasattr(definition, "environment") or hasattr(definition, "Environment"), \
599+
"Training job definition should have environment attribute"
600+
definition_env = getattr(definition, "environment", None) or getattr(definition, "Environment", None)
601+
assert definition_env == env_vars, \
602+
f"Environment should be {env_vars}, got {definition_env}"
603+
604+
def test_build_training_job_definition_with_empty_environment(self):
605+
"""Test that _build_training_job_definition handles empty environment."""
606+
mock_trainer = _create_mock_model_trainer(environment={})
607+
608+
tuner = HyperparameterTuner(
609+
model_trainer=mock_trainer,
610+
objective_metric_name="accuracy",
611+
hyperparameter_ranges=_create_single_hp_range(),
612+
)
613+
614+
# Should not raise an error
615+
definition = tuner._build_training_job_definition(None)
616+
assert definition is not None
617+
618+
def test_build_training_job_definition_with_none_environment(self):
619+
"""Test that _build_training_job_definition handles None environment."""
620+
mock_trainer = _create_mock_model_trainer()
621+
mock_trainer.environment = None
622+
623+
tuner = HyperparameterTuner(
624+
model_trainer=mock_trainer,
625+
objective_metric_name="accuracy",
626+
hyperparameter_ranges=_create_single_hp_range(),
627+
)
628+
629+
# Should not raise an error
630+
definition = tuner._build_training_job_definition(None)
631+
assert definition is not None
632+
633+
634+
class TestGetModelTrainerEnvironment:
635+
"""Test _get_model_trainer_environment helper method."""
636+
637+
def test_returns_environment_when_set(self):
638+
"""Test that environment is returned when set on model trainer."""
639+
env_vars = {"KEY1": "val1", "KEY2": "val2"}
640+
mock_trainer = _create_mock_model_trainer(environment=env_vars)
641+
642+
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
643+
assert result == env_vars
644+
645+
def test_returns_none_when_empty(self):
646+
"""Test that None is returned when environment is empty."""
647+
mock_trainer = _create_mock_model_trainer(environment={})
648+
649+
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
650+
assert result is None
651+
652+
def test_returns_none_when_none(self):
653+
"""Test that None is returned when environment is None."""
654+
mock_trainer = _create_mock_model_trainer()
655+
mock_trainer.environment = None
656+
657+
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
658+
assert result is None
659+
660+
def test_returns_none_when_attribute_missing(self):
661+
"""Test that None is returned when environment attribute doesn't exist."""
662+
mock_trainer = MagicMock(spec=[])
663+
664+
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
665+
assert result is None

0 commit comments

Comments
 (0)