Skip to content

Commit 94e95e0

Browse files
committed
fix: address review comments (iteration #1)
1 parent 333e5b0 commit 94e95e0

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,16 @@ def _normalize_outputs(self, outputs=None):
483483
# Generate a name for the ProcessingOutput if it doesn't have one.
484484
if output.output_name is None:
485485
output.output_name = "output-{}".format(count)
486-
if output.s3_output and output.s3_output.s3_uri is not None and is_pipeline_variable(output.s3_output.s3_uri):
486+
if (
487+
output.s3_output
488+
and output.s3_output.s3_uri is not None
489+
and is_pipeline_variable(output.s3_output.s3_uri)
490+
):
487491
normalized_outputs.append(output)
488492
continue
489-
# If s3_output is None or s3_uri is None, auto-generate an S3 URI
493+
# If s3_output is None or s3_uri is None, auto-generate
494+
# an S3 URI (V2 parity: destination=None delegates to
495+
# SageMaker).
490496
if not output.s3_output or output.s3_output.s3_uri is None:
491497
if _pipeline_config:
492498
s3_uri = Join(
@@ -495,7 +501,6 @@ def _normalize_outputs(self, outputs=None):
495501
"s3:/",
496502
self.sagemaker_session.default_bucket(),
497503
*(
498-
# don't include default_bucket_prefix if it is None or ""
499504
[self.sagemaker_session.default_bucket_prefix]
500505
if self.sagemaker_session.default_bucket_prefix
501506
else []
@@ -517,20 +522,32 @@ def _normalize_outputs(self, outputs=None):
517522
output.output_name,
518523
)
519524
if output.s3_output:
525+
# s3_output exists but s3_uri is None
520526
output.s3_output.s3_uri = s3_uri
521527
else:
522-
from sagemaker.core.shapes import ProcessingS3Output as _ProcessingS3Output
528+
# s3_output is None — create a new one with
529+
# sensible defaults.
530+
# Import here to avoid circular import with
531+
# shapes module.
532+
from sagemaker.core.shapes import (
533+
ProcessingS3Output as _ProcessingS3Output,
534+
)
523535
output.s3_output = _ProcessingS3Output(
524536
s3_uri=s3_uri,
525-
local_path=output.s3_output.local_path if output.s3_output else "/opt/ml/processing/output",
537+
local_path="/opt/ml/processing/output",
526538
s3_upload_mode="EndOfJob",
527539
)
528540
normalized_outputs.append(output)
529541
continue
530542
# If the output's s3_uri is not an s3_uri, create one.
531543
parse_result = urlparse(output.s3_output.s3_uri)
532544
if parse_result.scheme != "s3":
533-
if getattr(self.sagemaker_session, "local_mode", False) and parse_result.scheme == "file":
545+
if (
546+
getattr(
547+
self.sagemaker_session, "local_mode", False
548+
)
549+
and parse_result.scheme == "file"
550+
):
534551
normalized_outputs.append(output)
535552
continue
536553
if _pipeline_config:

sagemaker-core/tests/unit/test_processing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,8 @@ def test_normalize_outputs_with_none_s3_uri_generates_s3_path(self, mock_session
14361436

14371437
def test_normalize_outputs_with_none_s3_uri_and_pipeline_config(self, mock_session):
14381438
"""When s3_uri is None and pipeline_config is set, use pipeline-based path."""
1439+
from sagemaker.core.workflow.functions import Join
1440+
14391441
processor = Processor(
14401442
role="arn:aws:iam::123456789012:role/SageMakerRole",
14411443
image_uri="test-image:latest",
@@ -1459,7 +1461,11 @@ def test_normalize_outputs_with_none_s3_uri_and_pipeline_config(self, mock_sessi
14591461

14601462
assert len(result) == 1
14611463
# The result should be a Join object (pipeline variable) when pipeline_config is set
1462-
assert result[0].s3_output.s3_uri is not None
1464+
assert isinstance(result[0].s3_output.s3_uri, Join)
1465+
# Verify the Join contains expected pipeline-related values
1466+
join_obj = result[0].s3_output.s3_uri
1467+
assert join_obj.on == "/"
1468+
assert "test-pipeline" in join_obj.values
14631469

14641470
def test_normalize_outputs_with_none_s3_uri_auto_generates_name(self, mock_session):
14651471
"""When output_name is None and s3_uri is None, both should be auto-generated."""
@@ -1488,6 +1494,29 @@ def test_normalize_outputs_with_none_s3_uri_auto_generates_name(self, mock_sessi
14881494
assert generated_uri.startswith("s3://")
14891495
assert "output-1" in generated_uri
14901496

1497+
def test_normalize_outputs_with_no_s3_output_at_all(self, mock_session):
1498+
"""When s3_output is None entirely, a new ProcessingS3Output is created."""
1499+
processor = Processor(
1500+
role="arn:aws:iam::123456789012:role/SageMakerRole",
1501+
image_uri="test-image:latest",
1502+
instance_count=1,
1503+
instance_type="ml.m5.xlarge",
1504+
sagemaker_session=mock_session,
1505+
)
1506+
processor._current_job_name = "test-job"
1507+
1508+
outputs = [ProcessingOutput(output_name="my-output")]
1509+
1510+
with patch("sagemaker.core.workflow.utilities._pipeline_config", None):
1511+
result = processor._normalize_outputs(outputs)
1512+
1513+
assert len(result) == 1
1514+
assert result[0].s3_output is not None
1515+
assert result[0].s3_output.s3_uri.startswith("s3://")
1516+
assert result[0].s3_output.local_path == "/opt/ml/processing/output"
1517+
assert result[0].s3_output.s3_upload_mode == "EndOfJob"
1518+
assert "my-output" in result[0].s3_output.s3_uri
1519+
14911520
def test_processing_output_to_request_dict_omits_s3_uri_when_none(self):
14921521
"""Verify _processing_output_to_request_dict omits S3Uri when s3_uri is None."""
14931522
s3_output = ProcessingS3Output(

0 commit comments

Comments
 (0)