Skip to content

Commit 862ff2d

Browse files
committed
fix: address review comments (iteration #2)
1 parent dec47ab commit 862ff2d

File tree

3 files changed

+60
-10
lines changed

3 files changed

+60
-10
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,11 +1505,15 @@ def _build_training_job_definition(self, inputs):
15051505
)
15061506

15071507
# Get environment variables from model_trainer.
1508-
# environment is a defined attribute on ModelTrainer (dict | None).
1509-
# We pass it through as-is; even an empty dict is valid for the API.
1508+
# environment is a defined attribute on ModelTrainer (typed as dict | None).
1509+
# We access it directly (consistent with how role, compute, etc. are accessed).
1510+
# We pass it through as-is when it's a dict — even an empty dict is valid for the API.
1511+
# When it's None or not a dict, we omit it from the constructor so the Pydantic
1512+
# model keeps its default (Unassigned), which is then excluded during serialization.
15101513
env = model_trainer.environment
15111514

1512-
definition = HyperParameterTrainingJobDefinition(
1515+
# Build base kwargs for the definition
1516+
definition_kwargs = dict(
15131517
algorithm_specification=algorithm_spec,
15141518
role_arn=model_trainer.role,
15151519
input_data_config=input_data_config if input_data_config else None,
@@ -1518,9 +1522,16 @@ def _build_training_job_definition(self, inputs):
15181522
stopping_condition=stopping_condition,
15191523
static_hyper_parameters=getattr(self, "static_hyperparameters", None) or {},
15201524
enable_managed_spot_training=model_trainer.compute.enable_managed_spot_training,
1521-
environment=env,
15221525
)
15231526

1527+
# Only include environment when it's a dict (including empty dict).
1528+
# This avoids Pydantic validation errors for non-dict values and keeps
1529+
# the field as Unassigned (excluded from serialization) when not set.
1530+
if isinstance(env, dict):
1531+
definition_kwargs["environment"] = env
1532+
1533+
definition = HyperParameterTrainingJobDefinition(**definition_kwargs)
1534+
15241535
# Pass through VPC config from model_trainer
15251536
networking = getattr(model_trainer, "networking", None)
15261537
if networking and hasattr(networking, "_to_vpc_config"):

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,13 @@ def test_build_training_job_definition_includes_environment_variables(self):
624624
}, "Environment variables should match those set on ModelTrainer"
625625

626626
def test_build_training_job_definition_with_none_environment(self):
627-
"""Test that _build_training_job_definition handles None environment gracefully."""
627+
"""Test that _build_training_job_definition handles None environment gracefully.
628+
629+
When environment is None, it should not be passed to the Pydantic constructor,
630+
so the field stays as Unassigned (excluded from serialization).
631+
"""
632+
from sagemaker.core.utils.utils import Unassigned
633+
628634
mock_trainer = _create_mock_model_trainer()
629635
mock_trainer.environment = None
630636

@@ -636,7 +642,9 @@ def test_build_training_job_definition_with_none_environment(self):
636642

637643
definition = tuner._build_training_job_definition(None)
638644

639-
assert definition.environment is None, "Environment should be None when not set"
645+
assert isinstance(definition.environment, Unassigned), (
646+
"Environment should be Unassigned when model_trainer.environment is None"
647+
)
640648

641649
def test_build_training_job_definition_with_empty_environment(self):
642650
"""Test that _build_training_job_definition passes through empty environment.

sagemaker-train/tests/unit/train/test_tuner_driver_channels.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,31 @@ def test_passes_environment_variables(self):
405405
definition = tuner._build_training_job_definition(inputs=None)
406406
assert definition.environment == {"MY_VAR": "value", "OTHER": "123"}
407407

408+
def test_passes_empty_environment(self):
409+
"""Should pass through empty dict environment as-is.
410+
411+
An empty dict is valid for the SageMaker API, so we pass it through
412+
rather than silently converting it to None/Unassigned.
413+
"""
414+
trainer = _mock_model_trainer(environment={})
415+
416+
tuner = HyperparameterTuner(
417+
model_trainer=trainer,
418+
objective_metric_name="accuracy",
419+
hyperparameter_ranges=_hp_ranges(),
420+
)
421+
422+
definition = tuner._build_training_job_definition(inputs=None)
423+
assert definition.environment == {}, (
424+
"Empty dict environment should be passed through as-is"
425+
)
426+
408427
def test_skips_environment_when_none(self):
409-
"""Should not set environment when model_trainer.environment is None."""
428+
"""Should not set environment when model_trainer.environment is None.
429+
430+
When environment is None, it is not passed to the Pydantic constructor,
431+
so the field stays as Unassigned (excluded from serialization).
432+
"""
410433
trainer = _mock_model_trainer(environment=None)
411434

412435
tuner = HyperparameterTuner(
@@ -416,10 +439,16 @@ def test_skips_environment_when_none(self):
416439
)
417440

418441
definition = tuner._build_training_job_definition(inputs=None)
419-
assert _is_unassigned(definition.environment)
442+
assert _is_unassigned(definition.environment), (
443+
"Environment should be Unassigned when model_trainer.environment is None"
444+
)
420445

421446
def test_skips_environment_when_not_dict(self):
422-
"""Should not set environment when it's not a dict (e.g. MagicMock)."""
447+
"""Should not set environment when it's not a dict (e.g. MagicMock).
448+
449+
Non-dict values are not passed to the Pydantic constructor to avoid
450+
validation errors. The field stays as Unassigned.
451+
"""
423452
trainer = _mock_model_trainer(environment=MagicMock())
424453

425454
tuner = HyperparameterTuner(
@@ -429,7 +458,9 @@ def test_skips_environment_when_not_dict(self):
429458
)
430459

431460
definition = tuner._build_training_job_definition(inputs=None)
432-
assert _is_unassigned(definition.environment)
461+
assert _is_unassigned(definition.environment), (
462+
"Environment should be Unassigned when model_trainer.environment is not a dict"
463+
)
433464

434465
def test_passes_vpc_config(self):
435466
"""Should set definition.vpc_config from model_trainer.networking._to_vpc_config()."""

0 commit comments

Comments
 (0)