Skip to content

Commit 7862f5e

Browse files
committed
fix: address review comments (iteration #2)
1 parent 0c9c8fd commit 7862f5e

File tree

5 files changed

+137
-263
lines changed

5 files changed

+137
-263
lines changed

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 29 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def run(
298298
if wait:
299299
self._wait_for_job(self.latest_job, logs=logs)
300300

301-
def _wait_for_job(self, processing_job, logs=True):
301+
def _wait_for_job(self, processing_job, logs=True, timeout=3600):
302302
"""Wait for a processing job using the stored sagemaker_session.
303303
304304
This method uses the sagemaker_session from the Processor instance
@@ -308,6 +308,8 @@ def _wait_for_job(self, processing_job, logs=True):
308308
Args:
309309
processing_job: ProcessingJob resource object.
310310
logs (bool): Whether to show logs (default: True).
311+
timeout (int): Maximum time in seconds to wait (default: 3600).
312+
If None, waits indefinitely.
311313
"""
312314
job_name = processing_job.processing_job_name
313315
if logs:
@@ -316,7 +318,16 @@ def _wait_for_job(self, processing_job, logs=True):
316318
)
317319
else:
318320
poll = 10
321+
start_time = time.time()
319322
while True:
323+
if timeout and (time.time() - start_time) > timeout:
324+
raise RuntimeError(
325+
f"Timed out waiting for processing job {job_name} "
326+
f"after {timeout} seconds"
327+
)
328+
# TODO: Ideally sagemaker-core's ProcessingJob.refresh()/wait()
329+
# should accept a session parameter. Using ProcessingJob.get()
330+
# with the user's boto_session as a workaround.
320331
processing_job = ProcessingJob.get(
321332
processing_job_name=job_name,
322333
session=self.sagemaker_session.boto_session,
@@ -972,26 +983,6 @@ def _handle_user_code_url(self, code, kms_key=None):
972983
)
973984
return user_code_s3_uri
974985

975-
def _get_code_upload_bucket_and_prefix(self):
976-
"""Get the S3 bucket and prefix for code uploads.
977-
978-
If code_location is set (on FrameworkProcessor), parse it to extract
979-
bucket and prefix. Otherwise, use the session's default bucket.
980-
981-
Returns:
982-
tuple: (bucket, prefix) for S3 uploads.
983-
"""
984-
code_location = getattr(self, "code_location", None)
985-
if code_location:
986-
parsed = urlparse(code_location)
987-
bucket = parsed.netloc
988-
prefix = parsed.path.lstrip("/")
989-
return bucket, prefix
990-
return (
991-
self.sagemaker_session.default_bucket(),
992-
self.sagemaker_session.default_bucket_prefix or "",
993-
)
994-
995986
def _upload_code(self, code, kms_key=None):
996987
"""Uploads a code file or directory specified as a string and returns the S3 URI.
997988
@@ -1006,22 +997,20 @@ def _upload_code(self, code, kms_key=None):
1006997
"""
1007998
from sagemaker.core.workflow.utilities import _pipeline_config
1008999

1009-
bucket, prefix = self._get_code_upload_bucket_and_prefix()
1010-
10111000
if _pipeline_config and _pipeline_config.code_hash:
10121001
desired_s3_uri = s3.s3_path_join(
10131002
"s3://",
1014-
bucket,
1015-
prefix,
1003+
self.sagemaker_session.default_bucket(),
1004+
self.sagemaker_session.default_bucket_prefix,
10161005
_pipeline_config.pipeline_name,
10171006
self._CODE_CONTAINER_INPUT_NAME,
10181007
_pipeline_config.code_hash,
10191008
)
10201009
else:
10211010
desired_s3_uri = s3.s3_path_join(
10221011
"s3://",
1023-
bucket,
1024-
prefix,
1012+
self.sagemaker_session.default_bucket(),
1013+
self.sagemaker_session.default_bucket_prefix,
10251014
self._current_job_name,
10261015
"input",
10271016
self._CODE_CONTAINER_INPUT_NAME,
@@ -1211,12 +1200,10 @@ def _package_code(
12111200
item_path = os.path.join(source_dir, item)
12121201
tar.add(item_path, arcname=item)
12131202

1214-
# Upload to S3 - honor code_location if set
1215-
bucket, prefix = self._get_code_upload_bucket_and_prefix()
1216-
s3_uri = s3.s3_path_join(
1203+
s3_uri = s3.s3_path_join(
12171204
"s3://",
1218-
bucket,
1219-
prefix,
1205+
self.sagemaker_session.default_bucket(),
1206+
self.sagemaker_session.default_bucket_prefix or "",
12201207
job_name,
12211208
"source",
12221209
"sourcedir.tar.gz",
@@ -1233,46 +1220,6 @@ def _package_code(
12331220
os.unlink(tmp.name)
12341221
return s3_uri
12351222

1236-
_CODEARTIFACT_ARN_PATTERN = re.compile(
1237-
r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$"
1238-
)
1239-
1240-
@staticmethod
1241-
def _get_codeartifact_command(codeartifact_repo_arn: str) -> str:
1242-
"""Parse a CodeArtifact repository ARN and return the login command.
1243-
1244-
Args:
1245-
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository.
1246-
Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
1247-
1248-
Returns:
1249-
str: The bash command to login to CodeArtifact via pip.
1250-
1251-
Raises:
1252-
ValueError: If the ARN format is invalid.
1253-
"""
1254-
match = FrameworkProcessor._CODEARTIFACT_ARN_PATTERN.match(codeartifact_repo_arn)
1255-
if not match:
1256-
raise ValueError(
1257-
f"Invalid CodeArtifact repository ARN: {codeartifact_repo_arn}. "
1258-
"Expected format: "
1259-
"arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}"
1260-
)
1261-
region = match.group(1)
1262-
domain_owner = match.group(2)
1263-
domain = match.group(3)
1264-
repository = match.group(4)
1265-
1266-
return (
1267-
"if ! hash aws 2>/dev/null; then\n"
1268-
" echo \"AWS CLI is not installed. Skipping CodeArtifact login.\"\n"
1269-
"else\n"
1270-
f" aws codeartifact login --tool pip "
1271-
f"--domain {domain} --domain-owner {domain_owner} "
1272-
f"--repository {repository} --region {region}\n"
1273-
"fi"
1274-
)
1275-
12761223
@_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run")
12771224
@runnable_by_pipeline
12781225
def run(
@@ -1288,7 +1235,6 @@ def run(
12881235
job_name: Optional[str] = None,
12891236
experiment_config: Optional[Dict[str, str]] = None,
12901237
kms_key: Optional[str] = None,
1291-
codeartifact_repo_arn: Optional[str] = None,
12921238
):
12931239
"""Runs a processing job.
12941240
@@ -1316,16 +1262,10 @@ def run(
13161262
experiment_config (dict[str, str]): Experiment management configuration.
13171263
kms_key (str): The ARN of the KMS key that is used to encrypt the
13181264
user code file (default: None).
1319-
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use
1320-
for pip authentication when installing requirements.txt dependencies
1321-
(default: None). Format:
1322-
arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
13231265
Returns:
13241266
None or pipeline step arguments in case the Processor instance is built with
13251267
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
13261268
"""
1327-
self._codeartifact_repo_arn = codeartifact_repo_arn
1328-
13291269
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
13301270
code,
13311271
source_dir,
@@ -1452,32 +1392,16 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
14521392

14531393
def _generate_framework_script(self, user_script: str) -> str:
14541394
"""Generate the framework entrypoint file (as text) for a processing job."""
1455-
codeartifact_repo_arn = getattr(self, "_codeartifact_repo_arn", None)
1456-
1457-
if codeartifact_repo_arn:
1458-
codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn)
1459-
requirements_block = dedent(
1460-
"""\
1461-
if [[ -f 'requirements.txt' ]]; then
1462-
# Some py3 containers has typing, which may breaks pip install
1463-
pip uninstall --yes typing
1464-
1465-
{codeartifact_login}
1466-
pip install -r requirements.txt
1467-
fi
1468-
"""
1469-
).format(codeartifact_login=codeartifact_login)
1470-
else:
1471-
requirements_block = dedent(
1472-
"""\
1473-
if [[ -f 'requirements.txt' ]]; then
1474-
# Some py3 containers has typing, which may breaks pip install
1475-
pip uninstall --yes typing
1476-
1477-
pip install -r requirements.txt
1478-
fi
1479-
"""
1480-
)
1395+
requirements_block = dedent(
1396+
"""\
1397+
if [[ -f 'requirements.txt' ]]; then
1398+
# Some py3 containers has typing, which may breaks pip install
1399+
pip uninstall --yes typing
1400+
1401+
pip install -r requirements.txt
1402+
fi
1403+
"""
1404+
)
14811405

14821406
return dedent(
14831407
"""\

0 commit comments

Comments
 (0)