3939# ---------------------------------------------------------------------------
4040
4141
42- def _create_mock_model_trainer (with_internal_channels = False ):
42+ def _create_mock_model_trainer (with_internal_channels = False , environment = None ):
4343 """Create a mock ModelTrainer with common attributes.
4444
4545 Args:
4646 with_internal_channels: If True, adds internal channels (code, sm_drivers)
4747 to input_data_config for testing channel inclusion in tuning jobs.
48+ environment: Optional dict of environment variables to set on the trainer.
4849 """
4950 trainer = MagicMock ()
5051 trainer .sagemaker_session = MagicMock ()
@@ -61,6 +62,7 @@ def _create_mock_model_trainer(with_internal_channels=False):
6162 trainer .stopping_condition = MagicMock ()
6263 trainer .stopping_condition .max_runtime_in_seconds = 3600
6364 trainer .input_data_config = None
65+ trainer .environment = environment if environment is not None else {}
6466
6567 if with_internal_channels :
6668 trainer .input_data_config = [
@@ -574,3 +576,90 @@ def test_build_training_job_definition_includes_internal_channels(self):
574576 assert "train" in channel_names , "User 'train' channel should be included"
575577 assert "validation" in channel_names , "User 'validation' channel should be included"
576578 assert len (channel_names ) == 4 , "Should have exactly 4 channels"
579+
580+ def test_build_training_job_definition_includes_environment_variables (self ):
581+ """Test that _build_training_job_definition includes environment variables.
582+
583+ This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
584+ environment variables that were set on the ModelTrainer.
585+ """
586+ env_vars = {"RANDOM_STATE" : "42" , "MY_VAR" : "hello" }
587+ mock_trainer = _create_mock_model_trainer (environment = env_vars )
588+
589+ tuner = HyperparameterTuner (
590+ model_trainer = mock_trainer ,
591+ objective_metric_name = "accuracy" ,
592+ hyperparameter_ranges = _create_single_hp_range (),
593+ )
594+
595+ definition = tuner ._build_training_job_definition (None )
596+
597+ # The definition should contain the environment variables
598+ assert hasattr (definition , "environment" ) or hasattr (definition , "Environment" ), \
599+ "Training job definition should have environment attribute"
600+ definition_env = getattr (definition , "environment" , None ) or getattr (definition , "Environment" , None )
601+ assert definition_env == env_vars , \
602+ f"Environment should be { env_vars } , got { definition_env } "
603+
604+ def test_build_training_job_definition_with_empty_environment (self ):
605+ """Test that _build_training_job_definition handles empty environment."""
606+ mock_trainer = _create_mock_model_trainer (environment = {})
607+
608+ tuner = HyperparameterTuner (
609+ model_trainer = mock_trainer ,
610+ objective_metric_name = "accuracy" ,
611+ hyperparameter_ranges = _create_single_hp_range (),
612+ )
613+
614+ # Should not raise an error
615+ definition = tuner ._build_training_job_definition (None )
616+ assert definition is not None
617+
618+ def test_build_training_job_definition_with_none_environment (self ):
619+ """Test that _build_training_job_definition handles None environment."""
620+ mock_trainer = _create_mock_model_trainer ()
621+ mock_trainer .environment = None
622+
623+ tuner = HyperparameterTuner (
624+ model_trainer = mock_trainer ,
625+ objective_metric_name = "accuracy" ,
626+ hyperparameter_ranges = _create_single_hp_range (),
627+ )
628+
629+ # Should not raise an error
630+ definition = tuner ._build_training_job_definition (None )
631+ assert definition is not None
632+
633+
634+ class TestGetModelTrainerEnvironment :
635+ """Test _get_model_trainer_environment helper method."""
636+
637+ def test_returns_environment_when_set (self ):
638+ """Test that environment is returned when set on model trainer."""
639+ env_vars = {"KEY1" : "val1" , "KEY2" : "val2" }
640+ mock_trainer = _create_mock_model_trainer (environment = env_vars )
641+
642+ result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
643+ assert result == env_vars
644+
645+ def test_returns_none_when_empty (self ):
646+ """Test that None is returned when environment is empty."""
647+ mock_trainer = _create_mock_model_trainer (environment = {})
648+
649+ result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
650+ assert result is None
651+
652+ def test_returns_none_when_none (self ):
653+ """Test that None is returned when environment is None."""
654+ mock_trainer = _create_mock_model_trainer ()
655+ mock_trainer .environment = None
656+
657+ result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
658+ assert result is None
659+
660+ def test_returns_none_when_attribute_missing (self ):
661+ """Test that None is returned when environment attribute doesn't exist."""
662+ mock_trainer = MagicMock (spec = [])
663+
664+ result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
665+ assert result is None
0 commit comments