@@ -578,13 +578,15 @@ def test_build_training_job_definition_includes_internal_channels(self):
578578 assert len (channel_names ) == 4 , "Should have exactly 4 channels"
579579
580580 def test_build_training_job_definition_includes_environment_variables (self ):
581- """Test that _build_training_job_definition includes environment variables .
581+ """Test that _build_training_job_definition includes env vars .
582582
583- This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
584- environment variables that were set on the ModelTrainer.
583+ This test verifies the fix for GitHub issue #5613 where tuning
584+ jobs were missing environment variables set on the ModelTrainer.
585585 """
586586 env_vars = {"RANDOM_STATE" : "42" , "MY_VAR" : "hello" }
587- mock_trainer = _create_mock_model_trainer (environment = env_vars )
587+ mock_trainer = _create_mock_model_trainer (
588+ environment = env_vars ,
589+ )
588590
589591 tuner = HyperparameterTuner (
590592 model_trainer = mock_trainer ,
@@ -595,14 +597,13 @@ def test_build_training_job_definition_includes_environment_variables(self):
595597 definition = tuner ._build_training_job_definition (None )
596598
597599 # The definition should contain the environment variables
598- assert hasattr (definition , "environment" ) or hasattr (definition , "Environment" ), \
599- "Training job definition should have environment attribute"
600- definition_env = getattr (definition , "environment" , None ) or getattr (definition , "Environment" , None )
601- assert definition_env == env_vars , \
602- f"Environment should be { env_vars } , got { definition_env } "
600+ assert definition .environment == env_vars , (
601+ f"Environment should be { env_vars } , "
602+ f"got { definition .environment } "
603+ )
603604
604605 def test_build_training_job_definition_with_empty_environment (self ):
605- """Test that _build_training_job_definition handles empty environment ."""
606+ """Test that empty env is not propagated to definition ."""
606607 mock_trainer = _create_mock_model_trainer (environment = {})
607608
608609 tuner = HyperparameterTuner (
@@ -611,12 +612,17 @@ def test_build_training_job_definition_with_empty_environment(self):
611612 hyperparameter_ranges = _create_single_hp_range (),
612613 )
613614
614- # Should not raise an error
615615 definition = tuner ._build_training_job_definition (None )
616616 assert definition is not None
617+ # Empty environment should not be set on the definition
618+ env = getattr (definition , "environment" , None )
619+ assert env is None , (
620+ "Empty environment should not be propagated, "
621+ f"got { env } "
622+ )
617623
618624 def test_build_training_job_definition_with_none_environment (self ):
619- """Test that _build_training_job_definition handles None environment ."""
625+ """Test that None env is not propagated to definition ."""
620626 mock_trainer = _create_mock_model_trainer ()
621627 mock_trainer .environment = None
622628
@@ -626,40 +632,135 @@ def test_build_training_job_definition_with_none_environment(self):
626632 hyperparameter_ranges = _create_single_hp_range (),
627633 )
628634
629- # Should not raise an error
630635 definition = tuner ._build_training_job_definition (None )
631636 assert definition is not None
637+ # None environment should not be set on the definition
638+ env = getattr (definition , "environment" , None )
639+ assert env is None , (
640+ "None environment should not be propagated, "
641+ f"got { env } "
642+ )
632643
633644
634645class TestGetModelTrainerEnvironment :
635646 """Test _get_model_trainer_environment helper method."""
636647
637648 def test_returns_environment_when_set (self ):
638- """Test that environment is returned when set on model trainer ."""
649+ """Test that environment is returned when set."""
639650 env_vars = {"KEY1" : "val1" , "KEY2" : "val2" }
640- mock_trainer = _create_mock_model_trainer (environment = env_vars )
651+ mock_trainer = _create_mock_model_trainer (
652+ environment = env_vars ,
653+ )
641654
642- result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
655+ result = HyperparameterTuner ._get_model_trainer_environment (
656+ mock_trainer ,
657+ )
643658 assert result == env_vars
644659
645660 def test_returns_none_when_empty (self ):
646661 """Test that None is returned when environment is empty."""
647662 mock_trainer = _create_mock_model_trainer (environment = {})
648663
649- result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
664+ result = HyperparameterTuner ._get_model_trainer_environment (
665+ mock_trainer ,
666+ )
650667 assert result is None
651668
652669 def test_returns_none_when_none (self ):
653670 """Test that None is returned when environment is None."""
654671 mock_trainer = _create_mock_model_trainer ()
655672 mock_trainer .environment = None
656673
657- result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
674+ result = HyperparameterTuner ._get_model_trainer_environment (
675+ mock_trainer ,
676+ )
658677 assert result is None
659678
660- def test_returns_none_when_attribute_missing (self ):
661- """Test that None is returned when environment attribute doesn't exist."""
662- mock_trainer = MagicMock (spec = [])
663679
664- result = HyperparameterTuner ._get_model_trainer_environment (mock_trainer )
665- assert result is None
680+ class TestMultiTrainerEnvironmentPropagation :
681+ """Test environment propagation for multi-trainer tuning jobs."""
682+
683+ def test_create_multi_trainer_with_environment (self ):
684+ """Test that environment is preserved on trainers in create()."""
685+ env1 = {"VAR_A" : "1" }
686+ env2 = {"VAR_B" : "2" }
687+ trainer1 = _create_mock_model_trainer (environment = env1 )
688+ trainer2 = _create_mock_model_trainer (environment = env2 )
689+
690+ tuner = HyperparameterTuner .create (
691+ model_trainer_dict = {
692+ "trainer1" : trainer1 ,
693+ "trainer2" : trainer2 ,
694+ },
695+ objective_metric_name_dict = {
696+ "trainer1" : "accuracy" ,
697+ "trainer2" : "loss" ,
698+ },
699+ hyperparameter_ranges_dict = {
700+ "trainer1" : _create_single_hp_range (),
701+ "trainer2" : _create_single_hp_range (),
702+ },
703+ )
704+
705+ # Verify environment is preserved on each trainer
706+ assert tuner .model_trainer_dict ["trainer1" ].environment == env1
707+ assert tuner .model_trainer_dict ["trainer2" ].environment == env2
708+
709+ def test_get_environment_for_each_trainer_in_dict (self ):
710+ """Test _get_model_trainer_environment for each trainer."""
711+ env1 = {"VAR_A" : "1" }
712+ env2 = {"VAR_B" : "2" }
713+ trainer1 = _create_mock_model_trainer (environment = env1 )
714+ trainer2 = _create_mock_model_trainer (environment = env2 )
715+
716+ tuner = HyperparameterTuner .create (
717+ model_trainer_dict = {
718+ "trainer1" : trainer1 ,
719+ "trainer2" : trainer2 ,
720+ },
721+ objective_metric_name_dict = {
722+ "trainer1" : "accuracy" ,
723+ "trainer2" : "loss" ,
724+ },
725+ hyperparameter_ranges_dict = {
726+ "trainer1" : _create_single_hp_range (),
727+ "trainer2" : _create_single_hp_range (),
728+ },
729+ )
730+
731+ for name , mt in tuner .model_trainer_dict .items ():
732+ env = HyperparameterTuner ._get_model_trainer_environment (
733+ mt ,
734+ )
735+ if name == "trainer1" :
736+ assert env == env1
737+ elif name == "trainer2" :
738+ assert env == env2
739+
740+ def test_multi_trainer_empty_environment (self ):
741+ """Test multi-trainer with empty environment."""
742+ trainer1 = _create_mock_model_trainer (environment = {})
743+ trainer2 = _create_mock_model_trainer (environment = {})
744+
745+ tuner = HyperparameterTuner .create (
746+ model_trainer_dict = {
747+ "trainer1" : trainer1 ,
748+ "trainer2" : trainer2 ,
749+ },
750+ objective_metric_name_dict = {
751+ "trainer1" : "accuracy" ,
752+ "trainer2" : "loss" ,
753+ },
754+ hyperparameter_ranges_dict = {
755+ "trainer1" : _create_single_hp_range (),
756+ "trainer2" : _create_single_hp_range (),
757+ },
758+ )
759+
760+ for _name , mt in tuner .model_trainer_dict .items ():
761+ env = HyperparameterTuner ._get_model_trainer_environment (
762+ mt ,
763+ )
764+ assert env is None , (
765+ "Empty environment should return None"
766+ )
0 commit comments