-
Notifications
You must be signed in to change notification settings - Fork 0
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,14 +26,15 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger, ParameterFloat | ||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
| StoppingCondition, | ||
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
| from sagemaker.train.utils import safe_serialize | ||
|
|
||
|
|
||
| DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" | ||
|
|
@@ -117,6 +118,121 @@ def test_environment_values_accept_parameter_string(self): | |
| assert trainer.environment["STATIC_VAR"] == "hello" | ||
|
|
||
|
|
||
| class TestSafeSerializePipelineVariable: | ||
| """Test that safe_serialize correctly handles PipelineVariable objects (GH#5504).""" | ||
|
|
||
| def test_safe_serialize_returns_pipeline_variable_as_is(self): | ||
| """safe_serialize should return PipelineVariable objects without JSON serialization.""" | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
| def test_safe_serialize_returns_parameter_string_as_is(self): | ||
| """safe_serialize should return ParameterString objects without JSON serialization.""" | ||
| param = ParameterString(name="Algorithm", default_value="xgboost") | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
| def test_safe_serialize_returns_parameter_float_as_is(self): | ||
| """safe_serialize should return ParameterFloat objects without JSON serialization.""" | ||
| param = ParameterFloat(name="LearningRate", default_value=0.01) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
|
|
||
| def test_safe_serialize_still_handles_plain_string(self): | ||
| """safe_serialize should return plain strings as-is.""" | ||
| result = safe_serialize("hello") | ||
| assert result == "hello" | ||
|
|
||
| def test_safe_serialize_still_handles_int(self): | ||
| """safe_serialize should JSON-encode integers.""" | ||
| result = safe_serialize(42) | ||
| assert result == "42" | ||
|
|
||
| def test_safe_serialize_still_handles_dict(self): | ||
| """safe_serialize should JSON-encode dicts.""" | ||
| result = safe_serialize({"key": "value"}) | ||
| assert result == '{"key": "value"}' | ||
|
|
||
|
|
||
| class TestModelTrainerHyperparametersPipelineVariable: | ||
| """Test that ModelTrainer hyperparameters accept PipelineVariable objects (GH#5504).""" | ||
|
|
||
| def test_hyperparameters_accept_parameter_integer(self): | ||
| """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504). | ||
|
|
||
| This is the exact bug scenario: ParameterInteger in hyperparameters | ||
| caused TypeError in safe_serialize before the fix. | ||
| """ | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"max_depth": max_depth}, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is max_depth | ||
|
|
||
| def test_hyperparameters_accept_parameter_string(self): | ||
| """ModelTrainer hyperparameters should accept ParameterString (GH#5504).""" | ||
| objective = ParameterString(name="Objective", default_value="reg:squarederror") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same naming concern as above — |
||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"objective": objective}, | ||
| ) | ||
| assert trainer.hyperparameters["objective"] is objective | ||
|
|
||
| def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self): | ||
| """ModelTrainer hyperparameters should accept a mix of PipelineVariable and plain values.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={ | ||
| "max_depth": max_depth, | ||
| "eta": 0.1, | ||
| "objective": "reg:squarederror", | ||
| }, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is max_depth | ||
| assert trainer.hyperparameters["eta"] == 0.1 | ||
| assert trainer.hyperparameters["objective"] == "reg:squarederror" | ||
|
|
||
| @patch("sagemaker.train.model_trainer._get_unique_name", return_value="test-job-20240101") | ||
| def test_create_training_job_args_preserves_pipeline_variable_in_hyperparameters( | ||
| self, mock_unique_name | ||
| ): | ||
| """_create_training_job_args should preserve PipelineVariable in hyper_parameters dict. | ||
|
|
||
| When safe_serialize is called on a PipelineVariable, it should return the | ||
| PipelineVariable object as-is, not attempt JSON serialization. | ||
| """ | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test — this is the most valuable one in the PR as it actually exercises the code path ( assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable) |
||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={"max_depth": max_depth, "eta": 0.1}, | ||
| ) | ||
| args = trainer._create_training_job_args() | ||
| # PipelineVariable should be preserved as-is by safe_serialize | ||
| assert args["hyper_parameters"]["max_depth"] is max_depth | ||
| assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable) | ||
| # Plain values should be JSON-serialized to strings | ||
| assert args["hyper_parameters"]["eta"] == "0.1" | ||
|
|
||
|
|
||
| class TestModelTrainerRealValuesStillWork: | ||
| """Regression tests: verify that passing real values still works after the change.""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test name includes
via_safe_serializebut it's only testing that theModelTrainerconstructor stores the hyperparameter — it doesn't actually exercisesafe_serialize. Thesafe_serializecall happens in_create_training_job_args(), which is tested separately below. Consider renaming to something liketest_hyperparameters_accept_parameter_integerto avoid implyingsafe_serializeis being tested here.