-
Notifications
You must be signed in to change notification settings - Fork 0
fix: [Bug] Pipeline parameters (ParameterInteger, ParameterString) fail in ModelTrain (5504) #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,7 +116,7 @@ | |
| from sagemaker.core.jumpstart.utils import get_eula_url | ||
| from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline | ||
| from sagemaker.core.helper.pipeline_variable import StrPipeVar | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable, StrPipeVar | ||
|
|
||
| from sagemaker.train.local.local_container import _LocalContainer | ||
|
|
||
|
|
@@ -410,14 +410,21 @@ def __del__(self): | |
| self._temp_code_dir.cleanup() | ||
|
|
||
| def _validate_training_image_and_algorithm_name( | ||
| self, training_image: Optional[str], algorithm_name: Optional[str] | ||
| self, | ||
| training_image: Optional[StrPipeVar], | ||
| algorithm_name: Optional[StrPipeVar], | ||
| ): | ||
| """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" | ||
|
aviruthen marked this conversation as resolved.
|
||
| if not training_image and not algorithm_name: | ||
| # PipelineVariable objects do not support standard boolean coercion | ||
| # (__bool__ raises TypeError), so we use isinstance checks to detect | ||
| # them as truthy values during validation. | ||
| has_image = isinstance(training_image, PipelineVariable) or bool(training_image) | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic has_image = training_image is not None and training_image != ""
has_algo = algorithm_name is not None and algorithm_name != ""This avoids calling |
||
| has_algo = isinstance(algorithm_name, PipelineVariable) or bool(algorithm_name) | ||
| if not has_image and not has_algo: | ||
|
aviruthen marked this conversation as resolved.
|
||
| raise ValueError( | ||
| "Atleast one of 'training_image' or 'algorithm_name' must be provided.", | ||
| ) | ||
| if training_image and algorithm_name: | ||
| if has_image and has_algo: | ||
| raise ValueError( | ||
| "Only one of 'training_image' or 'algorithm_name' must be provided.", | ||
| ) | ||
|
|
@@ -546,9 +553,11 @@ def model_post_init(self, __context: Any): | |
| ) | ||
|
|
||
| if self.training_image: | ||
|
aviruthen marked this conversation as resolved.
|
||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable | ||
| if isinstance(self.training_image, PipelineVariable): | ||
| logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)") | ||
| logger.info( | ||
| "Training image URI: " | ||
| "(PipelineVariable - resolved at pipeline execution)" | ||
| ) | ||
| else: | ||
| logger.info(f"Training image URI: {self.training_image}") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,11 +23,15 @@ | |
| from datetime import datetime | ||
| from typing import Literal, Any | ||
|
|
||
| from typing import Union | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicate import: from typing import Literal, Any, UnionAlso, since the module already imports |
||
|
|
||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.shapes import Unassigned | ||
| from sagemaker.train import logger | ||
| from sagemaker.core.workflow.parameters import PipelineVariable | ||
|
|
||
| _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER = "pipeline-variable-image" | ||
|
|
||
|
|
||
| def _default_bucket_and_prefix(session: Session) -> str: | ||
| """Helper function to get the bucket name with the corresponding prefix if applicable | ||
|
|
@@ -142,7 +146,7 @@ def _get_unique_name(base, max_length=63): | |
| return unique_name | ||
|
|
||
|
|
||
| def _get_repo_name_from_image(image: str) -> str: | ||
| def _get_repo_name_from_image(image: Union[str, PipelineVariable]) -> str: | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type annotation says |
||
| """Get the repository name from the image URI. | ||
|
aviruthen marked this conversation as resolved.
|
||
|
|
||
| Example: | ||
|
|
@@ -152,11 +156,13 @@ def _get_repo_name_from_image(image: str) -> str: | |
| ``` | ||
|
|
||
| Args: | ||
| image (str): The image URI | ||
| image (str or PipelineVariable): The image URI. | ||
|
|
||
| Returns: | ||
|
aviruthen marked this conversation as resolved.
|
||
| str: The repository name | ||
| """ | ||
| if isinstance(image, PipelineVariable): | ||
|
aviruthen marked this conversation as resolved.
|
||
| return _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER | ||
| return image.split("/")[-1].split(":")[0].split("@")[0] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,281 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Tests for PipelineVariable support in ModelTrainer.""" | ||
| from __future__ import annotations | ||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
| import pytest | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from sagemaker.core.workflow.parameters import ( | ||
| ParameterString, | ||
| ParameterInteger, | ||
| ) | ||
| from sagemaker.core.helper.pipeline_variable import PipelineVariable | ||
| from sagemaker.train.utils import ( | ||
| safe_serialize, | ||
| _get_repo_name_from_image, | ||
| _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER, | ||
| ) | ||
|
|
||
| _TEST_IMAGE_URI = ( | ||
| "683313688378.dkr.ecr.us-east-1.amazonaws.com/" | ||
| "sagemaker-xgboost:1.0-1-cpu-py3" | ||
| ) | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: This test image URI contains a hardcoded region ( |
||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
|
|
||
| class TestSafeSerializeWithPipelineVariable: | ||
| """Tests for safe_serialize handling of PipelineVariable objects.""" | ||
|
|
||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| def test_safe_serialize_string(self): | ||
| """Test that plain strings are returned as-is.""" | ||
| assert safe_serialize("hello") == "hello" | ||
|
|
||
| def test_safe_serialize_int(self): | ||
| """Test that integers are JSON-serialized.""" | ||
| assert safe_serialize(5) == "5" | ||
|
|
||
| def test_safe_serialize_float(self): | ||
| """Test that floats are JSON-serialized.""" | ||
| assert safe_serialize(3.14) == "3.14" | ||
|
|
||
| def test_safe_serialize_dict(self): | ||
| """Test that dicts are JSON-serialized.""" | ||
| result = safe_serialize({"key": "value"}) | ||
| assert result == '{"key": "value"}' | ||
|
|
||
| def test_safe_serialize_pipeline_variable_parameter_string(self): | ||
| """Test that ParameterString is returned as the PipelineVariable object itself.""" | ||
| param = ParameterString(name="MyParam", default_value="test") | ||
| result = safe_serialize(param) | ||
| # Should return the PipelineVariable object, not raise TypeError | ||
| assert isinstance(result, PipelineVariable) | ||
| assert result is param | ||
|
|
||
| def test_safe_serialize_pipeline_variable_parameter_integer(self): | ||
| """Test that ParameterInteger is returned as the PipelineVariable object itself.""" | ||
| param = ParameterInteger(name="MaxDepth", default_value=5) | ||
| result = safe_serialize(param) | ||
| # Should return the PipelineVariable object, not raise TypeError | ||
| assert isinstance(result, PipelineVariable) | ||
| assert result is param | ||
|
|
||
|
|
||
| class TestGetRepoNameFromImage: | ||
| """Tests for _get_repo_name_from_image handling of PipelineVariable objects.""" | ||
|
|
||
| def test_get_repo_name_from_image_string(self): | ||
| """Test that a normal image URI returns the repo name.""" | ||
| result = _get_repo_name_from_image(_TEST_IMAGE_URI) | ||
| assert result == "sagemaker-xgboost" | ||
|
|
||
| def test_get_repo_name_from_image_pipeline_variable(self): | ||
| """Test that a PipelineVariable returns the placeholder constant.""" | ||
| param = ParameterString( | ||
| name="TrainingImage", default_value="some-image" | ||
| ) | ||
| result = _get_repo_name_from_image(param) | ||
| assert result == _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER | ||
|
|
||
| def test_get_repo_name_from_image_simple_string(self): | ||
| """Test with a simple image name.""" | ||
| result = _get_repo_name_from_image("my-repo:latest") | ||
| assert result == "my-repo" | ||
|
|
||
| def test_get_repo_name_from_image_with_digest(self): | ||
| """Test with an image URI containing a digest.""" | ||
| image = ( | ||
| "123456789012.dkr.ecr.us-west-2.amazonaws.com/" | ||
| "my-repo@sha256:abc123" | ||
| ) | ||
| result = _get_repo_name_from_image(image) | ||
| assert result == "my-repo" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_session(): | ||
| """Create a mock SageMaker session.""" | ||
| session = MagicMock() | ||
| session.boto_region_name = "us-east-1" | ||
| session.default_bucket.return_value = "my-bucket" | ||
| session.default_bucket_prefix = None | ||
| return session | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_train_defaults(): | ||
| """Patch TrainDefaults for ModelTrainer construction.""" | ||
| with patch("sagemaker.train.model_trainer.TrainDefaults") as mock_defaults: | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| mock_defaults.get_sagemaker_session.return_value = MagicMock() | ||
| mock_defaults.get_role.return_value = ( | ||
| "arn:aws:iam::123456789012:role/SageMakerRole" | ||
| ) | ||
| mock_defaults.get_base_job_name.return_value = "test-job" | ||
| mock_defaults.get_compute.return_value = Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 | ||
| ) | ||
| mock_defaults.get_stopping_condition.return_value = MagicMock() | ||
| mock_defaults.get_output_data_config.return_value = MagicMock() | ||
| yield mock_defaults | ||
|
|
||
|
|
||
| class TestModelTrainerValidationWithPipelineVariable: | ||
| """Tests for ModelTrainer validation with PipelineVariable objects.""" | ||
|
|
||
| def test_training_image_accepts_parameter_string( | ||
| self, mock_session, mock_train_defaults | ||
| ): | ||
| """Test that training_image accepts ParameterString.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| param = ParameterString( | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| name="TrainingImage", default_value="some-image-uri" | ||
| ) | ||
|
|
||
| # Should not raise | ||
| trainer = ModelTrainer( | ||
| training_image=param, | ||
| compute=Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 | ||
| ), | ||
| sagemaker_session=mock_session, | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| ) | ||
| assert trainer.training_image is param | ||
|
|
||
| def test_algorithm_name_accepts_parameter_string( | ||
| self, mock_session, mock_train_defaults | ||
| ): | ||
| """Test that algorithm_name accepts ParameterString.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| param = ParameterString( | ||
| name="AlgorithmName", default_value="some-algo" | ||
| ) | ||
|
|
||
| # Should not raise | ||
| trainer = ModelTrainer( | ||
| algorithm_name=param, | ||
| compute=Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 | ||
| ), | ||
| sagemaker_session=mock_session, | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| ) | ||
| assert trainer.algorithm_name is param | ||
|
|
||
| def test_environment_values_accept_parameter_string( | ||
| self, mock_session, mock_train_defaults | ||
| ): | ||
| """Test that environment dict values accept ParameterString.""" | ||
|
aviruthen marked this conversation as resolved.
|
||
| from sagemaker.train.model_trainer import ModelTrainer | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| env_param = ParameterString( | ||
| name="EnvValue", default_value="val" | ||
| ) | ||
|
|
||
| # Should not raise | ||
| trainer = ModelTrainer( | ||
| training_image=_TEST_IMAGE_URI, | ||
| compute=Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 | ||
| ), | ||
| sagemaker_session=mock_session, | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| environment={"MY_VAR": env_param}, | ||
| ) | ||
| assert trainer.environment["MY_VAR"] is env_param | ||
|
|
||
| def test_plain_string_values_still_work( | ||
| self, mock_session, mock_train_defaults | ||
| ): | ||
| """Regression test: plain string values continue to work.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| # Should not raise | ||
| trainer = ModelTrainer( | ||
| training_image=_TEST_IMAGE_URI, | ||
| compute=Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 | ||
| ), | ||
| sagemaker_session=mock_session, | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| ) | ||
| assert trainer.training_image == _TEST_IMAGE_URI | ||
|
|
||
| def test_validation_accepts_pipeline_variable_image_none_algo(self): | ||
| """Test validation accepts PipelineVariable image with None algorithm.""" | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
|
||
| from sagemaker.train.model_trainer import ModelTrainer | ||
|
|
||
| trainer = ModelTrainer.__new__(ModelTrainer) | ||
| param = ParameterString( | ||
| name="Image", default_value="img" | ||
| ) | ||
| # Should not raise | ||
| trainer._validate_training_image_and_algorithm_name( | ||
| param, None | ||
| ) | ||
|
|
||
| def test_validation_accepts_none_image_pipeline_variable_algo(self): | ||
| """Test validation accepts None image with PipelineVariable algorithm.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
|
|
||
| trainer = ModelTrainer.__new__(ModelTrainer) | ||
| param = ParameterString( | ||
| name="Algo", default_value="algo" | ||
| ) | ||
| # Should not raise | ||
| trainer._validate_training_image_and_algorithm_name( | ||
| None, param | ||
| ) | ||
|
|
||
| def test_validation_rejects_no_image_or_algorithm(self): | ||
| """Test that validation rejects when neither is provided.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
|
|
||
| trainer = ModelTrainer.__new__(ModelTrainer) | ||
| with pytest.raises(ValueError, match="Atleast one of"): | ||
| trainer._validate_training_image_and_algorithm_name( | ||
| None, None | ||
| ) | ||
|
|
||
| def test_validation_rejects_both_image_and_algorithm(self): | ||
| """Test that validation rejects when both are provided.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
|
|
||
| trainer = ModelTrainer.__new__(ModelTrainer) | ||
| with pytest.raises(ValueError, match="Only one of"): | ||
| trainer._validate_training_image_and_algorithm_name( | ||
| "image", "algo" | ||
| ) | ||
|
|
||
| def test_validation_rejects_both_pipeline_variables(self): | ||
| """Test that validation rejects when both are PipelineVariables.""" | ||
| from sagemaker.train.model_trainer import ModelTrainer | ||
|
|
||
| trainer = ModelTrainer.__new__(ModelTrainer) | ||
| img_param = ParameterString( | ||
| name="Image", default_value="img" | ||
| ) | ||
| algo_param = ParameterString( | ||
| name="Algo", default_value="algo" | ||
| ) | ||
| with pytest.raises(ValueError, match="Only one of"): | ||
| trainer._validate_training_image_and_algorithm_name( | ||
| img_param, algo_param | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.