-
Notifications
You must be signed in to change notification settings - Fork 0
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,13 +10,18 @@ | |
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Tests for PipelineVariable support in ModelTrainer (GH#5524). | ||
| """Tests for PipelineVariable support in ModelTrainer. | ||
|
|
||
| Verifies that ModelTrainer fields accept PipelineVariable objects | ||
| Verify that ModelTrainer fields accept PipelineVariable objects | ||
| (e.g., ParameterString) in addition to their concrete types, following | ||
| the existing V3 pattern established by SourceCode and OutputDataConfig. | ||
|
|
||
| See: https://github.com/aws/sagemaker-python-sdk/issues/5524 | ||
| Also verify that safe_serialize correctly handles PipelineVariable objects | ||
| in hyperparameters (returning them as-is instead of attempting json.dumps), | ||
| and that _create_training_job_args preserves PipelineVariable objects through | ||
| the serialization pipeline. | ||
|
|
||
| See: https://github.com/aws/sagemaker-python-sdk/issues/5504 | ||
| """ | ||
| from __future__ import absolute_import | ||
|
|
||
|
|
@@ -26,13 +31,18 @@ | |
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
| from sagemaker.core.workflow.parameters import ParameterString | ||
| from sagemaker.core.workflow.parameters import ( | ||
| ParameterString, | ||
| ParameterInteger, | ||
| ParameterFloat, | ||
| ) | ||
| from sagemaker.train.model_trainer import ModelTrainer, Mode | ||
| from sagemaker.train.configs import ( | ||
| Compute, | ||
| StoppingCondition, | ||
| OutputDataConfig, | ||
| ) | ||
| from sagemaker.train.utils import safe_serialize | ||
| from sagemaker.train.defaults import DEFAULT_INSTANCE_TYPE | ||
|
|
||
|
|
||
|
|
@@ -48,7 +58,7 @@ | |
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module", autouse=True) | ||
| @pytest.fixture(scope="module") | ||
| def modules_session(): | ||
| with patch("sagemaker.train.Session", spec=Session) as session_mock: | ||
| session_instance = session_mock.return_value | ||
|
|
@@ -64,7 +74,7 @@ class TestModelTrainerPipelineVariableAcceptance: | |
| """Test that ModelTrainer fields accept PipelineVariable objects.""" | ||
|
|
||
| def test_training_image_accepts_parameter_string(self): | ||
| """ModelTrainer.training_image should accept ParameterString (GH#5524).""" | ||
| """Verify ModelTrainer.training_image accepts ParameterString (GH#5504).""" | ||
| param = ParameterString(name="TrainingImage", default_value=DEFAULT_IMAGE) | ||
| trainer = ModelTrainer( | ||
| training_image=param, | ||
|
|
@@ -77,7 +87,7 @@ def test_training_image_accepts_parameter_string(self): | |
| assert trainer.training_image is param | ||
|
|
||
| def test_algorithm_name_accepts_parameter_string(self): | ||
| """ModelTrainer.algorithm_name should accept ParameterString.""" | ||
| """Verify ModelTrainer.algorithm_name accepts ParameterString.""" | ||
| param = ParameterString(name="AlgorithmName", default_value="my-algo-arn") | ||
| trainer = ModelTrainer( | ||
| algorithm_name=param, | ||
|
|
@@ -90,7 +100,7 @@ def test_algorithm_name_accepts_parameter_string(self): | |
| assert trainer.algorithm_name is param | ||
|
|
||
| def test_training_input_mode_accepts_parameter_string(self): | ||
| """ModelTrainer.training_input_mode should accept ParameterString.""" | ||
| """Verify ModelTrainer.training_input_mode accepts ParameterString.""" | ||
| param = ParameterString(name="InputMode", default_value="File") | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
|
|
@@ -103,7 +113,7 @@ def test_training_input_mode_accepts_parameter_string(self): | |
| assert trainer.training_input_mode is param | ||
|
|
||
| def test_environment_values_accept_parameter_string(self): | ||
| """ModelTrainer.environment dict values should accept ParameterString.""" | ||
| """Verify ModelTrainer.environment dict values accept ParameterString.""" | ||
| param = ParameterString(name="DatasetVersion", default_value="v1") | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
|
|
@@ -121,7 +131,7 @@ class TestModelTrainerRealValuesStillWork: | |
| """Regression tests: verify that passing real values still works after the change.""" | ||
|
|
||
| def test_training_image_accepts_real_string(self): | ||
| """ModelTrainer.training_image should still accept a plain string.""" | ||
| """Verify ModelTrainer.training_image still accepts a plain string.""" | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
|
|
@@ -132,7 +142,7 @@ def test_training_image_accepts_real_string(self): | |
| assert trainer.training_image == DEFAULT_IMAGE | ||
|
|
||
| def test_algorithm_name_accepts_real_string(self): | ||
| """ModelTrainer.algorithm_name should still accept a plain string.""" | ||
| """Verify ModelTrainer.algorithm_name still accepts a plain string.""" | ||
| trainer = ModelTrainer( | ||
| algorithm_name="arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo", | ||
| role=DEFAULT_ROLE, | ||
|
|
@@ -143,7 +153,7 @@ def test_algorithm_name_accepts_real_string(self): | |
| assert trainer.algorithm_name == "arn:aws:sagemaker:us-west-2:000000000000:algorithm/my-algo" | ||
|
|
||
| def test_training_input_mode_accepts_real_string(self): | ||
| """ModelTrainer.training_input_mode should still accept a plain string.""" | ||
| """Verify ModelTrainer.training_input_mode still accepts a plain string.""" | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| training_input_mode="Pipe", | ||
|
|
@@ -155,7 +165,7 @@ def test_training_input_mode_accepts_real_string(self): | |
| assert trainer.training_input_mode == "Pipe" | ||
|
|
||
| def test_environment_accepts_real_string_values(self): | ||
| """ModelTrainer.environment should still accept plain string values.""" | ||
| """Verify ModelTrainer.environment still accepts plain string values.""" | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| environment={"KEY1": "value1", "KEY2": "value2"}, | ||
|
|
@@ -167,7 +177,7 @@ def test_environment_accepts_real_string_values(self): | |
| assert trainer.environment == {"KEY1": "value1", "KEY2": "value2"} | ||
|
|
||
| def test_training_image_rejects_invalid_type(self): | ||
| """ModelTrainer.training_image should still reject invalid types (e.g., int).""" | ||
| """Verify ModelTrainer.training_image still rejects invalid types (e.g., int).""" | ||
| with pytest.raises(ValidationError): | ||
| ModelTrainer( | ||
| training_image=12345, | ||
|
|
@@ -176,3 +186,99 @@ def test_training_image_rejects_invalid_type(self): | |
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| ) | ||
|
|
||
|
|
||
| class TestSafeSerializeWithPipelineVariables: | ||
| """Verify that safe_serialize handles PipelineVariable objects correctly. | ||
|
|
||
| The safe_serialize function must return PipelineVariable objects as-is | ||
| instead of attempting json.dumps(), which would raise TypeError. | ||
| See: https://github.com/aws/sagemaker-python-sdk/issues/5504 | ||
| """ | ||
|
|
||
| @pytest.mark.parametrize("param", [ | ||
| ParameterInteger(name="MaxDepth", default_value=5), | ||
| ParameterString(name="Optimizer", default_value="adam"), | ||
| ParameterFloat(name="LearningRate", default_value=0.01), | ||
| ]) | ||
| def test_safe_serialize_returns_pipeline_variable_as_is(self, param): | ||
| """Verify safe_serialize returns PipelineVariable objects as-is.""" | ||
| result = safe_serialize(param) | ||
| assert result is param | ||
| assert isinstance(result, PipelineVariable) | ||
|
|
||
| @pytest.mark.parametrize("input_val,expected", [ | ||
| ("hello", "hello"), | ||
| (42, "42"), | ||
| ({"key": "value"}, '{"key": "value"}'), | ||
| (0.01, "0.01"), | ||
| (True, "true"), | ||
| (False, "false"), | ||
| ]) | ||
| def test_safe_serialize_handles_normal_types(self, input_val, expected): | ||
| """Verify safe_serialize correctly serializes normal (non-PipelineVariable) types.""" | ||
| result = safe_serialize(input_val) | ||
| assert result == expected | ||
|
|
||
|
|
||
| class TestModelTrainerHyperparametersWithPipelineVariables: | ||
| """Verify that ModelTrainer accepts PipelineVariable objects in hyperparameters. | ||
|
|
||
| See: https://github.com/aws/sagemaker-python-sdk/issues/5504 | ||
| """ | ||
|
|
||
| def test_hyperparameters_accept_pipeline_variable_values(self): | ||
| """Verify ModelTrainer accepts PipelineVariable objects as hyperparameter values.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| learning_rate = ParameterFloat(name="LearningRate", default_value=0.01) | ||
| optimizer = ParameterString(name="Optimizer", default_value="adam") | ||
|
|
||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| hyperparameters={ | ||
| "max_depth": max_depth, | ||
| "learning_rate": learning_rate, | ||
| "optimizer": optimizer, | ||
| "static_param": 10, | ||
| }, | ||
| ) | ||
| assert trainer.hyperparameters["max_depth"] is max_depth | ||
| assert trainer.hyperparameters["learning_rate"] is learning_rate | ||
| assert trainer.hyperparameters["optimizer"] is optimizer | ||
| assert trainer.hyperparameters["static_param"] == 10 | ||
|
|
||
| def test_create_training_job_args_with_pipeline_variable_hyperparameters( | ||
| self, modules_session | ||
| ): | ||
| """Verify _create_training_job_args preserves PipelineVariable in hyper_parameters.""" | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| learning_rate = ParameterFloat(name="LearningRate", default_value=0.01) | ||
|
|
||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, | ||
| role=DEFAULT_ROLE, | ||
| compute=DEFAULT_COMPUTE, | ||
| stopping_condition=DEFAULT_STOPPING, | ||
| output_data_config=DEFAULT_OUTPUT, | ||
| sagemaker_session=modules_session, | ||
| hyperparameters={ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test calls Please verify this test passes in isolation ( |
||
| "max_depth": max_depth, | ||
| "learning_rate": learning_rate, | ||
| "epochs": 10, | ||
| "verbose": "true", | ||
| }, | ||
| ) | ||
|
|
||
| training_args = trainer._create_training_job_args() | ||
| hyper_params = training_args["hyper_parameters"] | ||
|
|
||
| # PipelineVariable objects should be preserved as-is by safe_serialize | ||
| assert hyper_params["max_depth"] is max_depth | ||
| assert hyper_params["learning_rate"] is learning_rate | ||
| # Regular values should be serialized to strings | ||
| assert hyper_params["epochs"] == "10" | ||
| assert hyper_params["verbose"] == "true" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add back the GitHub issue number in this particular comment? You can also add it back on line 19 where it previously said "See: aws#5524"