-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: [v3] FrameworkProcessor and ModelTrainer: 4 regressions (including dropping Code (5765) #5769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
sagemaker-bot
wants to merge
3
commits into
aws:master
from
sagemaker-bot:fix/v3-frameworkprocessor-and-modeltrainer-4-5765
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
sagemaker-core/tests/unit/test_processing_regressions.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session.""" | ||
| import pytest | ||
| from unittest.mock import MagicMock, patch, call | ||
|
|
||
|
|
||
| class TestBug1ProcessorWaitUsesSession: | ||
| """Bug 1: wait=True should use sagemaker_session, not global default client.""" | ||
|
|
||
| def test_processor_wait_for_job_uses_session_no_logs(self): | ||
| """Test that _wait_for_job uses the Processor's sagemaker_session (no logs).""" | ||
| from sagemaker.core.processing import Processor | ||
|
|
||
| mock_session = MagicMock() | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = "" | ||
| mock_session.expand_role.return_value = ( | ||
| "arn:aws:iam::123456789:role/MyRole" | ||
| ) | ||
| mock_session.boto_session = MagicMock() | ||
|
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789:role/MyRole", | ||
| image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
|
|
||
| mock_job = MagicMock() | ||
| mock_job.processing_job_name = "test-job" | ||
|
|
||
| # Mock ProcessingJob.get to return a completed job | ||
| with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: | ||
| mock_refreshed = MagicMock() | ||
| mock_refreshed.processing_job_status = "Completed" | ||
| MockPJ.get.return_value = mock_refreshed | ||
|
|
||
| processor._wait_for_job(mock_job, logs=False) | ||
|
|
||
| MockPJ.get.assert_called_with( | ||
| processing_job_name="test-job", | ||
| session=mock_session.boto_session, | ||
| ) | ||
|
|
||
| def test_processor_wait_for_job_uses_session_with_logs(self): | ||
| """Test that _wait_for_job with logs=True uses logs_for_processing_job.""" | ||
| from sagemaker.core.processing import Processor | ||
|
|
||
| mock_session = MagicMock() | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = "" | ||
| mock_session.expand_role.return_value = ( | ||
| "arn:aws:iam::123456789:role/MyRole" | ||
| ) | ||
| mock_session.boto_session = MagicMock() | ||
|
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789:role/MyRole", | ||
| image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
|
|
||
| mock_job = MagicMock() | ||
| mock_job.processing_job_name = "test-job" | ||
|
|
||
| with patch("sagemaker.core.processing.logs_for_processing_job") as mock_logs: | ||
| processor._wait_for_job(mock_job, logs=True) | ||
|
|
||
| mock_logs.assert_called_once_with( | ||
| mock_session, "test-job", wait=True | ||
| ) | ||
|
|
||
| def test_processor_wait_for_job_raises_on_failed(self): | ||
| """Test that _wait_for_job raises RuntimeError when job fails.""" | ||
| from sagemaker.core.processing import Processor | ||
|
|
||
| mock_session = MagicMock() | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = "" | ||
| mock_session.expand_role.return_value = ( | ||
| "arn:aws:iam::123456789:role/MyRole" | ||
| ) | ||
| mock_session.boto_session = MagicMock() | ||
|
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789:role/MyRole", | ||
| image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
|
|
||
| mock_job = MagicMock() | ||
| mock_job.processing_job_name = "test-job" | ||
|
|
||
| with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: | ||
| mock_refreshed = MagicMock() | ||
| mock_refreshed.processing_job_status = "Failed" | ||
| mock_refreshed.failure_reason = "OutOfMemory" | ||
| MockPJ.get.return_value = mock_refreshed | ||
|
|
||
| with pytest.raises(RuntimeError, match="failed.*OutOfMemory"): | ||
| processor._wait_for_job(mock_job, logs=False) | ||
|
|
||
| def test_processor_wait_for_job_timeout(self): | ||
| """Test that _wait_for_job raises RuntimeError on timeout.""" | ||
| from sagemaker.core.processing import Processor | ||
|
|
||
| mock_session = MagicMock() | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = "" | ||
| mock_session.expand_role.return_value = ( | ||
| "arn:aws:iam::123456789:role/MyRole" | ||
| ) | ||
| mock_session.boto_session = MagicMock() | ||
|
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789:role/MyRole", | ||
| image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
|
aviruthen marked this conversation as resolved.
|
||
| ) | ||
|
|
||
| mock_job = MagicMock() | ||
| mock_job.processing_job_name = "test-job" | ||
|
|
||
| with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: | ||
| mock_refreshed = MagicMock() | ||
| mock_refreshed.processing_job_status = "InProgress" | ||
| MockPJ.get.return_value = mock_refreshed | ||
|
|
||
| with patch("sagemaker.core.processing.time") as mock_time: | ||
| # Simulate timeout: first call returns 0, second returns > timeout | ||
| mock_time.time.side_effect = [0, 0, 5000] | ||
| mock_time.sleep = MagicMock() | ||
|
|
||
| with pytest.raises(RuntimeError, match="Timed out"): | ||
| processor._wait_for_job(mock_job, logs=False, timeout=1) | ||
|
|
||
| def test_processor_run_calls_wait_for_job(self): | ||
| """Test that Processor.run with wait=True calls _wait_for_job.""" | ||
| from sagemaker.core.processing import Processor | ||
|
|
||
| mock_session = MagicMock() | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = "" | ||
| mock_session.expand_role.return_value = ( | ||
| "arn:aws:iam::123456789:role/MyRole" | ||
| ) | ||
| mock_session.boto_session = MagicMock() | ||
| mock_session.sagemaker_client = MagicMock() | ||
|
|
||
| processor = Processor( | ||
| role="arn:aws:iam::123456789:role/MyRole", | ||
| image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", | ||
| instance_count=1, | ||
| instance_type="ml.m5.xlarge", | ||
| sagemaker_session=mock_session, | ||
| ) | ||
|
|
||
| # Verify _wait_for_job method exists and is callable | ||
| assert hasattr(processor, '_wait_for_job') | ||
| assert callable(processor._wait_for_job) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scope this PR down to Bug 1 only (wait=True ignores sagemaker session). Bugs 2, 3, and 4 are being addressed by separate PRs (5772 and 5773).
Revert all changes related to:
Keep only the fix for Bug 1: ProcessingJob.refresh() and TrainingJob.refresh() use a global default client instead of the sagemaker_session passed by the user, causing NoCredentialsError with assumed-role sessions. The fix should ensure that when wait=True is used in FrameworkProcessor.run() and ModelTrainer.train(), the wait/refresh calls use the same boto session that was passed to the processor/trainer.
Do NOT monkey-patch ProcessingJob.wait or TrainingJob.wait globally. Instead, implement the wait at the call site — in processing.py's _start_new method and in trainer_wait.py — by polling describe_processing_job / describe_training_job using the sagemaker_session's client directly, similar to how v2 did it.
The relevant files are sagemaker-core/src/sagemaker/core/processing.py, sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py, and sagemaker-train/src/sagemaker/train/model_trainer.py. Do not worry about CI failures in this iteration, just focus on this de-scoping! Only address other comments given on this PR if they are related to first bug