Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 162 additions & 14 deletions sagemaker-core/src/sagemaker/core/processing.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope this PR down to Bug 1 only (wait=True ignores sagemaker session). Bugs 2, 3, and 4 are being addressed by separate PRs (5772 and 5773).

Revert all changes related to:

  • Bug 2 (code_location) — any changes to _package_code or code upload paths in processing.py
  • Bug 3 (codeartifact_repo_arn) — any changes to FrameworkProcessor.run() signature or runproc script generation
  • Bug 4 (ModelTrainer CodeArtifact) — any changes to templates.py or INSTALL_REQUIREMENTS

Keep only the fix for Bug 1: ProcessingJob.refresh() and TrainingJob.refresh() use a global default client instead of the sagemaker_session passed by the user, causing NoCredentialsError with assumed-role sessions. The fix should ensure that when wait=True is used in FrameworkProcessor.run() and ModelTrainer.train(), the wait/refresh calls use the same boto session that was passed to the processor/trainer.

Do NOT monkey-patch ProcessingJob.wait or TrainingJob.wait globally. Instead, implement the wait at the call site — in processing.py's _start_new method and in trainer_wait.py — by polling describe_processing_job / describe_training_job using the sagemaker_session's client directly, similar to how v2 did it.

The relevant files are sagemaker-core/src/sagemaker/core/processing.py, sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py, and sagemaker-train/src/sagemaker/train/model_trainer.py. Do not worry about CI failures in this iteration, just focus on this de-scoping! Only address other comments given on this PR if they are related to first bug

Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,10 @@ def submit(request):
transformed = transform(serialized_request, "CreateProcessingJobRequest")
# Remove tags from transformed dict as ProcessingJob resource doesn't accept it
transformed.pop("tags", None)
return ProcessingJob(**transformed)
processing_job = ProcessingJob(**transformed)
Comment thread
aviruthen marked this conversation as resolved.
# Store the sagemaker_session on the job so wait/refresh can use it
processing_job._sagemaker_session = self.sagemaker_session
return processing_job

def _get_process_args(self, inputs, outputs, experiment_config):
"""Gets a dict of arguments for a new Amazon SageMaker processing job."""
Expand Down Expand Up @@ -936,6 +939,26 @@ def _handle_user_code_url(self, code, kms_key=None):
)
return user_code_s3_uri

def _get_code_upload_bucket_and_prefix(self):
"""Get the S3 bucket and prefix for code uploads.

If code_location is set (on FrameworkProcessor), parse it to extract
bucket and prefix. Otherwise, use the session's default bucket.

Returns:
tuple: (bucket, prefix) for S3 uploads.
Comment thread
aviruthen marked this conversation as resolved.
Outdated
"""
code_location = getattr(self, 'code_location', None)
if code_location:
parsed = urlparse(code_location)
bucket = parsed.netloc
prefix = parsed.path.lstrip('/')
return bucket, prefix
return (
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
)

def _upload_code(self, code, kms_key=None):
"""Uploads a code file or directory specified as a string and returns the S3 URI.

Expand All @@ -950,20 +973,22 @@ def _upload_code(self, code, kms_key=None):
"""
from sagemaker.core.workflow.utilities import _pipeline_config

bucket, prefix = self._get_code_upload_bucket_and_prefix()

if _pipeline_config and _pipeline_config.code_hash:
desired_s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
bucket,
prefix,
_pipeline_config.pipeline_name,
self._CODE_CONTAINER_INPUT_NAME,
_pipeline_config.code_hash,
)
else:
desired_s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix,
bucket,
prefix,
self._current_job_name,
"input",
self._CODE_CONTAINER_INPUT_NAME,
Expand Down Expand Up @@ -1153,11 +1178,12 @@ def _package_code(
item_path = os.path.join(source_dir, item)
tar.add(item_path, arcname=item)

# Upload to S3
# Upload to S3 - honor code_location if set
bucket, prefix = self._get_code_upload_bucket_and_prefix()
s3_uri = s3.s3_path_join(
"s3://",
self.sagemaker_session.default_bucket(),
self.sagemaker_session.default_bucket_prefix or "",
bucket,
prefix,
job_name,
"source",
"sourcedir.tar.gz",
Expand All @@ -1174,6 +1200,36 @@ def _package_code(
os.unlink(tmp.name)
Comment thread
aviruthen marked this conversation as resolved.
return s3_uri

Comment thread
aviruthen marked this conversation as resolved.
@staticmethod
def _get_codeartifact_command(codeartifact_repo_arn):
"""Parse a CodeArtifact repository ARN and return the login command.

Args:
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository.
Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}

Returns:
str: The bash command to login to CodeArtifact via pip.
"""
# Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
parts = codeartifact_repo_arn.split(':')
region = parts[3]
domain_owner = parts[4]
resource = parts[5] # repository/{domain}/{repository}
resource_parts = resource.split('/')
domain = resource_parts[1]
repository = resource_parts[2]

return (
f'if ! hash aws 2>/dev/null; then\n'
f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n'
f'else\n'
f' aws codeartifact login --tool pip '
f'--domain {domain} --domain-owner {domain_owner} '
f'--repository {repository} --region {region}\n'
f'fi'
)

@_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run")
@runnable_by_pipeline
def run(
Expand All @@ -1189,6 +1245,7 @@ def run(
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
kms_key: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
codeartifact_repo_arn: Optional[str] = None,
Comment thread
aviruthen marked this conversation as resolved.
Outdated
):
"""Runs a processing job.

Expand Down Expand Up @@ -1216,10 +1273,16 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use
for pip authentication when installing requirements.txt dependencies
(default: None). Format:
arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}
Returns:
None or pipeline step arguments in case the Processor instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
self._codeartifact_repo_arn = codeartifact_repo_arn

s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
code,
source_dir,
Expand Down Expand Up @@ -1346,6 +1409,33 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):

def _generate_framework_script(self, user_script: str) -> str:
"""Generate the framework entrypoint file (as text) for a processing job."""
codeartifact_repo_arn = getattr(self, '_codeartifact_repo_arn', None)

if codeartifact_repo_arn:
codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn)
requirements_block = dedent(
"""\
if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

{codeartifact_login}
pip install -r requirements.txt
fi
"""
).format(codeartifact_login=codeartifact_login)
else:
requirements_block = dedent(
"""\
if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
"""
)

return dedent(
"""\
#!/bin/bash
Expand All @@ -1369,21 +1459,79 @@ def _generate_framework_script(self, user_script: str) -> str:
exit 1
fi

if [[ -f 'requirements.txt' ]]; then
# Some py3 containers has typing, which may breaks pip install
pip uninstall --yes typing

pip install -r requirements.txt
fi
{requirements_block}

{entry_point_command} {entry_point} "$@"
"""
).format(
requirements_block=requirements_block,
entry_point_command=" ".join(self.command),
entry_point=user_script,
)


def _wait_for_processing_job(processing_job, logs=True):
"""Wait for a processing job using the stored sagemaker_session.

This function uses the sagemaker_session stored on the processing_job
(if available) instead of the global default client, which fixes
NoCredentialsError when using assumed-role sessions.

Args:
processing_job: ProcessingJob resource object with _sagemaker_session attached.
logs (bool): Whether to show logs (default: True).
"""
sagemaker_session = getattr(processing_job, '_sagemaker_session', None)
job_name = processing_job.processing_job_name

if sagemaker_session is not None:
if logs:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_wait_for_processing_job uses direct boto3-style describe_processing_job call instead of sagemaker-core. Per V3 architecture tenets, all subpackages should use sagemaker-core for AWS API interactions, not call sagemaker_client.describe_processing_job() directly. This is essentially a boto3 call wrapped in a session. Consider using the sagemaker-core ProcessingJob.get() or similar API with the appropriate session/config.

logs_for_processing_job(sagemaker_session, job_name, wait=True)
else:
# Poll using the session's client
poll = 10
while True:
response = sagemaker_session.sagemaker_client.describe_processing_job(
ProcessingJobName=job_name
)
status = response.get('ProcessingJobStatus', 'Unknown')
if status in ('Completed', 'Failed', 'Stopped'):
if status == 'Failed':
reason = response.get('FailureReason', 'Unknown')
raise RuntimeError(
f"Processing job {job_name} failed: {reason}"
)
break
time.sleep(poll)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing import time at the top of the file. The time.sleep(poll) call here requires time to be imported. Verify this import exists at the module level.

else:
# Fallback to the original refresh-based wait
processing_job.wait(logs=logs)


# Monkey-patch ProcessingJob.wait to use session-aware waiting
_original_processing_job_wait = getattr(ProcessingJob, 'wait', None)


def _patched_processing_job_wait(self, logs=True):
"""Session-aware wait for ProcessingJob."""
if hasattr(self, '_sagemaker_session') and self._sagemaker_session is not None:
_wait_for_processing_job(self, logs=logs)
elif _original_processing_job_wait:
_original_processing_job_wait(self, logs=logs)
else:
# Fallback polling
poll = 10
while True:
self.refresh()
status = self.processing_job_status
if status in ('Completed', 'Failed', 'Stopped'):
break
time.sleep(poll)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Monkey-patching ProcessingJob.wait is a fragile anti-pattern. This globally mutates a sagemaker-core resource class at import time, which:

  1. Affects all consumers of ProcessingJob, not just those using this code path
  2. Creates hidden coupling between modules
  3. Is fragile if ProcessingJob is imported before/after this module
  4. Violates the V3 tenet that subpackages should use sagemaker-core properly

Instead, consider:

  • Overriding the wait behavior in the Processor class itself (e.g., processor.latest_job_wait() that uses the stored session)
  • Or wrapping the wait call at the call site in _start_new / wherever wait() is invoked
  • Or contributing a fix to sagemaker-core to accept a session/client parameter in wait()/refresh()


ProcessingJob.wait = _patched_processing_job_wait


class FeatureStoreOutput(ApiObject):
"""Configuration for processing job outputs in Amazon SageMaker Feature Store."""

Expand Down
Loading
Loading