-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix: ProcessingS3Output's s3_uri to be an optional field (5559)
#5755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,7 +51,7 @@ | |
| ) | ||
| from sagemaker.core.local.local_session import LocalSession | ||
| from sagemaker.core.helper.session_helper import Session | ||
| from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input | ||
| from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input, ProcessingS3Output | ||
| from sagemaker.core.resources import ProcessingJob | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession | ||
| from sagemaker.core.common_utils import ( | ||
|
|
@@ -483,13 +483,23 @@ def _normalize_outputs(self, outputs=None): | |
| # Generate a name for the ProcessingOutput if it doesn't have one. | ||
| if output.output_name is None: | ||
| output.output_name = "output-{}".format(count) | ||
| # If s3_output is None, create a default one with None s3_uri | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing shapes definition change: The PR description states that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. $context sagemaker-core/src/sagemaker/core/shapes/shapes.py |
||
| if output.s3_output is None: | ||
| output.s3_output = ProcessingS3Output( | ||
|
aviruthen marked this conversation as resolved.
|
||
| s3_uri=None, | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri): | ||
|
aviruthen marked this conversation as resolved.
|
||
| normalized_outputs.append(output) | ||
| continue | ||
| # If the output's s3_uri is not an s3_uri, create one. | ||
| parse_result = urlparse(output.s3_output.s3_uri) | ||
| if parse_result.scheme != "s3": | ||
| if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file": | ||
| # If the output's s3_uri is None or not an s3_uri, create one. | ||
| if output.s3_output.s3_uri is None: | ||
| parse_result_scheme = "" | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line length likely exceeds 100 characters: if getattr(self.sagemaker_session, "local_mode", False) and parse_result_scheme == "file":This should be wrapped to stay within the 100-character limit. |
||
| else: | ||
| parse_result_scheme = urlparse(output.s3_output.s3_uri).scheme | ||
| if parse_result_scheme != "s3": | ||
| if getattr(self.sagemaker_session, "local_mode", False) and parse_result_scheme == "file": | ||
| normalized_outputs.append(output) | ||
| continue | ||
| if _pipeline_config: | ||
|
|
@@ -1421,11 +1431,13 @@ def _processing_output_to_request_dict(processing_output): | |
| } | ||
|
|
||
| if processing_output.s3_output: | ||
| request_dict["S3Output"] = { | ||
| "S3Uri": processing_output.s3_output.s3_uri, | ||
| s3_output_dict = { | ||
| "LocalPath": processing_output.s3_output.local_path, | ||
| "S3UploadMode": processing_output.s3_output.s3_upload_mode, | ||
| } | ||
| if processing_output.s3_output.s3_uri is not None: | ||
| s3_output_dict["S3Uri"] = processing_output.s3_output.s3_uri | ||
| request_dict["S3Output"] = s3_output_dict | ||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
| return request_dict | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -493,6 +493,148 @@ def test_multiple_outputs_with_s3_uris_preserved(self, session_local_mode_false) | |
| assert result[1].s3_output.s3_uri == "s3://my-bucket/second" | ||
|
|
||
|
|
||
| class TestProcessingS3OutputOptionalS3Uri: | ||
| """Tests for ProcessingS3Output with optional s3_uri (issue #5559).""" | ||
|
|
||
| def test_processing_s3_output_with_none_s3_uri_creates_successfully(self): | ||
| """Verify ProcessingS3Output can be created with s3_uri=None.""" | ||
| s3_output = ProcessingS3Output( | ||
| s3_uri=None, | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
|
aviruthen marked this conversation as resolved.
|
||
| assert s3_output.s3_uri is None | ||
| assert s3_output.local_path == "/opt/ml/processing/output" | ||
| assert s3_output.s3_upload_mode == "EndOfJob" | ||
|
|
||
| def test_processing_s3_output_without_s3_uri_param_creates_successfully(self): | ||
| """Verify ProcessingS3Output works with default None for s3_uri.""" | ||
| s3_output = ProcessingS3Output( | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| assert s3_output.s3_uri is None | ||
|
|
||
| def test_normalize_outputs_with_none_s3_uri_generates_s3_path(self, mock_session): | ||
| """When s3_uri is None, _normalize_outputs should auto-generate an S3 URI.""" | ||
|
aviruthen marked this conversation as resolved.
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good test coverage! However, the |
||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_output = ProcessingS3Output( | ||
| s3_uri=None, | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)] | ||
|
|
||
| with patch("sagemaker.core.workflow.utilities._pipeline_config", None): | ||
| result = processor._normalize_outputs(outputs) | ||
|
|
||
| assert len(result) == 1 | ||
| assert result[0].s3_output.s3_uri is not None | ||
| assert result[0].s3_output.s3_uri.startswith("s3://") | ||
| assert "test-job" in result[0].s3_output.s3_uri | ||
| assert "my-output" in result[0].s3_output.s3_uri | ||
|
|
||
| def test_normalize_outputs_with_none_s3_uri_and_pipeline_config_generates_join(self, mock_session): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line exceeds 100 characters. Please wrap this line: def test_normalize_outputs_with_none_s3_uri_and_pipeline_config_generates_join(
self, mock_session
): |
||
| """When in pipeline context with s3_uri=None, should generate a Join expression.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_output = ProcessingS3Output( | ||
|
aviruthen marked this conversation as resolved.
|
||
| s3_uri=None, | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)] | ||
|
|
||
| with patch("sagemaker.core.workflow.utilities._pipeline_config") as mock_config: | ||
| mock_config.pipeline_name = "test-pipeline" | ||
| mock_config.step_name = "test-step" | ||
| result = processor._normalize_outputs(outputs) | ||
|
|
||
| assert len(result) == 1 | ||
| # In pipeline context, the s3_uri should be a Join object | ||
| from sagemaker.core.workflow.functions import Join | ||
| assert isinstance(result[0].s3_output.s3_uri, Join) | ||
|
|
||
| def test_normalize_outputs_with_none_s3_output_generates_s3_path(self, mock_session): | ||
| """When s3_output is None, _normalize_outputs should create s3_output and auto-generate URI.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import of |
||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| outputs = [ProcessingOutput(output_name="my-output")] | ||
|
|
||
| with patch("sagemaker.core.workflow.utilities._pipeline_config", None): | ||
| result = processor._normalize_outputs(outputs) | ||
|
|
||
| assert len(result) == 1 | ||
| assert result[0].s3_output is not None | ||
| assert result[0].s3_output.s3_uri is not None | ||
| assert result[0].s3_output.s3_uri.startswith("s3://") | ||
| assert result[0].s3_output.local_path == "/opt/ml/processing/output" | ||
| assert result[0].s3_output.s3_upload_mode == "EndOfJob" | ||
|
|
||
| def test_processing_output_to_request_dict_with_none_s3_uri_omits_key(self): | ||
| """When s3_uri is None, S3Uri should be omitted from the request dict.""" | ||
| s3_output = ProcessingS3Output( | ||
| s3_uri=None, | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| processing_output = ProcessingOutput(output_name="results", s3_output=s3_output) | ||
|
|
||
| result = _processing_output_to_request_dict(processing_output) | ||
|
|
||
| assert result["OutputName"] == "results" | ||
| assert "S3Output" in result | ||
| assert "S3Uri" not in result["S3Output"] | ||
| assert result["S3Output"]["LocalPath"] == "/opt/ml/processing/output" | ||
| assert result["S3Output"]["S3UploadMode"] == "EndOfJob" | ||
|
|
||
| def test_normalize_outputs_with_explicit_s3_uri_unchanged(self, mock_session): | ||
| """Regression test: explicit s3:// URIs should be preserved.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", | ||
| image_uri="test-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
| processor._current_job_name = "test-job" | ||
|
|
||
| s3_output = ProcessingS3Output( | ||
| s3_uri="s3://my-bucket/my-output", | ||
| local_path="/opt/ml/processing/output", | ||
| s3_upload_mode="EndOfJob", | ||
| ) | ||
| outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)] | ||
|
|
||
| with patch("sagemaker.core.workflow.utilities._pipeline_config", None): | ||
| result = processor._normalize_outputs(outputs) | ||
|
|
||
| assert len(result) == 1 | ||
| assert result[0].s3_output.s3_uri == "s3://my-bucket/my-output" | ||
|
|
||
|
|
||
| class TestProcessorStartNew: | ||
| def test_start_new_with_pipeline_session(self, mock_session): | ||
| from sagemaker.core.workflow.pipeline_context import PipelineSession | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import line exceeds the 100-character line length limit. Please break it into multiple lines: