Skip to content

Commit c318581

Browse files
committed
fix: wait=True respects sagemaker_session for processing and training jobs (5765)
1 parent 6497a94 commit c318581

4 files changed

Lines changed: 154 additions & 10 deletions

File tree

sagemaker-core/src/sagemaker/core/helper/session_helper.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,46 @@ def wait_for_optimization_job(self, job, poll=5):
16251625
_check_job_status(job, desc, "OptimizationJobStatus")
16261626
return desc
16271627

1628+
def wait_for_processing_job(self, job, poll=5):
1629+
"""Wait for an Amazon SageMaker Processing job to complete.
1630+
1631+
Args:
1632+
job (str): Name of the processing job to wait for.
1633+
poll (int): Polling interval in seconds (default: 5).
1634+
1635+
Returns:
1636+
(dict): Return value from the ``DescribeProcessingJob`` API.
1637+
1638+
Raises:
1639+
exceptions.CapacityError: If the processing job fails with CapacityError.
1640+
exceptions.UnexpectedStatusException: If the processing job fails.
1641+
"""
1642+
desc = _wait_until(
1643+
lambda: _processing_job_status(self.sagemaker_client, job), poll
1644+
)
1645+
_check_job_status(job, desc, "ProcessingJobStatus")
1646+
return desc
1647+
1648+
def wait_for_training_job(self, job, poll=5):
1649+
"""Wait for an Amazon SageMaker Training job to complete.
1650+
1651+
Args:
1652+
job (str): Name of the training job to wait for.
1653+
poll (int): Polling interval in seconds (default: 5).
1654+
1655+
Returns:
1656+
(dict): Return value from the ``DescribeTrainingJob`` API.
1657+
1658+
Raises:
1659+
exceptions.CapacityError: If the training job fails with CapacityError.
1660+
exceptions.UnexpectedStatusException: If the training job fails.
1661+
"""
1662+
desc = _wait_until(
1663+
lambda: _training_job_status(self.sagemaker_client, job), poll
1664+
)
1665+
_check_job_status(job, desc, "TrainingJobStatus")
1666+
return desc
1667+
16281668
def update_inference_component(
16291669
self, inference_component_name, specification=None, runtime_config=None, wait=True
16301670
):
@@ -2896,6 +2936,70 @@ def _optimization_job_status(sagemaker_client, job_name):
28962936
return desc
28972937

28982938

2939+
def _processing_job_status(sagemaker_client, job_name):
2940+
"""Check the status of a processing job.
2941+
2942+
Args:
2943+
sagemaker_client: The boto3 SageMaker client.
2944+
job_name (str): The name of the processing job.
2945+
2946+
Returns:
2947+
dict: The processing job description if complete, None if still in progress.
2948+
"""
2949+
status_codes = {
2950+
"Completed": "!",
2951+
"InProgress": ".",
2952+
"Failed": "*",
2953+
"Stopped": "s",
2954+
"Stopping": "_",
2955+
}
2956+
in_progress_statuses = ["InProgress", "Stopping", "Starting"]
2957+
2958+
desc = sagemaker_client.describe_processing_job(ProcessingJobName=job_name)
2959+
status = desc["ProcessingJobStatus"]
2960+
2961+
status = _STATUS_CODE_TABLE.get(status, status)
2962+
print(status_codes.get(status, "?"), end="")
2963+
sys.stdout.flush()
2964+
2965+
if status in in_progress_statuses:
2966+
return None
2967+
2968+
return desc
2969+
2970+
2971+
def _training_job_status(sagemaker_client, job_name):
2972+
"""Check the status of a training job.
2973+
2974+
Args:
2975+
sagemaker_client: The boto3 SageMaker client.
2976+
job_name (str): The name of the training job.
2977+
2978+
Returns:
2979+
dict: The training job description if complete, None if still in progress.
2980+
"""
2981+
status_codes = {
2982+
"Completed": "!",
2983+
"InProgress": ".",
2984+
"Failed": "*",
2985+
"Stopped": "s",
2986+
"Stopping": "_",
2987+
}
2988+
in_progress_statuses = ["InProgress", "Stopping", "Starting"]
2989+
2990+
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
2991+
status = desc["TrainingJobStatus"]
2992+
2993+
status = _STATUS_CODE_TABLE.get(status, status)
2994+
print(status_codes.get(status, "?"), end="")
2995+
sys.stdout.flush()
2996+
2997+
if status in in_progress_statuses:
2998+
return None
2999+
3000+
return desc
3001+
3002+
28993003
def container_def(
29003004
image_uri,
29013005
model_data_url=None,

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def run(
296296
if not isinstance(self.sagemaker_session, PipelineSession):
297297
self.jobs.append(self.latest_job)
298298
if wait:
299-
self.latest_job.wait(logs=logs)
299+
self.sagemaker_session.wait_for_processing_job(
300+
self.latest_job.processing_job_name
301+
)
300302

301303
def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
302304
"""Extend inputs and outputs based on extra parameters"""
@@ -846,7 +848,9 @@ def run(
846848
if not isinstance(self.sagemaker_session, PipelineSession):
847849
self.jobs.append(self.latest_job)
848850
if wait:
849-
self.latest_job.wait(logs=logs)
851+
self.sagemaker_session.wait_for_processing_job(
852+
self.latest_job.processing_job_name
853+
)
850854

851855
def _include_code_in_inputs(self, inputs, code, kms_key=None):
852856
"""Converts code to appropriate input and includes in input list.

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,34 @@
1515
from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil
1616

1717

18+
def _refresh_training_job(training_job, sagemaker_session=None):
19+
"""Refresh a training job using the session-aware client if available.
20+
21+
When sagemaker_session is provided, re-fetches the training job via
22+
TrainingJob.get() with the user's boto_session, which avoids the
23+
NoCredentialsError that occurs when refresh() uses the global default client.
24+
25+
Args:
26+
training_job (TrainingJob): The training job to refresh.
27+
sagemaker_session: SageMaker session with the correct credentials.
28+
If None, falls back to the default _refresh_training_job(training_job, sagemaker_session).
29+
"""
30+
if sagemaker_session is not None:
31+
refreshed = TrainingJob.get(
32+
training_job_name=training_job.training_job_name,
33+
session=sagemaker_session.boto_session,
34+
)
35+
# Copy refreshed attributes back to the original object
36+
for attr in ("training_job_status", "secondary_status", "failure_reason"):
37+
if hasattr(refreshed, attr):
38+
try:
39+
setattr(training_job, attr, getattr(refreshed, attr))
40+
except (AttributeError, TypeError):
41+
pass
42+
else:
43+
_refresh_training_job(training_job, sagemaker_session)
44+
45+
1846
@contextmanager
1947
def _suppress_info_logging():
2048
"""Context manager to temporarily suppress INFO level logging."""
@@ -218,14 +246,18 @@ def get_mlflow_url(training_job) -> str:
218246
def wait(
219247
training_job: TrainingJob,
220248
poll: int = 5,
221-
timeout: Optional[int] = 3000
249+
timeout: Optional[int] = 3000,
250+
sagemaker_session=None,
222251
) -> None:
223252
"""Wait for training job to complete with progress tracking.
224253
225254
Args:
226255
training_job (TrainingJob): The SageMaker training job to monitor.
227-
poll (int): Polling interval in seconds. Defaults to 3.
228-
timeout (Optional[int]): Maximum wait time in seconds. Defaults to None.
256+
poll (int): Polling interval in seconds. Defaults to 5.
257+
timeout (Optional[int]): Maximum wait time in seconds. Defaults to 3000.
258+
sagemaker_session: SageMaker session to use for describe calls.
259+
If provided, uses the session's sagemaker_client instead of the
260+
global default client, fixing NoCredentialsError with assumed-role sessions.
229261
230262
Raises:
231263
FailedStatusError: If the training job fails.
@@ -277,7 +309,7 @@ def get_cached_mlflow_url():
277309
iteration += 1
278310
time.sleep(0.5)
279311
if iteration >= poll * 2:
280-
training_job.refresh()
312+
_refresh_training_job(training_job, sagemaker_session)
281313
iteration = 0
282314

283315
status = training_job.training_job_status
@@ -360,7 +392,7 @@ def get_cached_mlflow_url():
360392
if not progress_started:
361393
progress_started = True
362394
time.sleep(poll)
363-
training_job.refresh()
395+
_refresh_training_job(training_job, sagemaker_session)
364396

365397
training_progress_pct, training_progress_text = _calculate_training_progress(
366398
training_job.progress_info, metrics_util, mlflow_run_name, training_job
@@ -442,7 +474,7 @@ def get_cached_mlflow_url():
442474
while True:
443475
iteration += 1
444476
time.sleep(poll)
445-
training_job.refresh()
477+
_refresh_training_job(training_job, sagemaker_session)
446478

447479
status = training_job.training_job_status
448480
secondary_status = training_job.secondary_status
@@ -462,7 +494,7 @@ def get_cached_mlflow_url():
462494
if not progress_started:
463495
progress_started = True
464496
time.sleep(20)
465-
training_job.refresh()
497+
_refresh_training_job(training_job, sagemaker_session)
466498

467499
progress_pct, progress_text = _calculate_training_progress(
468500
training_job.progress_info, metrics_util, mlflow_run_name, training_job

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,11 @@ def train(
790790
self._latest_training_job = training_job
791791

792792
if wait:
793-
training_job.wait(logs=logs)
793+
from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait
794+
trainer_wait(
795+
training_job=training_job,
796+
sagemaker_session=self.sagemaker_session,
797+
)
794798
if logs and not wait:
795799
logger.warning(
796800
"Not displaing the training container logs as 'wait' is set to False."

0 commit comments

Comments
 (0)