Skip to content

Commit 0c9c8fd

Browse files
committed
fix: address review comments (iteration #1)
1 parent da2c76d commit 0c9c8fd

File tree

6 files changed

+279
-253
lines changed

6 files changed

+279
-253
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 66 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,42 @@ def run(
296296
if not isinstance(self.sagemaker_session, PipelineSession):
297297
self.jobs.append(self.latest_job)
298298
if wait:
299-
self.latest_job.wait(logs=logs)
299+
self._wait_for_job(self.latest_job, logs=logs)
300+
301+
def _wait_for_job(self, processing_job, logs=True):
302+
"""Wait for a processing job using the stored sagemaker_session.
303+
304+
This method uses the sagemaker_session from the Processor instance
305+
instead of the global default client, which fixes NoCredentialsError
306+
when using assumed-role sessions.
307+
308+
Args:
309+
processing_job: ProcessingJob resource object.
310+
logs (bool): Whether to show logs (default: True).
311+
"""
312+
job_name = processing_job.processing_job_name
313+
if logs:
314+
logs_for_processing_job(
315+
self.sagemaker_session, job_name, wait=True
316+
)
317+
else:
318+
poll = 10
319+
while True:
320+
processing_job = ProcessingJob.get(
321+
processing_job_name=job_name,
322+
session=self.sagemaker_session.boto_session,
323+
)
324+
status = processing_job.processing_job_status
325+
if status in ("Completed", "Failed", "Stopped"):
326+
if status == "Failed":
327+
reason = getattr(
328+
processing_job, "failure_reason", "Unknown"
329+
)
330+
raise RuntimeError(
331+
f"Processing job {job_name} failed: {reason}"
332+
)
333+
break
334+
time.sleep(poll)
300335

301336
def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
302337
"""Extend inputs and outputs based on extra parameters"""
@@ -633,8 +668,6 @@ def submit(request):
633668
# Remove tags from transformed dict as ProcessingJob resource doesn't accept it
634669
transformed.pop("tags", None)
635670
processing_job = ProcessingJob(**transformed)
636-
# Store the sagemaker_session on the job so wait/refresh can use it
637-
processing_job._sagemaker_session = self.sagemaker_session
638671
return processing_job
639672

640673
def _get_process_args(self, inputs, outputs, experiment_config):
@@ -849,7 +882,7 @@ def run(
849882
if not isinstance(self.sagemaker_session, PipelineSession):
850883
self.jobs.append(self.latest_job)
851884
if wait:
852-
self.latest_job.wait(logs=logs)
885+
self._wait_for_job(self.latest_job, logs=logs)
853886

854887
def _include_code_in_inputs(self, inputs, code, kms_key=None):
855888
"""Converts code to appropriate input and includes in input list.
@@ -948,11 +981,11 @@ def _get_code_upload_bucket_and_prefix(self):
948981
Returns:
949982
tuple: (bucket, prefix) for S3 uploads.
950983
"""
951-
code_location = getattr(self, 'code_location', None)
984+
code_location = getattr(self, "code_location", None)
952985
if code_location:
953986
parsed = urlparse(code_location)
954987
bucket = parsed.netloc
955-
prefix = parsed.path.lstrip('/')
988+
prefix = parsed.path.lstrip("/")
956989
return bucket, prefix
957990
return (
958991
self.sagemaker_session.default_bucket(),
@@ -1200,8 +1233,12 @@ def _package_code(
12001233
os.unlink(tmp.name)
12011234
return s3_uri
12021235

1236+
_CODEARTIFACT_ARN_PATTERN = re.compile(
1237+
r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$"
1238+
)
1239+
12031240
@staticmethod
1204-
def _get_codeartifact_command(codeartifact_repo_arn):
1241+
def _get_codeartifact_command(codeartifact_repo_arn: str) -> str:
12051242
"""Parse a CodeArtifact repository ARN and return the login command.
12061243
12071244
Args:
@@ -1210,24 +1247,30 @@ def _get_codeartifact_command(codeartifact_repo_arn):
12101247
12111248
Returns:
12121249
str: The bash command to login to CodeArtifact via pip.
1250+
1251+
Raises:
1252+
ValueError: If the ARN format is invalid.
12131253
"""
1214-
# Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
1215-
parts = codeartifact_repo_arn.split(':')
1216-
region = parts[3]
1217-
domain_owner = parts[4]
1218-
resource = parts[5] # repository/{domain}/{repository}
1219-
resource_parts = resource.split('/')
1220-
domain = resource_parts[1]
1221-
repository = resource_parts[2]
1254+
match = FrameworkProcessor._CODEARTIFACT_ARN_PATTERN.match(codeartifact_repo_arn)
1255+
if not match:
1256+
raise ValueError(
1257+
f"Invalid CodeArtifact repository ARN: {codeartifact_repo_arn}. "
1258+
"Expected format: "
1259+
"arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}"
1260+
)
1261+
region = match.group(1)
1262+
domain_owner = match.group(2)
1263+
domain = match.group(3)
1264+
repository = match.group(4)
12221265

12231266
return (
1224-
f'if ! hash aws 2>/dev/null; then\n'
1225-
f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n'
1226-
f'else\n'
1227-
f' aws codeartifact login --tool pip '
1228-
f'--domain {domain} --domain-owner {domain_owner} '
1229-
f'--repository {repository} --region {region}\n'
1230-
f'fi'
1267+
"if ! hash aws 2>/dev/null; then\n"
1268+
" echo \"AWS CLI is not installed. Skipping CodeArtifact login.\"\n"
1269+
"else\n"
1270+
f" aws codeartifact login --tool pip "
1271+
f"--domain {domain} --domain-owner {domain_owner} "
1272+
f"--repository {repository} --region {region}\n"
1273+
"fi"
12311274
)
12321275

12331276
@_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run")
@@ -1409,7 +1452,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
14091452

14101453
def _generate_framework_script(self, user_script: str) -> str:
14111454
"""Generate the framework entrypoint file (as text) for a processing job."""
1412-
codeartifact_repo_arn = getattr(self, '_codeartifact_repo_arn', None)
1455+
codeartifact_repo_arn = getattr(self, "_codeartifact_repo_arn", None)
14131456

14141457
if codeartifact_repo_arn:
14151458
codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn)
@@ -1470,68 +1513,6 @@ def _generate_framework_script(self, user_script: str) -> str:
14701513
)
14711514

14721515

1473-
def _wait_for_processing_job(processing_job, logs=True):
1474-
"""Wait for a processing job using the stored sagemaker_session.
1475-
1476-
This function uses the sagemaker_session stored on the processing_job
1477-
(if available) instead of the global default client, which fixes
1478-
NoCredentialsError when using assumed-role sessions.
1479-
1480-
Args:
1481-
processing_job: ProcessingJob resource object with _sagemaker_session attached.
1482-
logs (bool): Whether to show logs (default: True).
1483-
"""
1484-
sagemaker_session = getattr(processing_job, '_sagemaker_session', None)
1485-
job_name = processing_job.processing_job_name
1486-
1487-
if sagemaker_session is not None:
1488-
if logs:
1489-
logs_for_processing_job(sagemaker_session, job_name, wait=True)
1490-
else:
1491-
# Poll using the session's client
1492-
poll = 10
1493-
while True:
1494-
response = sagemaker_session.sagemaker_client.describe_processing_job(
1495-
ProcessingJobName=job_name
1496-
)
1497-
status = response.get('ProcessingJobStatus', 'Unknown')
1498-
if status in ('Completed', 'Failed', 'Stopped'):
1499-
if status == 'Failed':
1500-
reason = response.get('FailureReason', 'Unknown')
1501-
raise RuntimeError(
1502-
f"Processing job {job_name} failed: {reason}"
1503-
)
1504-
break
1505-
time.sleep(poll)
1506-
else:
1507-
# Fallback to the original refresh-based wait
1508-
processing_job.wait(logs=logs)
1509-
1510-
1511-
# Monkey-patch ProcessingJob.wait to use session-aware waiting
1512-
_original_processing_job_wait = getattr(ProcessingJob, 'wait', None)
1513-
1514-
1515-
def _patched_processing_job_wait(self, logs=True):
1516-
"""Session-aware wait for ProcessingJob."""
1517-
if hasattr(self, '_sagemaker_session') and self._sagemaker_session is not None:
1518-
_wait_for_processing_job(self, logs=logs)
1519-
elif _original_processing_job_wait:
1520-
_original_processing_job_wait(self, logs=logs)
1521-
else:
1522-
# Fallback polling
1523-
poll = 10
1524-
while True:
1525-
self.refresh()
1526-
status = self.processing_job_status
1527-
if status in ('Completed', 'Failed', 'Stopped'):
1528-
break
1529-
time.sleep(poll)
1530-
1531-
1532-
ProcessingJob.wait = _patched_processing_job_wait
1533-
1534-
15351516
class FeatureStoreOutput(ApiObject):
15361517
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""
15371518

0 commit comments

Comments
 (0)