Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,24 @@
_DEFAULT_DEFINITION_CFG = PipelineDefinitionConfig(use_custom_job_prefix=False)



def _validate_role_arn(role_arn):
"""Validates that role_arn is either None or a string.

Args:
role_arn: The role ARN value to validate.

Raises:
ValueError: If role_arn is not None and not a string.
"""
if role_arn is not None and not isinstance(role_arn, str):
raise ValueError(
"role_arn must be a string or None, but got {}: {}".format(
type(role_arn).__name__, role_arn
)
)


class Pipeline:
"""Pipeline for workflow."""

Expand Down Expand Up @@ -163,9 +181,10 @@ def create(
that is applied to each of the executions of the pipeline. It takes precedence
over the parallelism configuration of the parent pipeline.

Returns:
Returns:
A response dict from the service.
"""
_validate_role_arn(role_arn)
role_arn = resolve_value_from_config(
role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
)
Expand All @@ -179,6 +198,22 @@ def create(
if parallelism_config:
logger.warning("Pipeline parallelism config is not supported in the local mode.")
return self.sagemaker_session.sagemaker_client.create_pipeline(self, description)
A response dict from the service.
"""
role_arn = resolve_value_from_config(
role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
)
if not role_arn:
# Originally IAM role was a required parameter.
# Now we marked that as Optional because we can fetch it from SageMakerConfig
# Because of marking that parameter as optional, we should validate if it is None, even
# after fetching the config.
raise ValueError("An AWS IAM role is required to create a Pipeline.")
if self.sagemaker_session.local_mode:

if parallelism_config:
logger.warning("Pipeline parallelism config is not supported in the local mode.")
return self.sagemaker_session.sagemaker_client.create_pipeline(self, description)
tags = format_tags(tags)
tags = _append_project_tags(tags)
tags = self.sagemaker_session._append_sagemaker_config_tags(tags, PIPELINE_TAGS_PATH)
Expand Down Expand Up @@ -270,6 +305,7 @@ def update(
Returns:
A response dict from the service.
"""
_validate_role_arn(role_arn)
role_arn = resolve_value_from_config(
role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
)
Expand Down Expand Up @@ -307,6 +343,7 @@ def upsert(
Returns:
response dict from service
"""
_validate_role_arn(role_arn)
role_arn = resolve_value_from_config(
role_arn, PIPELINE_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def role_arn():


def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock):

mock_session = copy.deepcopy(sagemaker_session_mock)
mock_session.sagemaker_config = {}
pipeline = Pipeline(
Expand All @@ -69,6 +70,31 @@ def test_pipeline_create_and_update_without_role_arn(sagemaker_session_mock):
pipeline.upsert()


def test_pipeline_create_with_invalid_role_arn_type(sagemaker_session_mock):
"""Test that create(), update(), and upsert() raise ValueError for non-string role_arn."""
mock_session = copy.deepcopy(sagemaker_session_mock)
mock_session.sagemaker_config = {}
pipeline = Pipeline(
name="MyPipeline",
parameters=[],
steps=[],
sagemaker_session=mock_session,
)
invalid_role_arns = [
{"arn": "arn:aws:iam::111111111111:role/SageMakerRole"},
123,
["arn:aws:iam::111111111111:role/SageMakerRole"],
True,
]
for invalid_role in invalid_role_arns:
with pytest.raises(ValueError, match="role_arn must be a string or None"):
pipeline.create(role_arn=invalid_role)
with pytest.raises(ValueError, match="role_arn must be a string or None"):
pipeline.update(role_arn=invalid_role)
with pytest.raises(ValueError, match="role_arn must be a string or None"):
pipeline.upsert(role_arn=invalid_role)


def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock):
# For tests which doesn't verify config file injection, operate with empty config
pipeline_role_arn = "arn:aws:iam::111111111111:role/ConfigRole"
Expand Down
Loading