-
Notifications
You must be signed in to change notification settings - Fork 0
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ParameterString, ParameterInteger | ||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
|
|
@@ -116,6 +116,123 @@ def test_environment_values_accept_parameter_string(self): | |
| assert trainer.environment["DATASET_VERSION"] is param | ||
| assert trainer.environment["STATIC_VAR"] == "hello" | ||
|
|
||
| def test_hyperparameters_accept_parameter_integer(self): | ||
| """ModelTrainer.hyperparameters should accept ParameterInteger values (GH#5504).""" | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| hyperparameters={"max_depth": param}, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is param | ||
|
aviruthen marked this conversation as resolved.
|
||
|
|
||
| def test_hyperparameters_accept_parameter_string(self): | ||
| """ModelTrainer.hyperparameters should accept ParameterString values (GH#5504).""" | ||
| param = ParameterString(name="Algorithm", default_value="xgboost") | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| hyperparameters={"algorithm": param}, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| assert trainer.hyperparameters["algorithm"] is param | ||
|
|
||
| def test_hyperparameters_accept_mixed_pipeline_and_plain_values(self): | ||
| """ModelTrainer.hyperparameters should accept a mix of PipelineVariable and plain values. | ||
|
|
||
| Regression test for GH#5504. | ||
| """ | ||
| param_int = ParameterInteger(name="MaxDepth", default_value=5) | ||
| param_str = ParameterString(name="Objective", default_value="reg:squarederror") | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| hyperparameters={ | ||
| "max_depth": param_int, | ||
| "objective": param_str, | ||
| "eta": 0.1, | ||
| "num_round": "100", | ||
| }, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is param_int | ||
| assert trainer.hyperparameters["objective"] is param_str | ||
| assert trainer.hyperparameters["eta"] == 0.1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Importing # At the top of the file, with other imports:
from sagemaker.train.utils import safe_serialize |
||
| assert trainer.hyperparameters["num_round"] == "100" | ||
|
|
||
|
|
||
| class TestSafeSerializePipelineVariable: | ||
| """Test that safe_serialize correctly preserves PipelineVariable objects (GH#5504).""" | ||
|
|
||
| def test_safe_serialize_preserves_parameter_integer(self): | ||
| """safe_serialize should return PipelineVariable as-is, not stringify it.""" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same inline import issue — |
||
| from sagemaker.train.utils import safe_serialize | ||
|
|
||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
| assert isinstance(result, PipelineVariable) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And again here — third inline import of |
||
| def test_safe_serialize_preserves_parameter_string(self): | ||
| """safe_serialize should return ParameterString as-is.""" | ||
| from sagemaker.train.utils import safe_serialize | ||
|
|
||
| param = ParameterString(name="Objective", default_value="reg:squarederror") | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
| assert isinstance(result, PipelineVariable) | ||
|
|
||
| def test_safe_serialize_still_serializes_plain_values(self): | ||
| """safe_serialize should still JSON-serialize plain values.""" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Also, does |
||
| from sagemaker.train.utils import safe_serialize | ||
|
|
||
| assert safe_serialize(42) == "42" | ||
| assert safe_serialize("hello") == '"hello"' | ||
| assert safe_serialize(0.1) == "0.1" | ||
|
|
||
| def test_create_training_job_args_preserves_pipeline_hyperparameters( | ||
| self, | ||
| ): | ||
| """_create_training_job_args should preserve PipelineVariable in hyperparameters. | ||
|
|
||
| Regression test for GH#5504. | ||
| """ | ||
| param_int = ParameterInteger(name="MaxDepth", default_value=5) | ||
| param_str = ParameterString( | ||
| name="Objective", default_value="reg:squarederror" | ||
| ) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| hyperparameters={ | ||
| "max_depth": param_int, | ||
| "objective": param_str, | ||
| "eta": 0.1, | ||
| "num_round": "100", | ||
| }, | ||
| base_job_name="pipeline-test-job", | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
| args = trainer._create_training_job_args() | ||
| hp = args["hyperparameters"] | ||
| assert hp["max_depth"] is param_int | ||
| assert hp["objective"] is param_str | ||
| # Plain values should be JSON-serialized strings | ||
| assert hp["eta"] == "0.1" | ||
| assert hp["num_round"] == '"100"' | ||
|
|
||
|
|
||
| class TestModelTrainerRealValuesStillWork: | ||
| """Regression tests: verify that passing real values still works after the change.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
ParameterIntegeris imported butParameterFloatis mentioned in the PR description as being imported. The PR description says "AddedParameterIntegerandParameterFloatimports" but onlyParameterIntegeris actually imported here. The description is slightly misleading, though the code itself is correct sinceParameterFloatisn't used in any test.