Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,46 @@ def wait_for_optimization_job(self, job, poll=5):
_check_job_status(job, desc, "OptimizationJobStatus")
return desc

def _wait_for_processing_job(self, job, poll=5):
"""Wait for an Amazon SageMaker Processing job to complete.

Args:
job (str): Name of the processing job to wait for.
poll (int): Polling interval in seconds (default: 5).

Returns:
(dict): Return value from the ``DescribeProcessingJob`` API.

Raises:
exceptions.CapacityError: If the processing job fails with CapacityError.
exceptions.UnexpectedStatusException: If the processing job fails.
"""
desc = _wait_until(
lambda: _processing_job_status(self.sagemaker_client, job), poll
)
_check_job_status(job, desc, "ProcessingJobStatus")
return desc

def _wait_for_training_job(self, job, poll=5):
"""Wait for an Amazon SageMaker Training job to complete.

Args:
job (str): Name of the training job to wait for.
poll (int): Polling interval in seconds (default: 5).

Returns:
(dict): Return value from the ``DescribeTrainingJob`` API.

Raises:
exceptions.CapacityError: If the training job fails with CapacityError.
exceptions.UnexpectedStatusException: If the training job fails.
"""
desc = _wait_until(
lambda: _training_job_status(self.sagemaker_client, job), poll
)
_check_job_status(job, desc, "TrainingJobStatus")
return desc

def update_inference_component(
self, inference_component_name, specification=None, runtime_config=None, wait=True
):
Expand Down Expand Up @@ -2896,6 +2936,70 @@ def _optimization_job_status(sagemaker_client, job_name):
return desc


def _processing_job_status(sagemaker_client, job_name):
"""Check the status of a processing job.

Args:
sagemaker_client: The boto3 SageMaker client.
job_name (str): The name of the processing job.

Returns:
dict: The processing job description if complete, None if still in progress.
"""
status_codes = {
"Completed": "!",
"InProgress": ".",
"Failed": "*",
"Stopped": "s",
"Stopping": "_",
}
in_progress_statuses = ["InProgress", "Stopping", "Starting"]

desc = sagemaker_client.describe_processing_job(ProcessingJobName=job_name)
status = desc["ProcessingJobStatus"]

status = _STATUS_CODE_TABLE.get(status, status)
print(status_codes.get(status, "?"), end="")
sys.stdout.flush()

if status in in_progress_statuses:
return None

return desc


def _training_job_status(sagemaker_client, job_name):
"""Check the status of a training job.

Args:
sagemaker_client: The boto3 SageMaker client.
job_name (str): The name of the training job.

Returns:
dict: The training job description if complete, None if still in progress.
"""
status_codes = {
"Completed": "!",
"InProgress": ".",
"Failed": "*",
"Stopped": "s",
"Stopping": "_",
}
in_progress_statuses = ["InProgress", "Stopping", "Starting"]

desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
status = desc["TrainingJobStatus"]

status = _STATUS_CODE_TABLE.get(status, status)
print(status_codes.get(status, "?"), end="")
sys.stdout.flush()

if status in in_progress_statuses:
return None

return desc


def container_def(
image_uri,
model_data_url=None,
Expand Down
8 changes: 6 additions & 2 deletions sagemaker-core/src/sagemaker/core/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self.sagemaker_session._wait_for_processing_job(
self.latest_job.processing_job_name
)

def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
"""Extend inputs and outputs based on extra parameters"""
Expand Down Expand Up @@ -846,7 +848,9 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self.sagemaker_session._wait_for_processing_job(
self.latest_job.processing_job_name
)

def _include_code_in_inputs(self, inputs, code, kms_key=None):
"""Converts code to appropriate input and includes in input list.
Expand Down
3 changes: 2 additions & 1 deletion sagemaker-core/tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def test_run_with_wait(self, mock_session):
)

mock_job = Mock()
mock_job.processing_job_name = "test-processing-job"
mock_job.wait = Mock()

with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f:
Expand All @@ -1049,7 +1050,7 @@ def test_run_with_wait(self, mock_session):
"sagemaker.core.s3.S3Uploader.upload", return_value="s3://bucket/code.py"
):
processor.run(code=temp_file, wait=True, logs=False)
mock_job.wait.assert_called_once()
mock_session._wait_for_processing_job.assert_called_once()
finally:
if os.path.exists(temp_file):
os.unlink(temp_file)
Expand Down
182 changes: 182 additions & 0 deletions sagemaker-core/tests/unit/test_session_wait_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for session wait methods (_wait_for_processing_job, _wait_for_training_job).

These methods were added to fix Bug 1 in issue #5765: wait=True does not
respect sagemaker_session, causing NoCredentialsError with assumed-role sessions.
"""
from __future__ import absolute_import

from unittest.mock import MagicMock, patch
import pytest

from sagemaker.core.helper.session_helper import (
_processing_job_status,
_training_job_status,
)


class TestProcessingJobStatus:
"""Tests for the _processing_job_status helper function."""

def test_returns_none_when_in_progress(self):
client = MagicMock()
client.describe_processing_job.return_value = {
"ProcessingJobStatus": "InProgress"
}
result = _processing_job_status(client, "my-job")
assert result is None
client.describe_processing_job.assert_called_once_with(ProcessingJobName="my-job")

def test_returns_desc_when_completed(self):
desc = {"ProcessingJobStatus": "Completed"}
client = MagicMock()
client.describe_processing_job.return_value = desc
result = _processing_job_status(client, "my-job")
assert result == desc

def test_returns_desc_when_failed(self):
desc = {"ProcessingJobStatus": "Failed", "FailureReason": "OOM"}
client = MagicMock()
client.describe_processing_job.return_value = desc
result = _processing_job_status(client, "my-job")
assert result == desc

def test_returns_desc_when_stopped(self):
desc = {"ProcessingJobStatus": "Stopped"}
client = MagicMock()
client.describe_processing_job.return_value = desc
result = _processing_job_status(client, "my-job")
assert result == desc

def test_returns_none_when_stopping(self):
client = MagicMock()
client.describe_processing_job.return_value = {
"ProcessingJobStatus": "Stopping"
}
result = _processing_job_status(client, "my-job")
assert result is None


class TestTrainingJobStatus:
"""Tests for the _training_job_status helper function."""

def test_returns_none_when_in_progress(self):
client = MagicMock()
client.describe_training_job.return_value = {
"TrainingJobStatus": "InProgress"
}
result = _training_job_status(client, "my-job")
assert result is None
client.describe_training_job.assert_called_once_with(TrainingJobName="my-job")

def test_returns_desc_when_completed(self):
desc = {"TrainingJobStatus": "Completed"}
client = MagicMock()
client.describe_training_job.return_value = desc
result = _training_job_status(client, "my-job")
assert result == desc

def test_returns_desc_when_failed(self):
desc = {"TrainingJobStatus": "Failed", "FailureReason": "AlgorithmError"}
client = MagicMock()
client.describe_training_job.return_value = desc
result = _training_job_status(client, "my-job")
assert result == desc


class TestSessionWaitForProcessingJob:
"""Tests for Session._wait_for_processing_job."""

def test_uses_session_client(self):
"""Verify _wait_for_processing_job uses self.sagemaker_client, not global."""
from sagemaker.core.helper.session_helper import Session

session = MagicMock(spec=Session)
session.sagemaker_client = MagicMock()
session.sagemaker_client.describe_processing_job.return_value = {
"ProcessingJobStatus": "Completed"
}

# Call the unbound method with our mock session
Session._wait_for_processing_job(session, "test-job", poll=0.1)

session.sagemaker_client.describe_processing_job.assert_called_with(
ProcessingJobName="test-job"
)

def test_polls_until_complete(self):
"""Verify it polls multiple times until job completes."""
from sagemaker.core.helper.session_helper import Session

session = MagicMock(spec=Session)
session.sagemaker_client = MagicMock()
session.sagemaker_client.describe_processing_job.side_effect = [
{"ProcessingJobStatus": "InProgress"},
{"ProcessingJobStatus": "InProgress"},
{"ProcessingJobStatus": "Completed"},
]

Session._wait_for_processing_job(session, "test-job", poll=0.1)

assert session.sagemaker_client.describe_processing_job.call_count == 3


class TestSessionWaitForTrainingJob:
"""Tests for Session._wait_for_training_job."""

def test_uses_session_client(self):
"""Verify _wait_for_training_job uses self.sagemaker_client, not global."""
from sagemaker.core.helper.session_helper import Session

session = MagicMock(spec=Session)
session.sagemaker_client = MagicMock()
session.sagemaker_client.describe_training_job.return_value = {
"TrainingJobStatus": "Completed"
}

Session._wait_for_training_job(session, "test-job", poll=0.1)

session.sagemaker_client.describe_training_job.assert_called_with(
TrainingJobName="test-job"
)


class TestProcessingUsesSessionWait:
"""Tests that processing.py uses session-aware wait instead of global client."""

def test_processor_run_calls_session_wait(self):
"""Verify Processor.run with wait=True calls _wait_for_processing_job."""
from sagemaker.core.processing import Processor

processor = MagicMock(spec=Processor)
processor.sagemaker_session = MagicMock()
processor.sagemaker_session.__class__.__name__ = "Session"
processor.jobs = []

# Create a mock processing job
mock_job = MagicMock()
mock_job.processing_job_name = "test-processing-job"
processor.latest_job = mock_job

# Simulate what run() does after _start_new
from sagemaker.core.workflow.pipeline_context import PipelineSession
if not isinstance(processor.sagemaker_session, PipelineSession):
processor.jobs.append(processor.latest_job)
processor.sagemaker_session._wait_for_processing_job(
processor.latest_job.processing_job_name
)

processor.sagemaker_session._wait_for_processing_job.assert_called_once_with(
"test-processing-job"
)
Loading
Loading