Skip to content

Commit 204ee87

Browse files
committed
fix: address review comments (iteration #1)
1 parent 299d61b commit 204ee87

File tree

2 files changed

+79
-88
lines changed

2 files changed

+79
-88
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 42 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
data pre-processing, post-processing, feature engineering, data validation, and model evaluation,
1717
and interpretation on Amazon SageMaker.
1818
"""
19-
from __future__ import absolute_import
19+
from __future__ import annotations
2020

2121
import json
2222
import logging
2323
import os
2424
import pathlib
2525
import re
26-
from typing import Dict, List, Optional, Union
26+
from typing import Dict, List, Literal, Optional, Union
2727
import time
2828
from copy import copy
2929
from textwrap import dedent
@@ -84,42 +84,37 @@
8484

8585
logger = logging.getLogger(__name__)
8686

87+
DEFAULT_PROCESSING_INPUT_PATH = "/opt/ml/processing/input"
88+
8789

8890
def processing_input_from_local(
8991
input_name: str,
9092
local_path: str,
9193
destination: str,
92-
s3_data_type: str = "S3Prefix",
93-
s3_input_mode: str = "File",
94+
s3_data_type: Literal["S3Prefix", "ManifestFile"] = "S3Prefix",
95+
s3_input_mode: Literal["File", "Pipe"] = "File",
9496
) -> ProcessingInput:
95-
"""Creates a ProcessingInput from a local file or directory path.
97+
"""Create a ProcessingInput from a local file or directory path.
9698
9799
This is a convenience factory that makes it clear users can pass local
98100
paths as processing job inputs. The local path is stored in
99101
``ProcessingS3Input.s3_uri`` and will be automatically uploaded to S3
100102
when the processor's ``run()`` method is called.
101103
102-
Args:
103-
input_name: The name for this processing input.
104-
local_path: A local file or directory path to use as input.
105-
This will be uploaded to S3 automatically before the
106-
processing job starts.
107-
destination: The container path where the input data will be
108-
made available (e.g. ``/opt/ml/processing/input/data``).
109-
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
110-
``'ManifestFile'`` (default: ``'S3Prefix'``).
111-
s3_input_mode: The input mode. Valid values: ``'File'``,
112-
``'Pipe'`` (default: ``'File'``).
113-
114-
Returns:
115-
ProcessingInput: A ``ProcessingInput`` object configured with the
116-
local path. The path will be uploaded to S3 during
117-
``Processor.run()``.
104+
:param input_name: The name for this processing input.
105+
:param local_path: A local file or directory path to use as input.
106+
This will be uploaded to S3 automatically before the
107+
processing job starts.
108+
:param destination: The container path where the input data will be
109+
made available (e.g. ``'/opt/ml/processing/input/data'``).
110+
:param s3_data_type: The S3 data type (default: ``'S3Prefix'``).
111+
:param s3_input_mode: The input mode (default: ``'File'``).
112+
:returns: A ``ProcessingInput`` configured with the local path.
113+
The path will be uploaded to S3 during ``Processor.run()``.
114+
:raises ValueError: If ``input_name``, ``local_path``, or ``destination`` is empty.
118115
119-
Raises:
120-
ValueError: If ``input_name`` or ``local_path`` is empty.
116+
Example::
121117
122-
Example:
123118
>>> inp = processing_input_from_local(
124119
... input_name="my-data",
125120
... local_path="/home/user/data/",
@@ -135,6 +130,10 @@ def processing_input_from_local(
135130
raise ValueError(
136131
f"local_path must be a non-empty string, got: {local_path!r}"
137132
)
133+
if not destination:
134+
raise ValueError(
135+
f"destination must be a non-empty string, got: {destination!r}"
136+
)
138137
return ProcessingInput(
139138
input_name=input_name,
140139
s3_input=ProcessingS3Input(
@@ -149,38 +148,33 @@ def processing_input_from_local(
149148
def create_processing_input(
150149
source: str,
151150
destination: str,
152-
input_name: str,
153-
s3_data_type: str = "S3Prefix",
154-
s3_input_mode: str = "File",
151+
input_name: str | None = None,
152+
s3_data_type: Literal["S3Prefix", "ManifestFile"] = "S3Prefix",
153+
s3_input_mode: Literal["File", "Pipe"] = "File",
155154
) -> ProcessingInput:
156-
"""Creates a ProcessingInput from a local path or S3 URI.
155+
"""Create a ProcessingInput from a local path or S3 URI.
157156
158157
This factory provides V2-like ergonomics where users pass a ``source``
159158
parameter that can be either a local file/directory path or an S3 URI.
160159
If ``source`` is a local path (does not start with ``s3://``), it will
161160
be automatically uploaded to S3 when the processor runs.
162161
163-
Args:
164-
source: A local file/directory path or S3 URI. Local paths will
165-
be uploaded to S3 automatically before the processing job
166-
starts.
167-
destination: The container path where the input data will be
168-
made available (e.g. ``/opt/ml/processing/input/data``).
169-
input_name: The name for this processing input.
170-
s3_data_type: The S3 data type. Valid values: ``'S3Prefix'``,
171-
``'ManifestFile'`` (default: ``'S3Prefix'``).
172-
s3_input_mode: The input mode. Valid values: ``'File'``,
173-
``'Pipe'`` (default: ``'File'``).
174-
175-
Returns:
176-
ProcessingInput: A ``ProcessingInput`` object. If ``source`` is a
177-
local path, it is stored in ``ProcessingS3Input.s3_uri`` and
178-
will be uploaded to S3 during ``Processor.run()``.
179-
180-
Raises:
181-
ValueError: If ``source``, ``destination``, or ``input_name`` is empty.
162+
:param source: A local file/directory path or S3 URI. Local paths will
163+
be uploaded to S3 automatically before the processing job starts.
164+
:param destination: The container path where the input data will be
165+
made available (e.g. ``'/opt/ml/processing/input/data'``).
166+
:param input_name: The name for this processing input. If ``None``,
167+
a name will be auto-generated by ``Processor._normalize_inputs()``
168+
(default: ``None``).
169+
:param s3_data_type: The S3 data type (default: ``'S3Prefix'``).
170+
:param s3_input_mode: The input mode (default: ``'File'``).
171+
:returns: A ``ProcessingInput`` object. If ``source`` is a local path,
172+
it is stored in ``ProcessingS3Input.s3_uri`` and will be uploaded
173+
to S3 during ``Processor.run()``.
174+
:raises ValueError: If ``source`` or ``destination`` is empty.
175+
176+
Example::
182177
183-
Example:
184178
>>> # Using a local path
185179
>>> inp = create_processing_input(
186180
... source="/home/user/data/",
@@ -201,10 +195,6 @@ def create_processing_input(
201195
raise ValueError(
202196
f"destination must be a non-empty string, got: {destination!r}"
203197
)
204-
if not input_name:
205-
raise ValueError(
206-
f"input_name must be a non-empty string, got: {input_name!r}"
207-
)
208198
return ProcessingInput(
209199
input_name=input_name,
210200
s3_input=ProcessingS3Input(
@@ -561,22 +551,6 @@ def _normalize_inputs(self, inputs=None, kms_key=None):
561551
if file_input.input_name is None:
562552
file_input.input_name = f"input-{count}"
563553

564-
# Support ad-hoc 'source' attribute for V2-like ergonomics.
565-
# If a user sets file_input.source (e.g. via monkey-patching or
566-
# a subclass), populate s3_input.s3_uri from it so the existing
567-
# upload logic handles it.
568-
_source = getattr(file_input, "source", None)
569-
if _source is not None:
570-
if file_input.s3_input is None:
571-
file_input.s3_input = ProcessingS3Input(
572-
s3_uri=_source,
573-
local_path="/opt/ml/processing/input",
574-
s3_data_type="S3Prefix",
575-
s3_input_mode="File",
576-
)
577-
else:
578-
file_input.s3_input.s3_uri = _source
579-
580554
if file_input.dataset_definition:
581555
normalized_inputs.append(file_input)
582556
continue

sagemaker-core/tests/unit/test_processing.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -527,24 +527,36 @@ def test_processing_input_from_local_with_custom_data_type_and_mode(self):
527527
assert result.s3_input.s3_data_type == "ManifestFile"
528528
assert result.s3_input.s3_input_mode == "Pipe"
529529

530-
def test_processing_input_from_local_empty_input_name_raises(self):
531-
"""processing_input_from_local should raise ValueError for empty input_name."""
530+
@pytest.mark.parametrize("input_name", ["", None])
531+
def test_processing_input_from_local_invalid_input_name_raises(self, input_name):
532+
"""processing_input_from_local should raise ValueError for empty or None input_name."""
532533
with pytest.raises(ValueError, match="input_name must be a non-empty string"):
533534
processing_input_from_local(
534-
input_name="",
535+
input_name=input_name,
535536
local_path="/tmp/data",
536537
destination="/opt/ml/processing/input",
537538
)
538539

539-
def test_processing_input_from_local_empty_local_path_raises(self):
540-
"""processing_input_from_local should raise ValueError for empty local_path."""
540+
@pytest.mark.parametrize("local_path", ["", None])
541+
def test_processing_input_from_local_invalid_local_path_raises(self, local_path):
542+
"""processing_input_from_local should raise ValueError for empty or None local_path."""
541543
with pytest.raises(ValueError, match="local_path must be a non-empty string"):
542544
processing_input_from_local(
543545
input_name="data",
544-
local_path="",
546+
local_path=local_path,
545547
destination="/opt/ml/processing/input",
546548
)
547549

550+
@pytest.mark.parametrize("destination", ["", None])
551+
def test_processing_input_from_local_invalid_destination_raises(self, destination):
552+
"""processing_input_from_local should raise ValueError for empty or None destination."""
553+
with pytest.raises(ValueError, match="destination must be a non-empty string"):
554+
processing_input_from_local(
555+
input_name="data",
556+
local_path="/tmp/data",
557+
destination=destination,
558+
)
559+
548560
def test_processing_input_from_local_with_pipeline_config_uses_pipeline_s3_path(
549561
self, mock_session
550562
):
@@ -574,6 +586,9 @@ def test_processing_input_from_local_with_pipeline_config_uses_pipeline_s3_path(
574586
) as mock_upload:
575587
result = processor._normalize_inputs([inp])
576588
mock_upload.assert_called_once()
589+
# Verify the session is passed correctly to the upload call
590+
call_kwargs = mock_upload.call_args
591+
assert call_kwargs[1]["sagemaker_session"] == mock_session
577592
assert result[0].s3_input.s3_uri.startswith("s3://")
578593

579594

@@ -614,32 +629,34 @@ def test_create_processing_input_default_s3_data_type_and_input_mode(self):
614629
assert result.s3_input.s3_data_type == "S3Prefix"
615630
assert result.s3_input.s3_input_mode == "File"
616631

617-
def test_create_processing_input_empty_source_raises(self):
618-
"""create_processing_input should raise ValueError for empty source."""
632+
@pytest.mark.parametrize("source", ["", None])
633+
def test_create_processing_input_invalid_source_raises(self, source):
634+
"""create_processing_input should raise ValueError for empty or None source."""
619635
with pytest.raises(ValueError, match="source must be a non-empty string"):
620636
create_processing_input(
621-
source="",
637+
source=source,
622638
destination="/opt/ml/processing/input",
623639
input_name="data",
624640
)
625641

626-
def test_create_processing_input_empty_destination_raises(self):
627-
"""create_processing_input should raise ValueError for empty destination."""
642+
@pytest.mark.parametrize("destination", ["", None])
643+
def test_create_processing_input_invalid_destination_raises(self, destination):
644+
"""create_processing_input should raise ValueError for empty or None destination."""
628645
with pytest.raises(ValueError, match="destination must be a non-empty string"):
629646
create_processing_input(
630647
source="/tmp/data",
631-
destination="",
648+
destination=destination,
632649
input_name="data",
633650
)
634651

635-
def test_create_processing_input_empty_input_name_raises(self):
636-
"""create_processing_input should raise ValueError for empty input_name."""
637-
with pytest.raises(ValueError, match="input_name must be a non-empty string"):
638-
create_processing_input(
639-
source="/tmp/data",
640-
destination="/opt/ml/processing/input",
641-
input_name="",
642-
)
652+
def test_create_processing_input_with_none_input_name_succeeds(self):
653+
"""create_processing_input should succeed with input_name=None (auto-generated)."""
654+
result = create_processing_input(
655+
source="/tmp/data",
656+
destination="/opt/ml/processing/input",
657+
)
658+
assert result.input_name is None
659+
assert result.s3_input.s3_uri == "/tmp/data"
643660

644661

645662
class TestNormalizeInputsLocalPathUpload:

0 commit comments

Comments
 (0)