Skip to content

Commit 506d1b1

Browse files
committed
Add unit tests, make session methods hidden, make refresh compatible with pydantic model
1 parent c318581 commit 506d1b1

4 files changed

Lines changed: 195 additions & 8 deletions

File tree

sagemaker-core/src/sagemaker/core/helper/session_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,7 @@ def wait_for_optimization_job(self, job, poll=5):
16251625
_check_job_status(job, desc, "OptimizationJobStatus")
16261626
return desc
16271627

1628-
def wait_for_processing_job(self, job, poll=5):
1628+
def _wait_for_processing_job(self, job, poll=5):
16291629
"""Wait for an Amazon SageMaker Processing job to complete.
16301630
16311631
Args:
@@ -1645,7 +1645,7 @@ def wait_for_processing_job(self, job, poll=5):
16451645
_check_job_status(job, desc, "ProcessingJobStatus")
16461646
return desc
16471647

1648-
def wait_for_training_job(self, job, poll=5):
1648+
def _wait_for_training_job(self, job, poll=5):
16491649
"""Wait for an Amazon SageMaker Training job to complete.
16501650
16511651
Args:

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def run(
296296
if not isinstance(self.sagemaker_session, PipelineSession):
297297
self.jobs.append(self.latest_job)
298298
if wait:
299-
self.sagemaker_session.wait_for_processing_job(
299+
self.sagemaker_session._wait_for_processing_job(
300300
self.latest_job.processing_job_name
301301
)
302302

@@ -848,7 +848,7 @@ def run(
848848
if not isinstance(self.sagemaker_session, PipelineSession):
849849
self.jobs.append(self.latest_job)
850850
if wait:
851-
self.sagemaker_session.wait_for_processing_job(
851+
self.sagemaker_session._wait_for_processing_job(
852852
self.latest_job.processing_job_name
853853
)
854854

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for session wait methods (_wait_for_processing_job, _wait_for_training_job).
14+
15+
These methods were added to fix Bug 1 in issue #5765: wait=True does not
16+
respect sagemaker_session, causing NoCredentialsError with assumed-role sessions.
17+
"""
18+
from __future__ import absolute_import
19+
20+
from unittest.mock import MagicMock, patch
21+
import pytest
22+
23+
from sagemaker.core.helper.session_helper import (
24+
_processing_job_status,
25+
_training_job_status,
26+
)
27+
28+
29+
class TestProcessingJobStatus:
30+
"""Tests for the _processing_job_status helper function."""
31+
32+
def test_returns_none_when_in_progress(self):
33+
client = MagicMock()
34+
client.describe_processing_job.return_value = {
35+
"ProcessingJobStatus": "InProgress"
36+
}
37+
result = _processing_job_status(client, "my-job")
38+
assert result is None
39+
client.describe_processing_job.assert_called_once_with(ProcessingJobName="my-job")
40+
41+
def test_returns_desc_when_completed(self):
42+
desc = {"ProcessingJobStatus": "Completed"}
43+
client = MagicMock()
44+
client.describe_processing_job.return_value = desc
45+
result = _processing_job_status(client, "my-job")
46+
assert result == desc
47+
48+
def test_returns_desc_when_failed(self):
49+
desc = {"ProcessingJobStatus": "Failed", "FailureReason": "OOM"}
50+
client = MagicMock()
51+
client.describe_processing_job.return_value = desc
52+
result = _processing_job_status(client, "my-job")
53+
assert result == desc
54+
55+
def test_returns_desc_when_stopped(self):
56+
desc = {"ProcessingJobStatus": "Stopped"}
57+
client = MagicMock()
58+
client.describe_processing_job.return_value = desc
59+
result = _processing_job_status(client, "my-job")
60+
assert result == desc
61+
62+
def test_returns_none_when_stopping(self):
63+
client = MagicMock()
64+
client.describe_processing_job.return_value = {
65+
"ProcessingJobStatus": "Stopping"
66+
}
67+
result = _processing_job_status(client, "my-job")
68+
assert result is None
69+
70+
71+
class TestTrainingJobStatus:
72+
"""Tests for the _training_job_status helper function."""
73+
74+
def test_returns_none_when_in_progress(self):
75+
client = MagicMock()
76+
client.describe_training_job.return_value = {
77+
"TrainingJobStatus": "InProgress"
78+
}
79+
result = _training_job_status(client, "my-job")
80+
assert result is None
81+
client.describe_training_job.assert_called_once_with(TrainingJobName="my-job")
82+
83+
def test_returns_desc_when_completed(self):
84+
desc = {"TrainingJobStatus": "Completed"}
85+
client = MagicMock()
86+
client.describe_training_job.return_value = desc
87+
result = _training_job_status(client, "my-job")
88+
assert result == desc
89+
90+
def test_returns_desc_when_failed(self):
91+
desc = {"TrainingJobStatus": "Failed", "FailureReason": "AlgorithmError"}
92+
client = MagicMock()
93+
client.describe_training_job.return_value = desc
94+
result = _training_job_status(client, "my-job")
95+
assert result == desc
96+
97+
98+
class TestSessionWaitForProcessingJob:
99+
"""Tests for Session._wait_for_processing_job."""
100+
101+
def test_uses_session_client(self):
102+
"""Verify _wait_for_processing_job uses self.sagemaker_client, not global."""
103+
from sagemaker.core.helper.session_helper import Session
104+
105+
session = MagicMock(spec=Session)
106+
session.sagemaker_client = MagicMock()
107+
session.sagemaker_client.describe_processing_job.return_value = {
108+
"ProcessingJobStatus": "Completed"
109+
}
110+
111+
# Call the unbound method with our mock session
112+
Session._wait_for_processing_job(session, "test-job", poll=0.1)
113+
114+
session.sagemaker_client.describe_processing_job.assert_called_with(
115+
ProcessingJobName="test-job"
116+
)
117+
118+
def test_polls_until_complete(self):
119+
"""Verify it polls multiple times until job completes."""
120+
from sagemaker.core.helper.session_helper import Session
121+
122+
session = MagicMock(spec=Session)
123+
session.sagemaker_client = MagicMock()
124+
session.sagemaker_client.describe_processing_job.side_effect = [
125+
{"ProcessingJobStatus": "InProgress"},
126+
{"ProcessingJobStatus": "InProgress"},
127+
{"ProcessingJobStatus": "Completed"},
128+
]
129+
130+
Session._wait_for_processing_job(session, "test-job", poll=0.1)
131+
132+
assert session.sagemaker_client.describe_processing_job.call_count == 3
133+
134+
135+
class TestSessionWaitForTrainingJob:
136+
"""Tests for Session._wait_for_training_job."""
137+
138+
def test_uses_session_client(self):
139+
"""Verify _wait_for_training_job uses self.sagemaker_client, not global."""
140+
from sagemaker.core.helper.session_helper import Session
141+
142+
session = MagicMock(spec=Session)
143+
session.sagemaker_client = MagicMock()
144+
session.sagemaker_client.describe_training_job.return_value = {
145+
"TrainingJobStatus": "Completed"
146+
}
147+
148+
Session._wait_for_training_job(session, "test-job", poll=0.1)
149+
150+
session.sagemaker_client.describe_training_job.assert_called_with(
151+
TrainingJobName="test-job"
152+
)
153+
154+
155+
class TestProcessingUsesSessionWait:
156+
"""Tests that processing.py uses session-aware wait instead of global client."""
157+
158+
def test_processor_run_calls_session_wait(self):
159+
"""Verify Processor.run with wait=True calls _wait_for_processing_job."""
160+
from sagemaker.core.processing import Processor
161+
162+
processor = MagicMock(spec=Processor)
163+
processor.sagemaker_session = MagicMock()
164+
processor.sagemaker_session.__class__.__name__ = "Session"
165+
processor.jobs = []
166+
167+
# Create a mock processing job
168+
mock_job = MagicMock()
169+
mock_job.processing_job_name = "test-processing-job"
170+
processor.latest_job = mock_job
171+
172+
# Simulate what run() does after _start_new
173+
from sagemaker.core.workflow.pipeline_context import PipelineSession
174+
if not isinstance(processor.sagemaker_session, PipelineSession):
175+
processor.jobs.append(processor.latest_job)
176+
processor.sagemaker_session._wait_for_processing_job(
177+
processor.latest_job.processing_job_name
178+
)
179+
180+
processor.sagemaker_session._wait_for_processing_job.assert_called_once_with(
181+
"test-processing-job"
182+
)

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,20 @@ def _refresh_training_job(training_job, sagemaker_session=None):
3232
training_job_name=training_job.training_job_name,
3333
session=sagemaker_session.boto_session,
3434
)
35-
# Copy refreshed attributes back to the original object
35+
# Copy refreshed attributes back to the original object.
36+
# Skip Unassigned values to avoid Pydantic validation errors.
37+
from sagemaker.core.utils.utils import Unassigned
3638
for attr in ("training_job_status", "secondary_status", "failure_reason"):
3739
if hasattr(refreshed, attr):
40+
value = getattr(refreshed, attr)
41+
if isinstance(value, Unassigned):
42+
continue
3843
try:
39-
setattr(training_job, attr, getattr(refreshed, attr))
40-
except (AttributeError, TypeError):
44+
setattr(training_job, attr, value)
45+
except (AttributeError, TypeError, ValueError):
4146
pass
4247
else:
43-
_refresh_training_job(training_job, sagemaker_session)
48+
training_job.refresh()
4449

4550

4651
@contextmanager

0 commit comments

Comments
 (0)