Conversation
…il in ModelTrain (5504)
sagemaker-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR adds test coverage for PipelineVariable support in ModelTrainer hyperparameters, verifying that the existing safe_serialize function correctly handles ParameterInteger, ParameterString, and ParameterFloat objects. The tests are well-structured and follow SDK conventions. A few minor improvements could be made.
| """ModelTrainer hyperparameters should accept ParameterInteger (GH#5504). | ||
|
|
||
| This is the exact bug scenario: ParameterInteger in hyperparameters | ||
| caused TypeError in safe_serialize before the fix. |
There was a problem hiding this comment.
This test name includes via_safe_serialize but it's only testing that the ModelTrainer constructor stores the hyperparameter — it doesn't actually exercise safe_serialize. The safe_serialize call happens in _create_training_job_args(), which is tested separately below. Consider renaming to something like test_hyperparameters_accept_parameter_integer to avoid implying safe_serialize is being tested here.
|
|
||
| def test_hyperparameters_accept_parameter_string_via_safe_serialize(self): | ||
| """ModelTrainer hyperparameters should accept ParameterString (GH#5504).""" | ||
| objective = ParameterString(name="Objective", default_value="reg:squarederror") |
There was a problem hiding this comment.
Same naming concern as above — via_safe_serialize is misleading since this test only checks constructor assignment, not serialization. Consider test_hyperparameters_accept_parameter_string.
| """ | ||
| max_depth = ParameterInteger(name="MaxDepth", default_value=5) | ||
| trainer = ModelTrainer( | ||
| training_image=DEFAULT_IMAGE, |
There was a problem hiding this comment.
Good test — this is the most valuable one in the PR as it actually exercises the code path (_create_training_job_args → safe_serialize) that was broken in the original bug report. Consider also asserting that args["hyper_parameters"]["max_depth"] is an instance of PipelineVariable (in addition to the identity check) for extra clarity:
assert isinstance(args["hyper_parameters"]["max_depth"], PipelineVariable)
🤖 Iteration #1 — Review Comments AddressedAdd tests for PipelineVariable support in ModelTrainer hyperparametersThe This PR adds comprehensive test coverage for the exact bug scenario described in the issue:
Changes from review feedback
Comments reviewed: 3
|
Description
Add tests for PipelineVariable support in ModelTrainer hyperparameters
The
safe_serializefunction insagemaker-train/src/sagemaker/train/utils.pyalready correctly handlesPipelineVariableobjects (includingParameterInteger,ParameterString,ParameterFloat) by checkingisinstance(data, PipelineVariable)and returning the object as-is, avoiding theTypeErrorfromjson.dumps().The downstream
serialize()function insagemaker-corealso handlesPipelineVariableobjects correctly by returning them as-is.This PR adds comprehensive test coverage for the exact bug scenario described in the issue:
TestSafeSerializePipelineVariable: Direct unit tests forsafe_serialize()withParameterInteger,ParameterString,ParameterFloat, and plain valuesTestModelTrainerHyperparametersPipelineVariable: Integration tests verifying thatModelTraineracceptsPipelineVariableobjects inhyperparametersand that_create_training_job_args()preserves them correctly throughsafe_serialize()Testing
Related Issue
Related issue: 5504
Changes Made
The issue described in aws#5504 has already been fixed in the current codebase. The
safe_serializefunction insagemaker-train/src/sagemaker/train/utils.pyalready handlesPipelineVariableobjects correctly. Specifically:safe_serialize()(lines 179-201) has an explicitelif isinstance(data, PipelineVariable): return databranch that returns PipelineVariable objects as-is, without attempting JSON serialization.The import at the top of the file (
from sagemaker.core.workflow.parameters import PipelineVariable) correctly imports the basePipelineVariableclass thatParameterInteger,ParameterString, etc. all inherit from.The
serialize()function insagemaker-core/src/sagemaker/core/utils/utils.pyalso handlesPipelineVariableobjects by returning them as-is, so downstream serialization in_create_training_job_args()(when using PipelineSession) also works correctly.The
_create_training_job_args()method inmodel_trainer.pyiterates over hyperparameters and callssafe_serialize(value)for each value. When the session is a PipelineSession, the training request goes throughserialize()which preserves PipelineVariable objects for proper pipeline JSON rendering.Additionally, the existing test file
test_model_trainer_pipeline_variable.pyalready contains tests for PipelineVariable acceptance in various ModelTrainer fields (training_image, algorithm_name, training_input_mode, environment), confirming the pattern is established.AI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat