Skip to content

Commit 3c99488

Browse files
committed
fix: address review comments (iteration #1)
1 parent b22e8a6 commit 3c99488

2 files changed

Lines changed: 133 additions & 29 deletions

File tree

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -443,18 +443,21 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke
443443
return new_static_hyperparameters, auto_parameters
444444

445445
@staticmethod
446-
def _get_model_trainer_environment(model_trainer):
446+
def _get_model_trainer_environment(
447+
model_trainer: "ModelTrainer",
448+
) -> Optional[Dict[str, str]]:
447449
"""Extract environment variables from a ModelTrainer instance.
448450
449451
Returns the environment dict if it is non-empty, otherwise None.
450452
451453
Args:
452-
model_trainer: ModelTrainer instance
454+
model_trainer (ModelTrainer): ModelTrainer instance.
453455
454456
Returns:
455-
dict or None: Environment variables dict, or None if empty/not set.
457+
Optional[Dict[str, str]]: Environment variables dict,
458+
or None if empty/not set.
456459
"""
457-
env = getattr(model_trainer, "environment", None)
460+
env = model_trainer.environment
458461
if env:
459462
return dict(env)
460463
return None
@@ -1530,8 +1533,8 @@ def _build_training_job_definition(self, inputs):
15301533
)
15311534

15321535
# Pass through environment variables from model_trainer
1533-
env = getattr(model_trainer, "environment", None)
1534-
if env and isinstance(env, dict):
1536+
env = self._get_model_trainer_environment(model_trainer)
1537+
if env is not None:
15351538
definition.environment = env
15361539

15371540
# Pass through VPC config from model_trainer

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 124 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -578,13 +578,15 @@ def test_build_training_job_definition_includes_internal_channels(self):
578578
assert len(channel_names) == 4, "Should have exactly 4 channels"
579579

580580
def test_build_training_job_definition_includes_environment_variables(self):
581-
"""Test that _build_training_job_definition includes environment variables.
581+
"""Test that _build_training_job_definition includes env vars.
582582
583-
This test verifies the fix for GitHub issue #5613 where tuning jobs were missing
584-
environment variables that were set on the ModelTrainer.
583+
This test verifies the fix for GitHub issue #5613 where tuning
584+
jobs were missing environment variables set on the ModelTrainer.
585585
"""
586586
env_vars = {"RANDOM_STATE": "42", "MY_VAR": "hello"}
587-
mock_trainer = _create_mock_model_trainer(environment=env_vars)
587+
mock_trainer = _create_mock_model_trainer(
588+
environment=env_vars,
589+
)
588590

589591
tuner = HyperparameterTuner(
590592
model_trainer=mock_trainer,
@@ -595,14 +597,13 @@ def test_build_training_job_definition_includes_environment_variables(self):
595597
definition = tuner._build_training_job_definition(None)
596598

597599
# The definition should contain the environment variables
598-
assert hasattr(definition, "environment") or hasattr(definition, "Environment"), \
599-
"Training job definition should have environment attribute"
600-
definition_env = getattr(definition, "environment", None) or getattr(definition, "Environment", None)
601-
assert definition_env == env_vars, \
602-
f"Environment should be {env_vars}, got {definition_env}"
600+
assert definition.environment == env_vars, (
601+
f"Environment should be {env_vars}, "
602+
f"got {definition.environment}"
603+
)
603604

604605
def test_build_training_job_definition_with_empty_environment(self):
605-
"""Test that _build_training_job_definition handles empty environment."""
606+
"""Test that empty env is not propagated to definition."""
606607
mock_trainer = _create_mock_model_trainer(environment={})
607608

608609
tuner = HyperparameterTuner(
@@ -611,12 +612,17 @@ def test_build_training_job_definition_with_empty_environment(self):
611612
hyperparameter_ranges=_create_single_hp_range(),
612613
)
613614

614-
# Should not raise an error
615615
definition = tuner._build_training_job_definition(None)
616616
assert definition is not None
617+
# Empty environment should not be set on the definition
618+
env = getattr(definition, "environment", None)
619+
assert env is None, (
620+
"Empty environment should not be propagated, "
621+
f"got {env}"
622+
)
617623

618624
def test_build_training_job_definition_with_none_environment(self):
619-
"""Test that _build_training_job_definition handles None environment."""
625+
"""Test that None env is not propagated to definition."""
620626
mock_trainer = _create_mock_model_trainer()
621627
mock_trainer.environment = None
622628

@@ -626,40 +632,135 @@ def test_build_training_job_definition_with_none_environment(self):
626632
hyperparameter_ranges=_create_single_hp_range(),
627633
)
628634

629-
# Should not raise an error
630635
definition = tuner._build_training_job_definition(None)
631636
assert definition is not None
637+
# None environment should not be set on the definition
638+
env = getattr(definition, "environment", None)
639+
assert env is None, (
640+
"None environment should not be propagated, "
641+
f"got {env}"
642+
)
632643

633644

634645
class TestGetModelTrainerEnvironment:
635646
"""Test _get_model_trainer_environment helper method."""
636647

637648
def test_returns_environment_when_set(self):
638-
"""Test that environment is returned when set on model trainer."""
649+
"""Test that environment is returned when set."""
639650
env_vars = {"KEY1": "val1", "KEY2": "val2"}
640-
mock_trainer = _create_mock_model_trainer(environment=env_vars)
651+
mock_trainer = _create_mock_model_trainer(
652+
environment=env_vars,
653+
)
641654

642-
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
655+
result = HyperparameterTuner._get_model_trainer_environment(
656+
mock_trainer,
657+
)
643658
assert result == env_vars
644659

645660
def test_returns_none_when_empty(self):
646661
"""Test that None is returned when environment is empty."""
647662
mock_trainer = _create_mock_model_trainer(environment={})
648663

649-
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
664+
result = HyperparameterTuner._get_model_trainer_environment(
665+
mock_trainer,
666+
)
650667
assert result is None
651668

652669
def test_returns_none_when_none(self):
653670
"""Test that None is returned when environment is None."""
654671
mock_trainer = _create_mock_model_trainer()
655672
mock_trainer.environment = None
656673

657-
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
674+
result = HyperparameterTuner._get_model_trainer_environment(
675+
mock_trainer,
676+
)
658677
assert result is None
659678

660-
def test_returns_none_when_attribute_missing(self):
661-
"""Test that None is returned when environment attribute doesn't exist."""
662-
mock_trainer = MagicMock(spec=[])
663679

664-
result = HyperparameterTuner._get_model_trainer_environment(mock_trainer)
665-
assert result is None
680+
class TestMultiTrainerEnvironmentPropagation:
681+
"""Test environment propagation for multi-trainer tuning jobs."""
682+
683+
def test_create_multi_trainer_with_environment(self):
684+
"""Test that environment is preserved on trainers in create()."""
685+
env1 = {"VAR_A": "1"}
686+
env2 = {"VAR_B": "2"}
687+
trainer1 = _create_mock_model_trainer(environment=env1)
688+
trainer2 = _create_mock_model_trainer(environment=env2)
689+
690+
tuner = HyperparameterTuner.create(
691+
model_trainer_dict={
692+
"trainer1": trainer1,
693+
"trainer2": trainer2,
694+
},
695+
objective_metric_name_dict={
696+
"trainer1": "accuracy",
697+
"trainer2": "loss",
698+
},
699+
hyperparameter_ranges_dict={
700+
"trainer1": _create_single_hp_range(),
701+
"trainer2": _create_single_hp_range(),
702+
},
703+
)
704+
705+
# Verify environment is preserved on each trainer
706+
assert tuner.model_trainer_dict["trainer1"].environment == env1
707+
assert tuner.model_trainer_dict["trainer2"].environment == env2
708+
709+
def test_get_environment_for_each_trainer_in_dict(self):
710+
"""Test _get_model_trainer_environment for each trainer."""
711+
env1 = {"VAR_A": "1"}
712+
env2 = {"VAR_B": "2"}
713+
trainer1 = _create_mock_model_trainer(environment=env1)
714+
trainer2 = _create_mock_model_trainer(environment=env2)
715+
716+
tuner = HyperparameterTuner.create(
717+
model_trainer_dict={
718+
"trainer1": trainer1,
719+
"trainer2": trainer2,
720+
},
721+
objective_metric_name_dict={
722+
"trainer1": "accuracy",
723+
"trainer2": "loss",
724+
},
725+
hyperparameter_ranges_dict={
726+
"trainer1": _create_single_hp_range(),
727+
"trainer2": _create_single_hp_range(),
728+
},
729+
)
730+
731+
for name, mt in tuner.model_trainer_dict.items():
732+
env = HyperparameterTuner._get_model_trainer_environment(
733+
mt,
734+
)
735+
if name == "trainer1":
736+
assert env == env1
737+
elif name == "trainer2":
738+
assert env == env2
739+
740+
def test_multi_trainer_empty_environment(self):
741+
"""Test multi-trainer with empty environment."""
742+
trainer1 = _create_mock_model_trainer(environment={})
743+
trainer2 = _create_mock_model_trainer(environment={})
744+
745+
tuner = HyperparameterTuner.create(
746+
model_trainer_dict={
747+
"trainer1": trainer1,
748+
"trainer2": trainer2,
749+
},
750+
objective_metric_name_dict={
751+
"trainer1": "accuracy",
752+
"trainer2": "loss",
753+
},
754+
hyperparameter_ranges_dict={
755+
"trainer1": _create_single_hp_range(),
756+
"trainer2": _create_single_hp_range(),
757+
},
758+
)
759+
760+
for _name, mt in tuner.model_trainer_dict.items():
761+
env = HyperparameterTuner._get_model_trainer_environment(
762+
mt,
763+
)
764+
assert env is None, (
765+
"Empty environment should return None"
766+
)

0 commit comments

Comments
 (0)