Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions sagemaker-core/src/sagemaker/core/processing.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope this PR down to Bug 1 only (wait=True ignores sagemaker session). Bugs 2, 3, and 4 are being addressed by separate PRs (5772 and 5773).

Revert all changes related to:

  • Bug 2 (code_location) — any changes to _package_code or code upload paths in processing.py
  • Bug 3 (codeartifact_repo_arn) — any changes to FrameworkProcessor.run() signature or runproc script generation
  • Bug 4 (ModelTrainer CodeArtifact) — any changes to templates.py or INSTALL_REQUIREMENTS

Keep only the fix for Bug 1: ProcessingJob.refresh() and TrainingJob.refresh() use a global default client instead of the sagemaker_session passed by the user, causing NoCredentialsError with assumed-role sessions. The fix should ensure that when wait=True is used in FrameworkProcessor.run() and ModelTrainer.train(), the wait/refresh calls use the same boto session that was passed to the processor/trainer.

Do NOT monkey-patch ProcessingJob.wait or TrainingJob.wait globally. Instead, implement the wait at the call site — in processing.py's _start_new method and in trainer_wait.py — by polling describe_processing_job / describe_training_job using the sagemaker_session's client directly, similar to how v2 did it.

The relevant files are sagemaker-core/src/sagemaker/core/processing.py, sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py, and sagemaker-train/src/sagemaker/train/model_trainer.py. Do not worry about CI failures in this iteration, just focus on this de-scoping! Only address other comments given on this PR if they are related to first bug

Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,53 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self._wait_for_job(self.latest_job, logs=logs)

def _wait_for_job(self, processing_job, logs=True, timeout=3600):
"""Wait for a processing job using the stored sagemaker_session.

This method uses the sagemaker_session from the Processor instance
instead of the global default client, which fixes NoCredentialsError
when using assumed-role sessions.

Args:
processing_job: ProcessingJob resource object.
logs (bool): Whether to show logs (default: True).
Comment thread
aviruthen marked this conversation as resolved.
timeout (int): Maximum time in seconds to wait (default: 3600).
If None, waits indefinitely.
"""
job_name = processing_job.processing_job_name
if logs:
logs_for_processing_job(
self.sagemaker_session, job_name, wait=True
)
Comment thread
aviruthen marked this conversation as resolved.
else:
poll = 10
start_time = time.time()
while True:
Comment thread
aviruthen marked this conversation as resolved.
if timeout and (time.time() - start_time) > timeout:
raise RuntimeError(
f"Timed out waiting for processing job {job_name} "
f"after {timeout} seconds"
)
# TODO: Ideally sagemaker-core's ProcessingJob.refresh()/wait()
# should accept a session parameter. Using ProcessingJob.get()
# with the user's boto_session as a workaround.
processing_job = ProcessingJob.get(
processing_job_name=job_name,
session=self.sagemaker_session.boto_session,
)
status = processing_job.processing_job_status
if status in ("Completed", "Failed", "Stopped"):
if status == "Failed":
reason = getattr(
processing_job, "failure_reason", "Unknown"
)
raise RuntimeError(
f"Processing job {job_name} failed: {reason}"
)
break
time.sleep(poll)

def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
"""Extend inputs and outputs based on extra parameters"""
Expand Down Expand Up @@ -632,7 +678,8 @@ def submit(request):
transformed = transform(serialized_request, "CreateProcessingJobRequest")
# Remove tags from transformed dict as ProcessingJob resource doesn't accept it
transformed.pop("tags", None)
return ProcessingJob(**transformed)
processing_job = ProcessingJob(**transformed)
Comment thread
aviruthen marked this conversation as resolved.
return processing_job

def _get_process_args(self, inputs, outputs, experiment_config):
"""Gets a dict of arguments for a new Amazon SageMaker processing job."""
Expand Down Expand Up @@ -846,7 +893,7 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self._wait_for_job(self.latest_job, logs=logs)

def _include_code_in_inputs(self, inputs, code, kms_key=None):
"""Converts code to appropriate input and includes in input list.
Expand Down Expand Up @@ -1153,8 +1200,7 @@ def _package_code(
item_path = os.path.join(source_dir, item)
tar.add(item_path, arcname=item)

# Upload to S3
s3_uri = s3.s3_path_join(
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
Expand Down Expand Up @@ -1346,6 +1392,17 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):

def _generate_framework_script(self, user_script: str) -> str:
"""Generate the framework entrypoint file (as text) for a processing job."""
requirements_block = dedent(
"""\
if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
"""
)

return dedent(
"""\
#!/bin/bash
Expand All @@ -1369,16 +1426,12 @@ def _generate_framework_script(self, user_script: str) -> str:
exit 1
fi

if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
{requirements_block}

{entry_point_command} {entry_point} "$@"
"""
).format(
requirements_block=requirements_block,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)
Expand Down
178 changes: 178 additions & 0 deletions sagemaker-core/tests/unit/test_processing_regressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session."""
import pytest
from unittest.mock import MagicMock, patch, call


class TestBug1ProcessorWaitUsesSession:
"""Bug 1: wait=True should use sagemaker_session, not global default client."""

def test_processor_wait_for_job_uses_session_no_logs(self):
"""Test that _wait_for_job uses the Processor's sagemaker_session (no logs)."""
from sagemaker.core.processing import Processor

mock_session = MagicMock()
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = ""
mock_session.expand_role.return_value = (
"arn:aws:iam::123456789:role/MyRole"
)
mock_session.boto_session = MagicMock()

processor = Processor(
role="arn:aws:iam::123456789:role/MyRole",
image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

mock_job = MagicMock()
mock_job.processing_job_name = "test-job"

# Mock ProcessingJob.get to return a completed job
with patch("sagemaker.core.processing.ProcessingJob") as MockPJ:
mock_refreshed = MagicMock()
mock_refreshed.processing_job_status = "Completed"
MockPJ.get.return_value = mock_refreshed

processor._wait_for_job(mock_job, logs=False)

MockPJ.get.assert_called_with(
processing_job_name="test-job",
session=mock_session.boto_session,
)

def test_processor_wait_for_job_uses_session_with_logs(self):
"""Test that _wait_for_job with logs=True uses logs_for_processing_job."""
from sagemaker.core.processing import Processor

mock_session = MagicMock()
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = ""
mock_session.expand_role.return_value = (
"arn:aws:iam::123456789:role/MyRole"
)
mock_session.boto_session = MagicMock()

processor = Processor(
role="arn:aws:iam::123456789:role/MyRole",
image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

mock_job = MagicMock()
mock_job.processing_job_name = "test-job"

with patch("sagemaker.core.processing.logs_for_processing_job") as mock_logs:
processor._wait_for_job(mock_job, logs=True)

mock_logs.assert_called_once_with(
mock_session, "test-job", wait=True
)

def test_processor_wait_for_job_raises_on_failed(self):
"""Test that _wait_for_job raises RuntimeError when job fails."""
from sagemaker.core.processing import Processor

mock_session = MagicMock()
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = ""
mock_session.expand_role.return_value = (
"arn:aws:iam::123456789:role/MyRole"
)
mock_session.boto_session = MagicMock()

processor = Processor(
role="arn:aws:iam::123456789:role/MyRole",
image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

mock_job = MagicMock()
mock_job.processing_job_name = "test-job"

with patch("sagemaker.core.processing.ProcessingJob") as MockPJ:
mock_refreshed = MagicMock()
mock_refreshed.processing_job_status = "Failed"
mock_refreshed.failure_reason = "OutOfMemory"
MockPJ.get.return_value = mock_refreshed

with pytest.raises(RuntimeError, match="failed.*OutOfMemory"):
processor._wait_for_job(mock_job, logs=False)

def test_processor_wait_for_job_timeout(self):
"""Test that _wait_for_job raises RuntimeError on timeout."""
from sagemaker.core.processing import Processor

mock_session = MagicMock()
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = ""
mock_session.expand_role.return_value = (
"arn:aws:iam::123456789:role/MyRole"
)
mock_session.boto_session = MagicMock()

processor = Processor(
role="arn:aws:iam::123456789:role/MyRole",
image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
Comment thread
aviruthen marked this conversation as resolved.
)

mock_job = MagicMock()
mock_job.processing_job_name = "test-job"

with patch("sagemaker.core.processing.ProcessingJob") as MockPJ:
mock_refreshed = MagicMock()
mock_refreshed.processing_job_status = "InProgress"
MockPJ.get.return_value = mock_refreshed

with patch("sagemaker.core.processing.time") as mock_time:
# Simulate timeout: first call returns 0, second returns > timeout
mock_time.time.side_effect = [0, 0, 5000]
mock_time.sleep = MagicMock()

with pytest.raises(RuntimeError, match="Timed out"):
processor._wait_for_job(mock_job, logs=False, timeout=1)

def test_processor_run_calls_wait_for_job(self):
"""Test that Processor.run with wait=True calls _wait_for_job."""
from sagemaker.core.processing import Processor

mock_session = MagicMock()
mock_session.default_bucket.return_value = "my-bucket"
mock_session.default_bucket_prefix = ""
mock_session.expand_role.return_value = (
"arn:aws:iam::123456789:role/MyRole"
)
mock_session.boto_session = MagicMock()
mock_session.sagemaker_client = MagicMock()

processor = Processor(
role="arn:aws:iam::123456789:role/MyRole",
image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=mock_session,
)

# Verify _wait_for_job method exists and is callable
assert hasattr(processor, '_wait_for_job')
assert callable(processor._wait_for_job)
Loading
Loading