Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sagemaker-train/src/sagemaker/train/tuner.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,11 @@ def _build_training_job_definition(self, inputs):
model_trainer.stopping_condition.max_wait_time_in_seconds
)

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
# Get environment variables from model_trainer.
Comment thread
aviruthen marked this conversation as resolved.
Outdated
# environment is a defined attribute on ModelTrainer (dict | None).
# We pass it through as-is; even an empty dict is valid for the API.
env = model_trainer.environment
Comment thread
aviruthen marked this conversation as resolved.

definition = HyperParameterTrainingJobDefinition(
algorithm_specification=algorithm_spec,
role_arn=model_trainer.role,
Expand All @@ -1513,13 +1518,9 @@ def _build_training_job_definition(self, inputs):
stopping_condition=stopping_condition,
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
environment=env,
)

# Pass through environment variables from model_trainer
env = getattr(model_trainer, "environment", None)
if env and isinstance(env, dict):
definition.environment = env

# Pass through VPC config from model_trainer
networking = getattr(model_trainer, "networking", None)
if networking and hasattr(networking, "_to_vpc_config"):
Expand Down
62 changes: 62 additions & 0 deletions sagemaker-train/tests/unit/train/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,65 @@ def test_build_training_job_definition_includes_spot_params(self):
assert isinstance(
definition.stopping_condition.max_wait_time_in_seconds, int
Comment thread
aviruthen marked this conversation as resolved.
), "Max wait time should be set"

Comment thread
aviruthen marked this conversation as resolved.
def test_build_training_job_definition_includes_environment_variables(self):
"""Test that _build_training_job_definition includes environment variables.

This test verifies the fix for GitHub issue #5613 where tuning jobs were
missing environment variables that were set on the ModelTrainer.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {
"FOO": "bar",
"RANDOM_STATE": "42",
}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is not None, "Environment should not be None"
assert definition.environment == {
"FOO": "bar",
"RANDOM_STATE": "42",
}, "Environment variables should match those set on ModelTrainer"

def test_build_training_job_definition_with_none_environment(self):
"""Test that _build_training_job_definition handles None environment gracefully."""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = None

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment is None, "Environment should be None when not set"

def test_build_training_job_definition_with_empty_environment(self):
"""Test that _build_training_job_definition passes through empty environment.

An empty dict is valid for the SageMaker API, so we pass it through as-is
rather than silently converting it to None.
"""
mock_trainer = _create_mock_model_trainer()
mock_trainer.environment = {}

tuner = HyperparameterTuner(
model_trainer=mock_trainer,
objective_metric_name="accuracy",
hyperparameter_ranges=_create_single_hp_range(),
)

definition = tuner._build_training_job_definition(None)

assert definition.environment == {}, (
"Empty dict environment should be passed through as-is"
)
Loading