-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix: [v3] FrameworkProcessor and ModelTrainer: 4 regressions (including dropping Code (5765) #5769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -632,7 +632,10 @@ def submit(request): | |
| transformed = transform(serialized_request, "CreateProcessingJobRequest") | ||
| # Remove tags from transformed dict as ProcessingJob resource doesn't accept it | ||
| transformed.pop("tags", None) | ||
| return ProcessingJob(**transformed) | ||
| processing_job = ProcessingJob(**transformed) | ||
|
aviruthen marked this conversation as resolved.
|
||
| # Store the sagemaker_session on the job so wait/refresh can use it | ||
| processing_job._sagemaker_session = self.sagemaker_session | ||
| return processing_job | ||
|
|
||
| def _get_process_args(self, inputs, outputs, experiment_config): | ||
| """Gets a dict of arguments for a new Amazon SageMaker processing job.""" | ||
|
|
@@ -936,6 +939,26 @@ def _handle_user_code_url(self, code, kms_key=None): | |
| ) | ||
| return user_code_s3_uri | ||
|
|
||
| def _get_code_upload_bucket_and_prefix(self): | ||
| """Get the S3 bucket and prefix for code uploads. | ||
|
|
||
| If code_location is set (on FrameworkProcessor), parse it to extract | ||
| bucket and prefix. Otherwise, use the session's default bucket. | ||
|
|
||
| Returns: | ||
| tuple: (bucket, prefix) for S3 uploads. | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| """ | ||
| code_location = getattr(self, 'code_location', None) | ||
| if code_location: | ||
| parsed = urlparse(code_location) | ||
| bucket = parsed.netloc | ||
| prefix = parsed.path.lstrip('/') | ||
| return bucket, prefix | ||
| return ( | ||
| self.sagemaker_session.default_bucket(), | ||
| self.sagemaker_session.default_bucket_prefix or "", | ||
| ) | ||
|
|
||
| def _upload_code(self, code, kms_key=None): | ||
| """Uploads a code file or directory specified as a string and returns the S3 URI. | ||
|
|
||
|
|
@@ -950,20 +973,22 @@ def _upload_code(self, code, kms_key=None): | |
| """ | ||
| from sagemaker.core.workflow.utilities import _pipeline_config | ||
|
|
||
| bucket, prefix = self._get_code_upload_bucket_and_prefix() | ||
|
|
||
| if _pipeline_config and _pipeline_config.code_hash: | ||
| desired_s3_uri = s3.s3_path_join( | ||
| "s3://", | ||
| self.sagemaker_session.default_bucket(), | ||
| self.sagemaker_session.default_bucket_prefix, | ||
| bucket, | ||
| prefix, | ||
| _pipeline_config.pipeline_name, | ||
| self._CODE_CONTAINER_INPUT_NAME, | ||
| _pipeline_config.code_hash, | ||
| ) | ||
| else: | ||
| desired_s3_uri = s3.s3_path_join( | ||
| "s3://", | ||
| self.sagemaker_session.default_bucket(), | ||
| self.sagemaker_session.default_bucket_prefix, | ||
| bucket, | ||
| prefix, | ||
| self._current_job_name, | ||
| "input", | ||
| self._CODE_CONTAINER_INPUT_NAME, | ||
|
|
@@ -1153,11 +1178,12 @@ def _package_code( | |
| item_path = os.path.join(source_dir, item) | ||
| tar.add(item_path, arcname=item) | ||
|
|
||
| # Upload to S3 | ||
| # Upload to S3 - honor code_location if set | ||
| bucket, prefix = self._get_code_upload_bucket_and_prefix() | ||
| s3_uri = s3.s3_path_join( | ||
| "s3://", | ||
| self.sagemaker_session.default_bucket(), | ||
| self.sagemaker_session.default_bucket_prefix or "", | ||
| bucket, | ||
| prefix, | ||
| job_name, | ||
| "source", | ||
| "sourcedir.tar.gz", | ||
|
|
@@ -1174,6 +1200,36 @@ def _package_code( | |
| os.unlink(tmp.name) | ||
|
aviruthen marked this conversation as resolved.
|
||
| return s3_uri | ||
|
|
||
|
aviruthen marked this conversation as resolved.
|
||
| @staticmethod | ||
| def _get_codeartifact_command(codeartifact_repo_arn): | ||
| """Parse a CodeArtifact repository ARN and return the login command. | ||
|
|
||
| Args: | ||
| codeartifact_repo_arn (str): The ARN of the CodeArtifact repository. | ||
| Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} | ||
|
|
||
| Returns: | ||
| str: The bash command to login to CodeArtifact via pip. | ||
| """ | ||
| # Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} | ||
| parts = codeartifact_repo_arn.split(':') | ||
| region = parts[3] | ||
| domain_owner = parts[4] | ||
| resource = parts[5] # repository/{domain}/{repository} | ||
| resource_parts = resource.split('/') | ||
| domain = resource_parts[1] | ||
| repository = resource_parts[2] | ||
|
|
||
| return ( | ||
| f'if ! hash aws 2>/dev/null; then\n' | ||
| f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n' | ||
| f'else\n' | ||
| f' aws codeartifact login --tool pip ' | ||
| f'--domain {domain} --domain-owner {domain_owner} ' | ||
| f'--repository {repository} --region {region}\n' | ||
| f'fi' | ||
| ) | ||
|
|
||
| @_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run") | ||
| @runnable_by_pipeline | ||
| def run( | ||
|
|
@@ -1189,6 +1245,7 @@ def run( | |
| job_name: Optional[str] = None, | ||
| experiment_config: Optional[Dict[str, str]] = None, | ||
| kms_key: Optional[str] = None, | ||
|
aviruthen marked this conversation as resolved.
|
||
| codeartifact_repo_arn: Optional[str] = None, | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| ): | ||
| """Runs a processing job. | ||
|
|
||
|
|
@@ -1216,10 +1273,16 @@ def run( | |
| experiment_config (dict[str, str]): Experiment management configuration. | ||
| kms_key (str): The ARN of the KMS key that is used to encrypt the | ||
| user code file (default: None). | ||
| codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use | ||
| for pip authentication when installing requirements.txt dependencies | ||
| (default: None). Format: | ||
| arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} | ||
| Returns: | ||
| None or pipeline step arguments in case the Processor instance is built with | ||
| :class:`~sagemaker.workflow.pipeline_context.PipelineSession` | ||
| """ | ||
| self._codeartifact_repo_arn = codeartifact_repo_arn | ||
|
|
||
| s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( | ||
| code, | ||
| source_dir, | ||
|
|
@@ -1346,6 +1409,33 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): | |
|
|
||
| def _generate_framework_script(self, user_script: str) -> str: | ||
| """Generate the framework entrypoint file (as text) for a processing job.""" | ||
| codeartifact_repo_arn = getattr(self, '_codeartifact_repo_arn', None) | ||
|
|
||
| if codeartifact_repo_arn: | ||
| codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn) | ||
| requirements_block = dedent( | ||
| """\ | ||
| if [[ -f 'requirements.txt' ]]; then | ||
| # Some py3 containers has typing, which may breaks pip install | ||
| pip uninstall --yes typing | ||
|
|
||
| {codeartifact_login} | ||
| pip install -r requirements.txt | ||
| fi | ||
| """ | ||
| ).format(codeartifact_login=codeartifact_login) | ||
| else: | ||
| requirements_block = dedent( | ||
| """\ | ||
| if [[ -f 'requirements.txt' ]]; then | ||
| # Some py3 containers has typing, which may breaks pip install | ||
| pip uninstall --yes typing | ||
|
|
||
| pip install -r requirements.txt | ||
| fi | ||
| """ | ||
| ) | ||
|
|
||
| return dedent( | ||
| """\ | ||
| #!/bin/bash | ||
|
|
@@ -1369,21 +1459,79 @@ def _generate_framework_script(self, user_script: str) -> str: | |
| exit 1 | ||
| fi | ||
|
|
||
| if [[ -f 'requirements.txt' ]]; then | ||
| # Some py3 containers has typing, which may breaks pip install | ||
| pip uninstall --yes typing | ||
|
|
||
| pip install -r requirements.txt | ||
| fi | ||
| {requirements_block} | ||
|
|
||
| {entry_point_command} {entry_point} "$@" | ||
| """ | ||
| ).format( | ||
| requirements_block=requirements_block, | ||
| entry_point_command=" ".join(self.command), | ||
| entry_point=user_script, | ||
| ) | ||
|
|
||
|
|
||
| def _wait_for_processing_job(processing_job, logs=True): | ||
| """Wait for a processing job using the stored sagemaker_session. | ||
|
|
||
| This function uses the sagemaker_session stored on the processing_job | ||
| (if available) instead of the global default client, which fixes | ||
| NoCredentialsError when using assumed-role sessions. | ||
|
|
||
| Args: | ||
| processing_job: ProcessingJob resource object with _sagemaker_session attached. | ||
| logs (bool): Whether to show logs (default: True). | ||
| """ | ||
| sagemaker_session = getattr(processing_job, '_sagemaker_session', None) | ||
| job_name = processing_job.processing_job_name | ||
|
|
||
| if sagemaker_session is not None: | ||
| if logs: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| logs_for_processing_job(sagemaker_session, job_name, wait=True) | ||
| else: | ||
| # Poll using the session's client | ||
| poll = 10 | ||
| while True: | ||
| response = sagemaker_session.sagemaker_client.describe_processing_job( | ||
| ProcessingJobName=job_name | ||
| ) | ||
| status = response.get('ProcessingJobStatus', 'Unknown') | ||
| if status in ('Completed', 'Failed', 'Stopped'): | ||
| if status == 'Failed': | ||
| reason = response.get('FailureReason', 'Unknown') | ||
| raise RuntimeError( | ||
| f"Processing job {job_name} failed: {reason}" | ||
| ) | ||
| break | ||
| time.sleep(poll) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing |
||
| else: | ||
| # Fallback to the original refresh-based wait | ||
| processing_job.wait(logs=logs) | ||
|
|
||
|
|
||
| # Monkey-patch ProcessingJob.wait to use session-aware waiting | ||
| _original_processing_job_wait = getattr(ProcessingJob, 'wait', None) | ||
|
|
||
|
|
||
| def _patched_processing_job_wait(self, logs=True): | ||
| """Session-aware wait for ProcessingJob.""" | ||
| if hasattr(self, '_sagemaker_session') and self._sagemaker_session is not None: | ||
| _wait_for_processing_job(self, logs=logs) | ||
| elif _original_processing_job_wait: | ||
| _original_processing_job_wait(self, logs=logs) | ||
| else: | ||
| # Fallback polling | ||
| poll = 10 | ||
| while True: | ||
| self.refresh() | ||
| status = self.processing_job_status | ||
| if status in ('Completed', 'Failed', 'Stopped'): | ||
| break | ||
| time.sleep(poll) | ||
|
|
||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Monkey-patching
Instead, consider:
|
||
|
|
||
| ProcessingJob.wait = _patched_processing_job_wait | ||
|
|
||
|
|
||
| class FeatureStoreOutput(ApiObject): | ||
| """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scope this PR down to Bug 1 only (wait=True ignores sagemaker session). Bugs 2, 3, and 4 are being addressed by separate PRs (5772 and 5773).
Revert all changes related to:
Keep only the fix for Bug 1: ProcessingJob.refresh() and TrainingJob.refresh() use a global default client instead of the sagemaker_session passed by the user, causing NoCredentialsError with assumed-role sessions. The fix should ensure that when wait=True is used in FrameworkProcessor.run() and ModelTrainer.train(), the wait/refresh calls use the same boto session that was passed to the processor/trainer.
Do NOT monkey-patch ProcessingJob.wait or TrainingJob.wait globally. Instead, implement the wait at the call site — in processing.py's _start_new method and in trainer_wait.py — by polling describe_processing_job / describe_training_job using the sagemaker_session's client directly, similar to how v2 did it.
The relevant files are sagemaker-core/src/sagemaker/core/processing.py, sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py, and sagemaker-train/src/sagemaker/train/model_trainer.py. Do not worry about CI failures in this iteration, just focus on this de-scoping! Only address other comments given on this PR if they are related to first bug