Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ParameterInteger is imported but ParameterFloat is mentioned in the PR description as being imported. The PR description says "Added ParameterInteger and ParameterFloat imports" but only ParameterInteger is actually imported here. The description is slightly misleading, though the code itself is correct since ParameterFloat isn't used in any test.

from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.train.model_trainer import ModelTrainer, Mode
from sagemaker.train.configs import (
Compute,
Expand Down Expand Up @@ -116,6 +116,123 @@ def test_environment_values_accept_parameter_string(self):
assert trainer.environment["DATASET_VERSION"] is param
assert trainer.environment["STATIC_VAR"] == "hello"

def test_hyperparameters_accept_parameter_integer(self):
"""ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504)."""
param = ParameterInteger(name="MaxDepth", default_value=5)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={"max_depth": param},
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["max_depth"] is param
Comment thread
aviruthen marked this conversation as resolved.

def test_hyperparameters_accept_parameter_string(self):
"""ModelTrainer.hyperparameters should accept ParameterString values (GH#5504)."""
param = ParameterString(name="Algorithm", default_value="xgboost")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={"algorithm": param},
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["algorithm"] is param

def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self):
"""ModelTrainer.hyperparameters should accept a mix of PipelineVariable and plain values.

Regression test for GH#5504.
"""
param_int = ParameterInteger(name="MaxDepth", default_value=5)
param_str = ParameterString(name="Objective", default_value="reg:squarederror")
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={
"max_depth": param_int,
"objective": param_str,
"eta": 0.1,
"num_round": "100",
},
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
assert trainer.hyperparameters["max_depth"] is param_int
assert trainer.hyperparameters["objective"] is param_str
assert trainer.hyperparameters["eta"] == 0.1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing safe_serialize inside the test method body is unconventional. Since this is used in multiple tests within TestSafeSerializePipelineVariable, consider importing it at the module level (with the other imports at the top of the file) to keep imports consistent and avoid repeated inline imports.

# At the top of the file, with other imports:
from sagemaker.train.utils import safe_serialize

assert trainer.hyperparameters["num_round"] == "100"


class TestSafeSerializePipelineVariable:
"""Test that safe_serialize correctly preserves PipelineVariable objects (GH#5504)."""

def test_safe_serialize_preserves_parameter_integer(self):
"""safe_serialize should return PipelineVariable as-is, not stringify it."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same inline import issue — safe_serialize is imported again here. Moving it to the module-level imports would clean this up.

from sagemaker.train.utils import safe_serialize

param = ParameterInteger(name="MaxDepth", default_value=5)
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And again here — third inline import of safe_serialize.

def test_safe_serialize_preserves_parameter_string(self):
"""safe_serialize should return ParameterString as-is."""
from sagemaker.train.utils import safe_serialize

param = ParameterString(name="Objective", default_value="reg:squarederror")
result = safe_serialize(param)
assert result is param
assert isinstance(result, PipelineVariable)

def test_safe_serialize_still_serializes_plain_values(self):
"""safe_serialize should still JSON-serialize plain values."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_create_training_job_args_preserves_pipeline_hyperparameters test calls trainer._create_training_job_args() which is a private method. This is acceptable for a regression test verifying internal behavior, but be aware this test will break if the private method signature changes. Consider adding a brief comment noting this tests internal implementation details intentionally.

Also, does _create_training_job_args() require a session or make any API calls? If so, this test may need mocking to avoid failures in CI. If it works without mocking (as implied by the existing test patterns), then this is fine.

from sagemaker.train.utils import safe_serialize

assert safe_serialize(42) == "42"
assert safe_serialize("hello") == '"hello"'
assert safe_serialize(0.1) == "0.1"

def test_create_training_job_args_preserves_pipeline_hyperparameters(
self,
):
"""_create_training_job_args should preserve PipelineVariable in hyperparameters.

Regression test for GH#5504.
"""
param_int = ParameterInteger(name="MaxDepth", default_value=5)
param_str = ParameterString(
name="Objective", default_value="reg:squarederror"
)
trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
hyperparameters={
"max_depth": param_int,
"objective": param_str,
"eta": 0.1,
"num_round": "100",
},
base_job_name="pipeline-test-job",
role=DEFAULT_ROLE,
compute=DEFAULT_COMPUTE,
stopping_condition=DEFAULT_STOPPING,
output_data_config=DEFAULT_OUTPUT,
)
args = trainer._create_training_job_args()
hp = args["hyperparameters"]
assert hp["max_depth"] is param_int
assert hp["objective"] is param_str
# Plain values should be JSON-serialized strings
assert hp["eta"] == "0.1"
assert hp["num_round"] == '"100"'


class TestModelTrainerRealValuesStillWork:
"""Regression tests: verify that passing real values still works after the change."""
Expand Down
Loading