Skip to content

Commit 299d61b

Browse files
committed
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672)
1 parent 6a1ba54 commit 299d61b

File tree

2 files changed

+399
-1
lines changed

2 files changed

+399
-1
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,137 @@
8585
logger = logging.getLogger(__name__)
8686

8787

88+
def processing_input_from_local(
89+
input_name: str,
90+
local_path: str,
91+
destination: str,
92+
s3_data_type: str = "S3Prefix",
93+
s3_input_mode: str = "File",
94+
) -> ProcessingInput:
95+
"""Creates a ProcessingInput from a local file or directory path.
96+
97+
This is a convenience factory that makes it clear users can pass local
98+
paths as processing job inputs. The local path is stored in
99+
``ProcessingS3Input.s3_uri`` and will be automatically uploaded to S3
100+
when the processor's ``run()`` method is called.
101+
102+
Args:
103+
input_name: The name for this processing input.
104+
local_path: A local file or directory path to use as input.
105+
This will be uploaded to S3 automatically before the
106+
processing job starts.
107+
destination: The container path where the input data will be
108+
made available (e.g. ``/opt/ml/processing/input/data``).
109+
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
110+
``'ManifestFile'`` (default: ``'S3Prefix'``).
111+
s3_input_mode: The input mode. Valid values: ``'File'``,
112+
``'Pipe'`` (default: ``'File'``).
113+
114+
Returns:
115+
ProcessingInput: A ``ProcessingInput`` object configured with the
116+
local path. The path will be uploaded to S3 during
117+
``Processor.run()``.
118+
119+
Raises:
120+
ValueError: If ``input_name`` or ``local_path`` is empty.
121+
122+
Example:
123+
>>> inp = processing_input_from_local(
124+
... input_name="my-data",
125+
... local_path="/home/user/data/",
126+
... destination="/opt/ml/processing/input/data",
127+
... )
128+
>>> processor.run(inputs=[inp])
129+
"""
130+
if not input_name:
131+
raise ValueError(
132+
f"input_name must be a non-empty string, got: {input_name!r}"
133+
)
134+
if not local_path:
135+
raise ValueError(
136+
f"local_path must be a non-empty string, got: {local_path!r}"
137+
)
138+
return ProcessingInput(
139+
input_name=input_name,
140+
s3_input=ProcessingS3Input(
141+
s3_uri=local_path,
142+
local_path=destination,
143+
s3_data_type=s3_data_type,
144+
s3_input_mode=s3_input_mode,
145+
),
146+
)
147+
148+
149+
def create_processing_input(
150+
source: str,
151+
destination: str,
152+
input_name: str,
153+
s3_data_type: str = "S3Prefix",
154+
s3_input_mode: str = "File",
155+
) -> ProcessingInput:
156+
"""Creates a ProcessingInput from a local path or S3 URI.
157+
158+
This factory provides V2-like ergonomics where users pass a ``source``
159+
parameter that can be either a local file/directory path or an S3 URI.
160+
If ``source`` is a local path (does not start with ``s3://``), it will
161+
be automatically uploaded to S3 when the processor runs.
162+
163+
Args:
164+
source: A local file/directory path or S3 URI. Local paths will
165+
be uploaded to S3 automatically before the processing job
166+
starts.
167+
destination: The container path where the input data will be
168+
made available (e.g. ``/opt/ml/processing/input/data``).
169+
input_name: The name for this processing input.
170+
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
171+
``'ManifestFile'`` (default: ``'S3Prefix'``).
172+
s3_input_mode: The input mode. Valid values: ``'File'``,
173+
``'Pipe'`` (default: ``'File'``).
174+
175+
Returns:
176+
ProcessingInput: A ``ProcessingInput`` object. If ``source`` is a
177+
local path, it is stored in ``ProcessingS3Input.s3_uri`` and
178+
will be uploaded to S3 during ``Processor.run()``.
179+
180+
Raises:
181+
ValueError: If ``source``, ``destination``, or ``input_name`` is empty.
182+
183+
Example:
184+
>>> # Using a local path
185+
>>> inp = create_processing_input(
186+
... source="/home/user/data/",
187+
... destination="/opt/ml/processing/input/data",
188+
... input_name="my-data",
189+
... )
190+
>>> # Using an S3 URI
191+
>>> inp = create_processing_input(
192+
... source="s3://my-bucket/data/",
193+
... destination="/opt/ml/processing/input/data",
194+
... input_name="my-data",
195+
... )
196+
>>> processor.run(inputs=[inp])
197+
"""
198+
if not source:
199+
raise ValueError(f"source must be a non-empty string, got: {source!r}")
200+
if not destination:
201+
raise ValueError(
202+
f"destination must be a non-empty string, got: {destination!r}"
203+
)
204+
if not input_name:
205+
raise ValueError(
206+
f"input_name must be a non-empty string, got: {input_name!r}"
207+
)
208+
return ProcessingInput(
209+
input_name=input_name,
210+
s3_input=ProcessingS3Input(
211+
s3_uri=source,
212+
local_path=destination,
213+
s3_data_type=s3_data_type,
214+
s3_input_mode=s3_input_mode,
215+
),
216+
)
217+
218+
88219
class Processor(object):
89220
"""Handles Amazon SageMaker Processing tasks."""
90221

@@ -238,6 +369,14 @@ def run(
238369
inputs (list[:class:`~sagemaker.core.shapes.ProcessingInput`]): Input files for
239370
the processing job. These must be provided as
240371
:class:`~sagemaker.core.shapes.ProcessingInput` objects (default: None).
372+
373+
.. note::
374+
``ProcessingS3Input.s3_uri`` can accept local file paths in addition
375+
to S3 URIs. Local paths will be automatically uploaded to S3 before
376+
the processing job starts. For clearer intent when using local paths,
377+
consider using the :func:`processing_input_from_local` or
378+
:func:`create_processing_input` convenience factories.
379+
241380
outputs (list[:class:`~sagemaker.core.shapes.ProcessingOutput`]): Outputs for
242381
the processing job. These can be specified as either path strings or
243382
:class:`~sagemaker.core.shapes.ProcessingOutput` objects (default: None).
@@ -401,6 +540,13 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
401540
402541
Raises:
403542
TypeError: if the inputs are not ``ProcessingInput`` objects.
543+
544+
Note:
545+
``ProcessingS3Input.s3_uri`` can accept local file or directory paths
546+
in addition to S3 URIs. Local paths will be automatically uploaded
547+
to S3 before the processing job starts. For clearer intent when
548+
using local paths, consider using :func:`processing_input_from_local`
549+
or :func:`create_processing_input`.
404550
"""
405551
from sagemaker.core.workflow.utilities import _pipeline_config
406552

@@ -413,7 +559,23 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
413559
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
414560
# Generate a name for the ProcessingInput if it doesn't have one.
415561
if file_input.input_name is None:
416-
file_input.input_name = "input-{}".format(count)
562+
file_input.input_name = f"input-{count}"
563+
564+
# Support ad-hoc 'source' attribute for V2-like ergonomics.
565+
# If a user sets file_input.source (e.g. via monkey-patching or
566+
# a subclass), populate s3_input.s3_uri from it so the existing
567+
# upload logic handles it.
568+
_source = getattr(file_input, "source", None)
569+
if _source is not None:
570+
if file_input.s3_input is None:
571+
file_input.s3_input = ProcessingS3Input(
572+
s3_uri=_source,
573+
local_path="/opt/ml/processing/input",
574+
s3_data_type="S3Prefix",
575+
s3_input_mode="File",
576+
)
577+
else:
578+
file_input.s3_input.s3_uri = _source
417579

418580
if file_input.dataset_definition:
419581
normalized_inputs.append(file_input)

0 commit comments

Comments
 (0)