Skip to content

Commit 4acc234

Browse files
committed
fix: ProcessingS3Output's s3_uri to be an optional field (5559)
1 parent daf19b0 commit 4acc234

File tree

2 files changed

+161
-7
lines changed

2 files changed

+161
-7
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
from sagemaker.core.local.local_session import LocalSession
5353
from sagemaker.core.helper.session_helper import Session
54-
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input
54+
from sagemaker.core.shapes import ProcessingInput, ProcessingOutput, ProcessingS3Input, ProcessingS3Output
5555
from sagemaker.core.resources import ProcessingJob
5656
from sagemaker.core.workflow.pipeline_context import PipelineSession
5757
from sagemaker.core.common_utils import (
@@ -483,13 +483,23 @@ def _normalize_outputs(self, outputs=None):
483483
# Generate a name for the ProcessingOutput if it doesn't have one.
484484
if output.output_name is None:
485485
output.output_name = "output-{}".format(count)
486+
# If s3_output is None, create a default one with None s3_uri
487+
if output.s3_output is None:
488+
output.s3_output = ProcessingS3Output(
489+
s3_uri=None,
490+
local_path="/opt/ml/processing/output",
491+
s3_upload_mode="EndOfJob",
492+
)
486493
if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri):
487494
normalized_outputs.append(output)
488495
continue
489-
# If the output's s3_uri is not an s3_uri, create one.
490-
parse_result = urlparse(output.s3_output.s3_uri)
491-
if parse_result.scheme != "s3":
492-
if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file":
496+
# If the output's s3_uri is None or not an s3_uri, create one.
497+
if output.s3_output.s3_uri is None:
498+
parse_result_scheme = ""
499+
else:
500+
parse_result_scheme = urlparse(output.s3_output.s3_uri).scheme
501+
if parse_result_scheme != "s3":
502+
if getattr(self.sagemaker_session, "local_mode", False) and parse_result_scheme == "file":
493503
normalized_outputs.append(output)
494504
continue
495505
if _pipeline_config:
@@ -1421,11 +1431,13 @@ def _processing_output_to_request_dict(processing_output):
14211431
}
14221432

14231433
if processing_output.s3_output:
1424-
request_dict["S3Output"] = {
1425-
"S3Uri": processing_output.s3_output.s3_uri,
1434+
s3_output_dict = {
14261435
"LocalPath": processing_output.s3_output.local_path,
14271436
"S3UploadMode": processing_output.s3_output.s3_upload_mode,
14281437
}
1438+
if processing_output.s3_output.s3_uri is not None:
1439+
s3_output_dict["S3Uri"] = processing_output.s3_output.s3_uri
1440+
request_dict["S3Output"] = s3_output_dict
14291441

14301442
return request_dict
14311443

sagemaker-core/tests/unit/test_processing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,148 @@ def test_multiple_outputs_with_s3_uris_preserved(self, session_local_mode_false)
493493
assert result[1].s3_output.s3_uri == "s3://my-bucket/second"
494494

495495

496+
class TestProcessingS3OutputOptionalS3Uri:
497+
"""Tests for ProcessingS3Output with optional s3_uri (issue #5559)."""
498+
499+
def test_processing_s3_output_with_none_s3_uri_creates_successfully(self):
500+
"""Verify ProcessingS3Output can be created with s3_uri=None."""
501+
s3_output = ProcessingS3Output(
502+
s3_uri=None,
503+
local_path="/opt/ml/processing/output",
504+
s3_upload_mode="EndOfJob",
505+
)
506+
assert s3_output.s3_uri is None
507+
assert s3_output.local_path == "/opt/ml/processing/output"
508+
assert s3_output.s3_upload_mode == "EndOfJob"
509+
510+
def test_processing_s3_output_without_s3_uri_param_creates_successfully(self):
511+
"""Verify ProcessingS3Output works with default None for s3_uri."""
512+
s3_output = ProcessingS3Output(
513+
local_path="/opt/ml/processing/output",
514+
s3_upload_mode="EndOfJob",
515+
)
516+
assert s3_output.s3_uri is None
517+
518+
def test_normalize_outputs_with_none_s3_uri_generates_s3_path(self, mock_session):
519+
"""When s3_uri is None, _normalize_outputs should auto-generate an S3 URI."""
520+
processor = Processor(
521+
role="arn:aws:iam::123456789012:role/SageMakerRole",
522+
image_uri="test-image:latest",
523+
instance_count=1,
524+
instance_type="ml.m5.xlarge",
525+
sagemaker_session=mock_session,
526+
)
527+
processor._current_job_name = "test-job"
528+
529+
s3_output = ProcessingS3Output(
530+
s3_uri=None,
531+
local_path="/opt/ml/processing/output",
532+
s3_upload_mode="EndOfJob",
533+
)
534+
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]
535+
536+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
537+
result = processor._normalize_outputs(outputs)
538+
539+
assert len(result) == 1
540+
assert result[0].s3_output.s3_uri is not None
541+
assert result[0].s3_output.s3_uri.startswith("s3://")
542+
assert "test-job" in result[0].s3_output.s3_uri
543+
assert "my-output" in result[0].s3_output.s3_uri
544+
545+
def test_normalize_outputs_with_none_s3_uri_and_pipeline_config_generates_join(self, mock_session):
546+
"""When in pipeline context with s3_uri=None, should generate a Join expression."""
547+
processor = Processor(
548+
role="arn:aws:iam::123456789012:role/SageMakerRole",
549+
image_uri="test-image:latest",
550+
instance_count=1,
551+
instance_type="ml.m5.xlarge",
552+
sagemaker_session=mock_session,
553+
)
554+
processor._current_job_name = "test-job"
555+
556+
s3_output = ProcessingS3Output(
557+
s3_uri=None,
558+
local_path="/opt/ml/processing/output",
559+
s3_upload_mode="EndOfJob",
560+
)
561+
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]
562+
563+
with patch("sagemaker.core.workflow.utilities._pipeline_config") as mock_config:
564+
mock_config.pipeline_name = "test-pipeline"
565+
mock_config.step_name = "test-step"
566+
result = processor._normalize_outputs(outputs)
567+
568+
assert len(result) == 1
569+
# In pipeline context, the s3_uri should be a Join object
570+
from sagemaker.core.workflow.functions import Join
571+
assert isinstance(result[0].s3_output.s3_uri, Join)
572+
573+
def test_normalize_outputs_with_none_s3_output_generates_s3_path(self, mock_session):
574+
"""When s3_output is None, _normalize_outputs should create s3_output and auto-generate URI."""
575+
processor = Processor(
576+
role="arn:aws:iam::123456789012:role/SageMakerRole",
577+
image_uri="test-image:latest",
578+
instance_count=1,
579+
instance_type="ml.m5.xlarge",
580+
sagemaker_session=mock_session,
581+
)
582+
processor._current_job_name = "test-job"
583+
584+
outputs = [ProcessingOutput(output_name="my-output")]
585+
586+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
587+
result = processor._normalize_outputs(outputs)
588+
589+
assert len(result) == 1
590+
assert result[0].s3_output is not None
591+
assert result[0].s3_output.s3_uri is not None
592+
assert result[0].s3_output.s3_uri.startswith("s3://")
593+
assert result[0].s3_output.local_path == "/opt/ml/processing/output"
594+
assert result[0].s3_output.s3_upload_mode == "EndOfJob"
595+
596+
def test_processing_output_to_request_dict_with_none_s3_uri_omits_key(self):
597+
"""When s3_uri is None, S3Uri should be omitted from the request dict."""
598+
s3_output = ProcessingS3Output(
599+
s3_uri=None,
600+
local_path="/opt/ml/processing/output",
601+
s3_upload_mode="EndOfJob",
602+
)
603+
processing_output = ProcessingOutput(output_name="results", s3_output=s3_output)
604+
605+
result = _processing_output_to_request_dict(processing_output)
606+
607+
assert result["OutputName"] == "results"
608+
assert "S3Output" in result
609+
assert "S3Uri" not in result["S3Output"]
610+
assert result["S3Output"]["LocalPath"] == "/opt/ml/processing/output"
611+
assert result["S3Output"]["S3UploadMode"] == "EndOfJob"
612+
613+
def test_normalize_outputs_with_explicit_s3_uri_unchanged(self, mock_session):
614+
"""Regression test: explicit s3:// URIs should be preserved."""
615+
processor = Processor(
616+
role="arn:aws:iam::123456789012:role/SageMakerRole",
617+
image_uri="test-image:latest",
618+
instance_count=1,
619+
instance_type="ml.m5.xlarge",
620+
sagemaker_session=mock_session,
621+
)
622+
processor._current_job_name = "test-job"
623+
624+
s3_output = ProcessingS3Output(
625+
s3_uri="s3://my-bucket/my-output",
626+
local_path="/opt/ml/processing/output",
627+
s3_upload_mode="EndOfJob",
628+
)
629+
outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)]
630+
631+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
632+
result = processor._normalize_outputs(outputs)
633+
634+
assert len(result) == 1
635+
assert result[0].s3_output.s3_uri == "s3://my-bucket/my-output"
636+
637+
496638
class TestProcessorStartNew:
497639
def test_start_new_with_pipeline_session(self, mock_session):
498640
from sagemaker.core.workflow.pipeline_context import PipelineSession

0 commit comments

Comments
 (0)