Skip to content

Commit 96f4200

Browse files
author
qxz3ea6
committed
Fix PySparkProcessor V3 ProcessingInput construction
1 parent 215713f commit 96f4200

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

sagemaker-core/src/sagemaker/core/spark/processing.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
from sagemaker.core import image_uris
3838
from sagemaker.core import s3
3939
from sagemaker.core.local.image import _ecr_login_if_needed, _pull_image
40-
from sagemaker.core.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
40+
from sagemaker.core.processing import (
41+
ProcessingInput,
42+
ProcessingOutput,
43+
ProcessingS3Input,
44+
ScriptProcessor,
45+
)
4146
from sagemaker.core.s3 import S3Uploader
4247
from sagemaker.core.helper.session_helper import Session
4348
from sagemaker.core.network import NetworkConfig
@@ -52,6 +57,18 @@
5257
logger = logging.getLogger(__name__)
5358

5459

60+
def _make_s3_processing_input(input_name: str, s3_uri: str, local_path: str) -> ProcessingInput:
61+
"""Build a V3-compatible ProcessingInput backed by an S3 channel."""
62+
return ProcessingInput(
63+
input_name=input_name,
64+
s3_input=ProcessingS3Input(
65+
s3_uri=s3_uri,
66+
s3_data_type="S3Prefix",
67+
local_path=local_path,
68+
),
69+
)
70+
71+
5572
class _SparkProcessorBase(ScriptProcessor):
5673
"""Handles Amazon SageMaker processing tasks for jobs using Spark.
5774
@@ -404,10 +421,10 @@ def _stage_configuration(self, configuration):
404421
sagemaker_session=self.sagemaker_session,
405422
)
406423

407-
conf_input = ProcessingInput(
408-
source=s3_uri,
409-
destination=f"{self._conf_container_base_path}{self._conf_container_input_name}",
410-
input_name=_SparkProcessorBase._conf_container_input_name,
424+
conf_input = _make_s3_processing_input(
425+
input_name=self._conf_container_input_name,
426+
s3_uri=s3_uri,
427+
local_path=f"{self._conf_container_base_path}{self._conf_container_input_name}",
411428
)
412429
return conf_input
413430

@@ -505,15 +522,16 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
505522
# them to the Spark container and form the spark-submit option from a
506523
# combination of S3 URIs and container's local input path
507524
if use_input_channel:
508-
input_channel = ProcessingInput(
509-
source=input_channel_s3_uri,
510-
destination=f"{self._conf_container_base_path}{input_channel_name}",
525+
local_path = f"{self._conf_container_base_path}{input_channel_name}"
526+
input_channel = _make_s3_processing_input(
511527
input_name=input_channel_name,
528+
s3_uri=input_channel_s3_uri,
529+
local_path=local_path,
512530
)
513531
spark_opt = (
514-
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
532+
Join(on=",", values=spark_opt_s3_uris + [local_path])
515533
if spark_opt_s3_uris_has_pipeline_var
516-
else ",".join(spark_opt_s3_uris + [input_channel.destination])
534+
else ",".join(spark_opt_s3_uris + [local_path])
517535
)
518536
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
519537
else:
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from unittest.mock import Mock, patch
15+
16+
import pytest
17+
18+
from sagemaker.core.spark.processing import PySparkProcessor
19+
20+
21+
@pytest.fixture
22+
def mock_session():
23+
session = Mock()
24+
session.boto_session = Mock()
25+
session.boto_session.region_name = "us-west-2"
26+
session.boto_region_name = "us-west-2"
27+
session.sagemaker_client = Mock()
28+
session.default_bucket = Mock(return_value="test-bucket")
29+
session.default_bucket_prefix = "sagemaker"
30+
session.expand_role = Mock(side_effect=lambda x: x)
31+
session.sagemaker_config = {}
32+
return session
33+
34+
35+
def _make_processor(mock_session):
36+
processor = PySparkProcessor(
37+
role="arn:aws:iam::123456789012:role/SageMakerRole",
38+
image_uri="test-image:latest",
39+
instance_count=1,
40+
instance_type="ml.m5.xlarge",
41+
sagemaker_session=mock_session,
42+
)
43+
processor._current_job_name = "test-job"
44+
return processor
45+
46+
47+
class TestPySparkProcessorV3ProcessingInputs:
48+
@patch("sagemaker.core.spark.processing.S3Uploader.upload_string_as_file_body")
49+
def test_stage_configuration_builds_v3_processing_input(self, mock_upload, mock_session):
50+
processor = _make_processor(mock_session)
51+
52+
config_input = processor._stage_configuration(
53+
[{"Classification": "spark-defaults", "Properties": {"spark.app.name": "test"}}]
54+
)
55+
56+
mock_upload.assert_called_once()
57+
assert config_input.input_name == processor._conf_container_input_name
58+
assert config_input.s3_input.s3_uri == (
59+
"s3://test-bucket/sagemaker/test-job/input/conf/configuration.json"
60+
)
61+
assert config_input.s3_input.local_path == "/opt/ml/processing/input/conf"
62+
63+
@patch("sagemaker.core.spark.processing.S3Uploader.upload")
64+
def test_stage_submit_deps_builds_v3_processing_input_for_local_dependencies(
65+
self, mock_upload, mock_session, tmp_path
66+
):
67+
processor = _make_processor(mock_session)
68+
dep_file = tmp_path / "dep.py"
69+
dep_file.write_text("print('dep')", encoding="utf-8")
70+
71+
input_channel, spark_opt = processor._stage_submit_deps(
72+
[str(dep_file)], processor._submit_py_files_input_channel_name
73+
)
74+
75+
mock_upload.assert_called_once()
76+
assert input_channel.input_name == processor._submit_py_files_input_channel_name
77+
assert input_channel.s3_input.s3_uri == "s3://test-bucket/sagemaker/test-job/input/py-files"
78+
assert input_channel.s3_input.local_path == "/opt/ml/processing/input/py-files"
79+
assert spark_opt == "/opt/ml/processing/input/py-files"

0 commit comments

Comments
 (0)