@@ -296,7 +296,42 @@ def run(
296296 if not isinstance (self .sagemaker_session , PipelineSession ):
297297 self .jobs .append (self .latest_job )
298298 if wait :
299- self .latest_job .wait (logs = logs )
299+ self ._wait_for_job (self .latest_job , logs = logs )
300+
301+ def _wait_for_job (self , processing_job , logs = True ):
302+ """Wait for a processing job using the stored sagemaker_session.
303+
304+ This method uses the sagemaker_session from the Processor instance
305+ instead of the global default client, which fixes NoCredentialsError
306+ when using assumed-role sessions.
307+
308+ Args:
309+ processing_job: ProcessingJob resource object.
310+ logs (bool): Whether to show logs (default: True).
311+ """
312+ job_name = processing_job .processing_job_name
313+ if logs :
314+ logs_for_processing_job (
315+ self .sagemaker_session , job_name , wait = True
316+ )
317+ else :
318+ poll = 10
319+ while True :
320+ processing_job = ProcessingJob .get (
321+ processing_job_name = job_name ,
322+ session = self .sagemaker_session .boto_session ,
323+ )
324+ status = processing_job .processing_job_status
325+ if status in ("Completed" , "Failed" , "Stopped" ):
326+ if status == "Failed" :
327+ reason = getattr (
328+ processing_job , "failure_reason" , "Unknown"
329+ )
330+ raise RuntimeError (
331+ f"Processing job { job_name } failed: { reason } "
332+ )
333+ break
334+ time .sleep (poll )
300335
301336 def _extend_processing_args (self , inputs , outputs , ** kwargs ): # pylint: disable=W0613
302337 """Extend inputs and outputs based on extra parameters"""
@@ -633,8 +668,6 @@ def submit(request):
633668 # Remove tags from transformed dict as ProcessingJob resource doesn't accept it
634669 transformed .pop ("tags" , None )
635670 processing_job = ProcessingJob (** transformed )
636- # Store the sagemaker_session on the job so wait/refresh can use it
637- processing_job ._sagemaker_session = self .sagemaker_session
638671 return processing_job
639672
640673 def _get_process_args (self , inputs , outputs , experiment_config ):
@@ -849,7 +882,7 @@ def run(
849882 if not isinstance (self .sagemaker_session , PipelineSession ):
850883 self .jobs .append (self .latest_job )
851884 if wait :
852- self .latest_job . wait ( logs = logs )
885+ self ._wait_for_job ( self . latest_job , logs = logs )
853886
854887 def _include_code_in_inputs (self , inputs , code , kms_key = None ):
855888 """Converts code to appropriate input and includes in input list.
@@ -948,11 +981,11 @@ def _get_code_upload_bucket_and_prefix(self):
948981 Returns:
949982 tuple: (bucket, prefix) for S3 uploads.
950983 """
951- code_location = getattr (self , ' code_location' , None )
984+ code_location = getattr (self , " code_location" , None )
952985 if code_location :
953986 parsed = urlparse (code_location )
954987 bucket = parsed .netloc
955- prefix = parsed .path .lstrip ('/' )
988+ prefix = parsed .path .lstrip ("/" )
956989 return bucket , prefix
957990 return (
958991 self .sagemaker_session .default_bucket (),
@@ -1200,8 +1233,12 @@ def _package_code(
12001233 os .unlink (tmp .name )
12011234 return s3_uri
12021235
1236+ _CODEARTIFACT_ARN_PATTERN = re .compile (
1237+ r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$"
1238+ )
1239+
12031240 @staticmethod
1204- def _get_codeartifact_command (codeartifact_repo_arn ) :
1241+ def _get_codeartifact_command (codeartifact_repo_arn : str ) -> str :
12051242 """Parse a CodeArtifact repository ARN and return the login command.
12061243
12071244 Args:
@@ -1210,24 +1247,30 @@ def _get_codeartifact_command(codeartifact_repo_arn):
12101247
12111248 Returns:
12121249 str: The bash command to login to CodeArtifact via pip.
1250+
1251+ Raises:
1252+ ValueError: If the ARN format is invalid.
12131253 """
1214- # Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
1215- parts = codeartifact_repo_arn .split (':' )
1216- region = parts [3 ]
1217- domain_owner = parts [4 ]
1218- resource = parts [5 ] # repository/{domain}/{repository}
1219- resource_parts = resource .split ('/' )
1220- domain = resource_parts [1 ]
1221- repository = resource_parts [2 ]
1254+ match = FrameworkProcessor ._CODEARTIFACT_ARN_PATTERN .match (codeartifact_repo_arn )
1255+ if not match :
1256+ raise ValueError (
1257+ f"Invalid CodeArtifact repository ARN: { codeartifact_repo_arn } . "
1258+ "Expected format: "
1259+ "arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}"
1260+ )
1261+ region = match .group (1 )
1262+ domain_owner = match .group (2 )
1263+ domain = match .group (3 )
1264+ repository = match .group (4 )
12221265
12231266 return (
1224- f' if ! hash aws 2>/dev/null; then\n '
1225- f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n '
1226- f' else\n '
1227- f' aws codeartifact login --tool pip '
1228- f' --domain { domain } --domain-owner { domain_owner } '
1229- f' --repository { repository } --region { region } \n '
1230- f'fi'
1267+ " if ! hash aws 2>/dev/null; then\n "
1268+ " echo \ " AWS CLI is not installed. Skipping CodeArtifact login.\ "\n "
1269+ " else\n "
1270+ f" aws codeartifact login --tool pip "
1271+ f" --domain { domain } --domain-owner { domain_owner } "
1272+ f" --repository { repository } --region { region } \n "
1273+ "fi"
12311274 )
12321275
12331276 @_telemetry_emitter (feature = Feature .PROCESSING , func_name = "FrameworkProcessor.run" )
@@ -1409,7 +1452,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
14091452
14101453 def _generate_framework_script (self , user_script : str ) -> str :
14111454 """Generate the framework entrypoint file (as text) for a processing job."""
1412- codeartifact_repo_arn = getattr (self , ' _codeartifact_repo_arn' , None )
1455+ codeartifact_repo_arn = getattr (self , " _codeartifact_repo_arn" , None )
14131456
14141457 if codeartifact_repo_arn :
14151458 codeartifact_login = self ._get_codeartifact_command (codeartifact_repo_arn )
@@ -1470,68 +1513,6 @@ def _generate_framework_script(self, user_script: str) -> str:
14701513 )
14711514
14721515
1473- def _wait_for_processing_job (processing_job , logs = True ):
1474- """Wait for a processing job using the stored sagemaker_session.
1475-
1476- This function uses the sagemaker_session stored on the processing_job
1477- (if available) instead of the global default client, which fixes
1478- NoCredentialsError when using assumed-role sessions.
1479-
1480- Args:
1481- processing_job: ProcessingJob resource object with _sagemaker_session attached.
1482- logs (bool): Whether to show logs (default: True).
1483- """
1484- sagemaker_session = getattr (processing_job , '_sagemaker_session' , None )
1485- job_name = processing_job .processing_job_name
1486-
1487- if sagemaker_session is not None :
1488- if logs :
1489- logs_for_processing_job (sagemaker_session , job_name , wait = True )
1490- else :
1491- # Poll using the session's client
1492- poll = 10
1493- while True :
1494- response = sagemaker_session .sagemaker_client .describe_processing_job (
1495- ProcessingJobName = job_name
1496- )
1497- status = response .get ('ProcessingJobStatus' , 'Unknown' )
1498- if status in ('Completed' , 'Failed' , 'Stopped' ):
1499- if status == 'Failed' :
1500- reason = response .get ('FailureReason' , 'Unknown' )
1501- raise RuntimeError (
1502- f"Processing job { job_name } failed: { reason } "
1503- )
1504- break
1505- time .sleep (poll )
1506- else :
1507- # Fallback to the original refresh-based wait
1508- processing_job .wait (logs = logs )
1509-
1510-
1511- # Monkey-patch ProcessingJob.wait to use session-aware waiting
1512- _original_processing_job_wait = getattr (ProcessingJob , 'wait' , None )
1513-
1514-
1515- def _patched_processing_job_wait (self , logs = True ):
1516- """Session-aware wait for ProcessingJob."""
1517- if hasattr (self , '_sagemaker_session' ) and self ._sagemaker_session is not None :
1518- _wait_for_processing_job (self , logs = logs )
1519- elif _original_processing_job_wait :
1520- _original_processing_job_wait (self , logs = logs )
1521- else :
1522- # Fallback polling
1523- poll = 10
1524- while True :
1525- self .refresh ()
1526- status = self .processing_job_status
1527- if status in ('Completed' , 'Failed' , 'Stopped' ):
1528- break
1529- time .sleep (poll )
1530-
1531-
1532- ProcessingJob .wait = _patched_processing_job_wait
1533-
1534-
15351516class FeatureStoreOutput (ApiObject ):
15361517 """Configuration for processing job outputs in Amazon SageMaker Feature Store."""
15371518
0 commit comments