Skip to content

Commit 0be1e83

Browse files
committed
fix: address review comments (iteration #2)
1 parent 3c99488 commit 0be1e83

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

sagemaker-train/src/sagemaker/train/tuner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
1414

15-
from __future__ import absolute_import
15+
from __future__ import absolute_import, annotations
1616

1717
import logging
1818
from enum import Enum
@@ -445,17 +445,18 @@ def _prepare_auto_parameters(self, static_hyperparameters, hyperparameters_to_ke
445445
@staticmethod
446446
def _get_model_trainer_environment(
447447
model_trainer: "ModelTrainer",
448-
) -> Optional[Dict[str, str]]:
448+
) -> dict[str, str] | None:
449449
"""Extract environment variables from a ModelTrainer instance.
450450
451-
Returns the environment dict if it is non-empty, otherwise None.
451+
Returns a copy of the environment dict if it is non-empty,
452+
otherwise None.
452453
453454
Args:
454455
model_trainer (ModelTrainer): ModelTrainer instance.
455456
456457
Returns:
457-
Optional[Dict[str, str]]: Environment variables dict,
458-
or None if empty/not set.
458+
dict[str, str] | None: A copy of the environment variables
459+
dict, or None if empty/not set.
459460
"""
460461
env = model_trainer.environment
461462
if env:

sagemaker-train/tests/unit/train/test_tuner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,11 @@ def test_build_training_job_definition_includes_environment_variables(self):
601601
f"Environment should be {env_vars}, "
602602
f"got {definition.environment}"
603603
)
604+
# Verify defensive copy: the dict on the definition
605+
# should not be the same object as the original
606+
assert definition.environment is not env_vars, (
607+
"Environment should be a copy, not the same object"
608+
)
604609

605610
def test_build_training_job_definition_with_empty_environment(self):
606611
"""Test that empty env is not propagated to definition."""
@@ -656,6 +661,10 @@ def test_returns_environment_when_set(self):
656661
mock_trainer,
657662
)
658663
assert result == env_vars
664+
# Verify it's a copy, not the same object
665+
assert result is not env_vars, (
666+
"Should return a defensive copy"
667+
)
659668

660669
def test_returns_none_when_empty(self):
661670
"""Test that None is returned when environment is empty."""

0 commit comments

Comments
 (0)