Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 145 additions & 16 deletions sagemaker-core/src/sagemaker/core/processing.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope this PR down to Bug 1 only (wait=True ignores sagemaker session). Bugs 2, 3, and 4 are being addressed by separate PRs (5772 and 5773).

Revert all changes related to:

  • Bug 2 (code_location) — any changes to _package_code or code upload paths in processing.py
  • Bug 3 (codeartifact_repo_arn) — any changes to FrameworkProcessor.run() signature or runproc script generation
  • Bug 4 (ModelTrainer CodeArtifact) — any changes to templates.py or INSTALL_REQUIREMENTS

Keep only the fix for Bug 1: ProcessingJob.refresh() and TrainingJob.refresh() use a global default client instead of the sagemaker_session passed by the user, causing NoCredentialsError with assumed-role sessions. The fix should ensure that when wait=True is used in FrameworkProcessor.run() and ModelTrainer.train(), the wait/refresh calls use the same boto session that was passed to the processor/trainer.

Do NOT monkey-patch ProcessingJob.wait or TrainingJob.wait globally. Instead, implement the wait at the call site — in processing.py's _start_new method and in trainer_wait.py — by polling describe_processing_job / describe_training_job using the sagemaker_session's client directly, similar to how v2 did it.

The relevant files are sagemaker-core/src/sagemaker/core/processing.py, sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py, and sagemaker-train/src/sagemaker/train/model_trainer.py. Do not worry about CI failures in this iteration, just focus on this de-scoping! Only address other comments given on this PR if they are related to first bug

Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,42 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self._wait_for_job(self.latest_job, logs=logs)

def _wait_for_job(self, processing_job, logs=True):
"""Wait for a processing job using the stored sagemaker_session.

This method uses the sagemaker_session from the Processor instance
instead of the global default client, which fixes NoCredentialsError
when using assumed-role sessions.

Args:
processing_job: ProcessingJob resource object.
logs (bool): Whether to show logs (default: True).
Comment thread
aviruthen marked this conversation as resolved.
"""
job_name = processing_job.processing_job_name
if logs:
logs_for_processing_job(
self.sagemaker_session, job_name, wait=True
)
Comment thread
aviruthen marked this conversation as resolved.
else:
poll = 10
while True:
Comment thread
aviruthen marked this conversation as resolved.
processing_job = ProcessingJob.get(
processing_job_name=job_name,
session=self.sagemaker_session.boto_session,
)
status = processing_job.processing_job_status
if status in ("Completed", "Failed", "Stopped"):
if status == "Failed":
reason = getattr(
processing_job, "failure_reason", "Unknown"
)
raise RuntimeError(
f"Processing job {job_name} failed: {reason}"
)
break
time.sleep(poll)

def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
"""Extend inputs and outputs based on extra parameters"""
Expand Down Expand Up @@ -632,7 +667,8 @@ def submit(request):
transformed = transform(serialized_request, "CreateProcessingJobRequest")
# Remove tags from transformed dict as ProcessingJob resource doesn't accept it
transformed.pop("tags", None)
return ProcessingJob(**transformed)
processing_job = ProcessingJob(**transformed)
Comment thread
aviruthen marked this conversation as resolved.
return processing_job

def _get_process_args(self, inputs, outputs, experiment_config):
"""Gets a dict of arguments for a new Amazon SageMaker processing job."""
Expand Down Expand Up @@ -846,7 +882,7 @@ def run(
if not isinstance(self.sagemaker_session, PipelineSession):
self.jobs.append(self.latest_job)
if wait:
self.latest_job.wait(logs=logs)
self._wait_for_job(self.latest_job, logs=logs)

def _include_code_in_inputs(self, inputs, code, kms_key=None):
"""Converts code to appropriate input and includes in input list.
Expand Down Expand Up @@ -936,6 +972,26 @@ def _handle_user_code_url(self, code, kms_key=None):
)
return user_code_s3_uri

def _get_code_upload_bucket_and_prefix(self):
"""Get the S3 bucket and prefix for code uploads.

If code_location is set (on FrameworkProcessor), parse it to extract
bucket and prefix. Otherwise, use the session's default bucket.

Returns:
tuple: (bucket, prefix) for S3 uploads.
Comment thread
aviruthen marked this conversation as resolved.
Outdated
"""
code_location = getattr(self, "code_location", None)
if code_location:
parsed = urlparse(code_location)
bucket = parsed.netloc
prefix = parsed.path.lstrip("/")
return bucket, prefix
return (
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
)

def _upload_code(self, code, kms_key=None):
"""Uploads a code file or directory specified as a string and returns the S3 URI.

Expand All @@ -950,20 +1006,22 @@ def _upload_code(self, code, kms_key=None):
"""
from sagemaker.core.workflow.utilities import _pipeline_config

bucket, prefix = self._get_code_upload_bucket_and_prefix()

if _pipeline_config and _pipeline_config.code_hash:
desired_s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
bucket,
prefix,
_pipeline_config.pipeline_name,
self._CODE_CONTAINER_INPUT_NAME,
_pipeline_config.code_hash,
)
else:
desired_s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
bucket,
prefix,
self._current_job_name,
"input",
self._CODE_CONTAINER_INPUT_NAME,
Expand Down Expand Up @@ -1153,11 +1211,12 @@ def _package_code(
item_path = os.path.join(source_dir, item)
tar.add(item_path, arcname=item)

# Upload to S3
# Upload to S3 - honor code_location if set
bucket, prefix = self._get_code_upload_bucket_and_prefix()
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
bucket,
prefix,
job_name,
"source",
"sourcedir.tar.gz",
Expand All @@ -1174,6 +1233,46 @@ def _package_code(
os.unlink(tmp.name)
Comment thread
aviruthen marked this conversation as resolved.
return s3_uri

Comment thread
aviruthen marked this conversation as resolved.
_CODEARTIFACT_ARN_PATTERN = re.compile(
r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$"
)

@staticmethod
def _get_codeartifact_command(codeartifact_repo_arn: str) -> str:
"""Parse a CodeArtifact repository ARN and return the login command.

Args:
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository.
Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}

Returns:
str: The bash command to login to CodeArtifact via pip.

Raises:
ValueError: If the ARN format is invalid.
"""
match = FrameworkProcessor._CODEARTIFACT_ARN_PATTERN.match(codeartifact_repo_arn)
if not match:
raise ValueError(
f"Invalid CodeArtifact repository ARN: {codeartifact_repo_arn}. "
"Expected format: "
"arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}"
)
region = match.group(1)
domain_owner = match.group(2)
domain = match.group(3)
repository = match.group(4)

return (
"if ! hash aws 2>/dev/null; then\n"
" echo \"AWS CLI is not installed. Skipping CodeArtifact login.\"\n"
"else\n"
f" aws codeartifact login --tool pip "
f"--domain {domain} --domain-owner {domain_owner} "
f"--repository {repository} --region {region}\n"
"fi"
)

@_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run")
@runnable_by_pipeline
def run(
Expand All @@ -1189,6 +1288,7 @@ def run(
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
kms_key: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
codeartifact_repo_arn: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
Outdated
):
"""Runs a processing job.

Expand Down Expand Up @@ -1216,10 +1316,16 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use
for pip authentication when installing requirements.txt dependencies
(default: None). Format:
arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
Returns:
None or pipeline step arguments in case the Processor instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
self._codeartifact_repo_arn = codeartifact_repo_arn

s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
code,
source_dir,
Expand Down Expand Up @@ -1346,6 +1452,33 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):

def _generate_framework_script(self, user_script: str) -> str:
"""Generate the framework entrypoint file (as text) for a processing job."""
codeartifact_repo_arn = getattr(self, "_codeartifact_repo_arn", None)

if codeartifact_repo_arn:
codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn)
requirements_block = dedent(
"""\
if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

{codeartifact_login}
pip install -r requirements.txt
fi
"""
).format(codeartifact_login=codeartifact_login)
else:
requirements_block = dedent(
"""\
if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
"""
)

return dedent(
"""\
#!/bin/bash
Expand All @@ -1369,16 +1502,12 @@ def _generate_framework_script(self, user_script: str) -> str:
exit 1
fi

if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
{requirements_block}

{entry_point_command} {entry_point} "$@"
"""
).format(
requirements_block=requirements_block,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)
Expand Down
Loading
Loading