fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672)#31
fix: Support local source for sagemaker.core.shapes.ProcessingInput (5672)#31
Conversation
mufaddal-rohawala
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR adds a processing_input_from_local() factory function to support local file paths for ProcessingInput in SageMaker v3, addressing a migration gap from v2. The implementation is generally sound with good test coverage, but there are several issues: the function name could be more idiomatic, there's a type annotation issue with the source parameter accepting None but being typed as str, and the destination parameter mapping to local_path on ProcessingS3Input is confusing and potentially incorrect.
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def processing_input_from_local( |
There was a problem hiding this comment.
The source parameter is typed as str but the function explicitly handles None (line 155). The type annotation should be str | None to match the actual behavior, or better yet, keep it as str and remove the None handling since None is not a meaningful input.
def processing_input_from_local(
source: str | None,
...Alternatively, if you want to keep str as the type, remove the None check and let it fail naturally with a TypeError — that's more Pythonic for a required parameter.
| if not os.path.exists(local_path): | ||
| raise ValueError( | ||
| f"Input source path does not exist: {source!r}. " | ||
| f"Please provide a valid local path or S3 URI." |
There was a problem hiding this comment.
The mapping of destination → local_path on ProcessingS3Input is semantically confusing and potentially incorrect. In the ProcessingS3Input shape, local_path typically refers to the container path where data is made available, while s3_uri is the S3 location. However, here you're setting local_path=destination which is the container path — this seems correct for the SageMaker API, but the parameter name local_path on ProcessingS3Input is misleading in this context. Please add a comment clarifying this mapping:
s3_input_kwargs = {
"s3_uri": source,
# local_path in ProcessingS3Input maps to the container destination path
"local_path": destination,
"s3_data_type": s3_data_type,
"s3_input_mode": s3_input_mode,
}| """ | ||
| if not source: | ||
| raise ValueError( | ||
| f"source must be a valid local path or S3 URI, got: {source!r}" |
There was a problem hiding this comment.
The f-string formatting here uses !r which is good, but the line exceeds 100 characters. While CI will catch this, consider reformatting:
raise ValueError(
f"source must be a valid local path or S3 URI, got: {source!r}"
)Actually, looking more carefully, this is already formatted this way. Disregard the formatting concern — but note that this validation doesn't distinguish between None and empty string in the error message. Consider separate messages for better debugging.
| local_source = url2pathname(parse_result.path) | ||
| if not os.path.exists(local_source): | ||
| raise ValueError( | ||
| f"Input source path does not exist: {file_input.s3_input.s3_uri!r}. " |
There was a problem hiding this comment.
The logger.info() call at line 562 is placed after the desired_s3_uri is computed in the non-pipeline branch, but the variable desired_s3_uri is set inside two different conditional branches (if _pipeline_config and else). If _pipeline_config is truthy, desired_s3_uri is set in that branch; otherwise in the else branch. The logger.info() at line 562 is placed after the else block but still inside the if parse_result.scheme != "s3": block. This means it will only execute for the non-pipeline path. For the pipeline path, the log message won't be emitted. Move the logger.info() after both branches to ensure it logs in both cases:
# After both if/else branches for _pipeline_config
logger.info(
"Uploading local input '%s' from %s to %s",
file_input.input_name,
file_input.s3_input.s3_uri,
desired_s3_uri,
)Wait — looking at the diff more carefully, the logger.info is placed right before s3_uri = s3.S3Uploader.upload(...) which is after both branches. This should be correct. Please verify the indentation level is correct and that it's inside the if parse_result.scheme != "s3": block but after both the pipeline and non-pipeline branches.
|
|
||
| # Check if source is a local path (not an S3 URI) | ||
| parse_result = urlparse(source) | ||
| if parse_result.scheme not in ("s3", "http", "https"): |
There was a problem hiding this comment.
The URL scheme check parse_result.scheme not in ("s3", "http", "https") doesn't account for other valid remote schemes that might be used (e.g., gs:// for GCS if ever supported). More importantly, on Windows, absolute paths like C:\data\file.csv will have parse_result.scheme == 'c', which would correctly fall into the local path branch. However, this should be documented or tested. Consider adding a comment about this behavior.
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def processing_input_from_local( |
There was a problem hiding this comment.
Per SDK V3 conventions, user-facing data/config classes should use Pydantic BaseModel. While a factory function is a reasonable approach here (since ProcessingInput is auto-generated), consider whether a thin Pydantic wrapper class like LocalProcessingInput(BaseModel) with a .to_processing_input() method would be more consistent with V3 patterns. This is a design suggestion, not a blocker — the factory function approach is pragmatic given that ProcessingInput is auto-generated.
| assert isinstance(result, ProcessingInput) | ||
| assert result.input_name == "my-data" | ||
| assert result.s3_input.s3_uri == temp_file | ||
| assert result.s3_input.local_path == "/opt/ml/processing/input/data" |
There was a problem hiding this comment.
Good use of tempfile.NamedTemporaryFile with proper cleanup in finally. However, consider using tempfile.TemporaryDirectory with a file created inside it as a fixture to reduce boilerplate across tests. Also, delete=False on Windows can leave files behind if the test crashes before finally. A pytest.fixture with tmp_path would be cleaner:
def test_processing_input_from_local_with_file_path(self, tmp_path):
temp_file = tmp_path / "data.csv"
temp_file.write_text("col1,col2\n1,2\n")
result = processing_input_from_local(
source=str(temp_file),
destination="/opt/ml/processing/input/data",
input_name="my-data",
)
...| def test_processing_input_from_local_used_in_normalize_inputs(self, mock_session): | ||
| """ProcessingInput from processing_input_from_local should work with _normalize_inputs.""" | ||
| processor = Processor( | ||
| role="arn:aws:iam::123456789012:role/SageMakerRole", |
There was a problem hiding this comment.
This test patches sagemaker.core.s3.S3Uploader.upload but the actual code in _normalize_inputs calls s3.S3Uploader.upload where s3 is imported as from sagemaker.core import s3. The patch target should be sagemaker.core.processing.s3.S3Uploader.upload (patching where it's used, not where it's defined) to ensure the mock is applied correctly. Please verify this works — if the import in processing.py is from sagemaker.core import s3, then the correct patch target is sagemaker.core.processing.s3.S3Uploader.upload.
| s3_uri=tmpdir, | ||
| local_path="/opt/ml/processing/input", | ||
| s3_data_type="S3Prefix", | ||
| s3_input_mode="File", |
There was a problem hiding this comment.
The f-string in the assertion f"s3://test-bucket/sagemaker/test-job/input/local-input" hardcodes the expected S3 path structure. This is brittle — if the S3 path construction logic changes in _normalize_inputs, this test will break. Consider using mock_logger.info.assert_called() and then inspecting the call args more flexibly, or at minimum add a comment explaining the expected path structure.
Description
Support local source for ProcessingInput (GitHub Issue 5672)
Problem
When migrating from SageMaker v2 to v3,
sagemaker.core.shapes.ProcessingInputonly accepts S3 paths, requiring users to manually upload local files to S3 before creating processing jobs. In v2,sagemaker.processing.ProcessingInputhad asourceparameter that accepted local paths and handled the upload automatically.Solution
This PR adds a convenience factory function
processing_input_from_local()that provides V2-like ergonomics for creatingProcessingInputobjects from local file/directory paths:Changes
sagemaker-core/src/sagemaker/core/processing.pyprocessing_input_from_local()factory function: Creates aProcessingInputfrom a local file/directory path or S3 URI. Validates that local paths exist and constructs the properProcessingInputwithProcessingS3Input. The local path is automatically uploaded to S3 during_normalize_inputs()when the processing job runs._normalize_inputs(): Before attempting to upload, validates that the local path exists and raises a clearValueErrorif not._normalize_inputs(): Logs an info message when a local path is detected and will be uploaded to S3, improving visibility of the automatic upload behavior.sagemaker-core/tests/unit/test_processing.pyTestProcessingInputFromLocaltest class with comprehensive tests for the factory functionTestNormalizeInputsLocalPathValidationtest class for the enhanced validation and logging behavior_normalize_inputs(), and logging verificationRelated Issue
Related issue: 5672
Changes Made
The GitHub issue requests that
sagemaker.core.shapes.ProcessingInputsupport local file paths as input, similar to V2'ssagemaker.processing.ProcessingInputwhich had asourceparameter. The current V3ProcessingInputshape (from sagemaker-core auto-generated shapes) requires constructing aProcessingS3Inputwith the local path awkwardly placed in thes3_urifield. While the_normalize_inputs()method inProcessoralready handles uploading local paths found ins3_input.s3_uri, this is unintuitive and poorly documented. The fix should add a conveniencesourceparameter to theProcessingInputused in processing.py (or a helper/factory function) so users can pass local paths directly without manually constructingProcessingS3Inputobjects with local paths in thes3_urifield. SinceProcessingInputis an auto-generated shape from sagemaker-core, the best approach is to add a wrapper/helper insagemaker-core/src/sagemaker/core/processing.pythat accepts a localsourceparameter and constructs the properProcessingInputwithProcessingS3Input.AI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat