1515from sagemaker .train .common_utils .mlflow_metrics_util import _MLflowMetricsUtil
1616
1717
18+ def _refresh_training_job (training_job , sagemaker_session = None ):
19+ """Refresh a training job using the session-aware client if available.
20+
21+ When sagemaker_session is provided, re-fetches the training job via
22+ TrainingJob.get() with the user's boto_session, which avoids the
23+ NoCredentialsError that occurs when refresh() uses the global default client.
24+
25+ Args:
26+ training_job (TrainingJob): The training job to refresh.
27+ sagemaker_session: SageMaker session with the correct credentials.
28+ If None, falls back to the default _refresh_training_job(training_job, sagemaker_session).
29+ """
30+ if sagemaker_session is not None :
31+ refreshed = TrainingJob .get (
32+ training_job_name = training_job .training_job_name ,
33+ session = sagemaker_session .boto_session ,
34+ )
35+ # Copy refreshed attributes back to the original object
36+ for attr in ("training_job_status" , "secondary_status" , "failure_reason" ):
37+ if hasattr (refreshed , attr ):
38+ try :
39+ setattr (training_job , attr , getattr (refreshed , attr ))
40+ except (AttributeError , TypeError ):
41+ pass
42+ else :
43+ _refresh_training_job (training_job , sagemaker_session )
44+
45+
1846@contextmanager
1947def _suppress_info_logging ():
2048 """Context manager to temporarily suppress INFO level logging."""
@@ -218,14 +246,18 @@ def get_mlflow_url(training_job) -> str:
218246def wait (
219247 training_job : TrainingJob ,
220248 poll : int = 5 ,
221- timeout : Optional [int ] = 3000
249+ timeout : Optional [int ] = 3000 ,
250+ sagemaker_session = None ,
222251) -> None :
223252 """Wait for training job to complete with progress tracking.
224253
225254 Args:
226255 training_job (TrainingJob): The SageMaker training job to monitor.
227- poll (int): Polling interval in seconds. Defaults to 3.
228- timeout (Optional[int]): Maximum wait time in seconds. Defaults to None.
256+ poll (int): Polling interval in seconds. Defaults to 5.
257+ timeout (Optional[int]): Maximum wait time in seconds. Defaults to 3000.
258+ sagemaker_session: SageMaker session to use for describe calls.
259+ If provided, uses the session's sagemaker_client instead of the
260+ global default client, fixing NoCredentialsError with assumed-role sessions.
229261
230262 Raises:
231263 FailedStatusError: If the training job fails.
@@ -277,7 +309,7 @@ def get_cached_mlflow_url():
277309 iteration += 1
278310 time .sleep (0.5 )
279311 if iteration >= poll * 2 :
280- training_job . refresh ( )
312+ _refresh_training_job ( training_job , sagemaker_session )
281313 iteration = 0
282314
283315 status = training_job .training_job_status
@@ -360,7 +392,7 @@ def get_cached_mlflow_url():
360392 if not progress_started :
361393 progress_started = True
362394 time .sleep (poll )
363- training_job . refresh ( )
395+ _refresh_training_job ( training_job , sagemaker_session )
364396
365397 training_progress_pct , training_progress_text = _calculate_training_progress (
366398 training_job .progress_info , metrics_util , mlflow_run_name , training_job
@@ -442,7 +474,7 @@ def get_cached_mlflow_url():
442474 while True :
443475 iteration += 1
444476 time .sleep (poll )
445- training_job . refresh ( )
477+ _refresh_training_job ( training_job , sagemaker_session )
446478
447479 status = training_job .training_job_status
448480 secondary_status = training_job .secondary_status
@@ -462,7 +494,7 @@ def get_cached_mlflow_url():
462494 if not progress_started :
463495 progress_started = True
464496 time .sleep (20 )
465- training_job . refresh ( )
497+ _refresh_training_job ( training_job , sagemaker_session )
466498
467499 progress_pct , progress_text = _calculate_training_progress (
468500 training_job .progress_info , metrics_util , mlflow_run_name , training_job
0 commit comments