diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index b507ae1a93..0b97f538bb 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -296,7 +296,53 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self._wait_for_job(self.latest_job, logs=logs) + + def _wait_for_job(self, processing_job, logs=True, timeout=3600): + """Wait for a processing job using the stored sagemaker_session. + + This method uses the sagemaker_session from the Processor instance + instead of the global default client, which fixes NoCredentialsError + when using assumed-role sessions. + + Args: + processing_job: ProcessingJob resource object. + logs (bool): Whether to show logs (default: True). + timeout (int): Maximum time in seconds to wait (default: 3600). + If None, waits indefinitely. + """ + job_name = processing_job.processing_job_name + if logs: + logs_for_processing_job( + self.sagemaker_session, job_name, wait=True + ) + else: + poll = 10 + start_time = time.time() + while True: + if timeout and (time.time() - start_time) > timeout: + raise RuntimeError( + f"Timed out waiting for processing job {job_name} " + f"after {timeout} seconds" + ) + # TODO: Ideally sagemaker-core's ProcessingJob.refresh()/wait() + # should accept a session parameter. Using ProcessingJob.get() + # with the user's boto_session as a workaround. + processing_job = ProcessingJob.get( + processing_job_name=job_name, + session=self.sagemaker_session.boto_session, + ) + status = processing_job.processing_job_status + if status in ("Completed", "Failed", "Stopped"): + if status == "Failed": + reason = getattr( + processing_job, "failure_reason", "Unknown" + ) + raise RuntimeError( + f"Processing job {job_name} failed: {reason}" + ) + break + time.sleep(poll) def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613 """Extend inputs and outputs based on extra parameters""" @@ -632,7 +678,8 @@ def submit(request): transformed = transform(serialized_request, "CreateProcessingJobRequest") # Remove tags from transformed dict as ProcessingJob resource doesn't accept it transformed.pop("tags", None) - return ProcessingJob(**transformed) + processing_job = ProcessingJob(**transformed) + return processing_job def _get_process_args(self, inputs, outputs, experiment_config): """Gets a dict of arguments for a new Amazon SageMaker processing job.""" @@ -846,7 +893,7 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self._wait_for_job(self.latest_job, logs=logs) def _include_code_in_inputs(self, inputs, code, kms_key=None): """Converts code to appropriate input and includes in input list. @@ -1153,8 +1200,7 @@ def _package_code( item_path = os.path.join(source_dir, item) tar.add(item_path, arcname=item) - # Upload to S3 - s3_uri = s3.s3_path_join( + s3_uri = s3.s3_path_join( "s3://", self.sagemaker_session.default_bucket(), self.sagemaker_session.default_bucket_prefix or "", @@ -1346,6 +1392,17 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): def _generate_framework_script(self, user_script: str) -> str: """Generate the framework entrypoint file (as text) for a processing job.""" + requirements_block = dedent( + """\ + if [[ -f 'requirements.txt' ]]; then + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + """ + ) + return dedent( """\ #!/bin/bash @@ -1369,16 +1426,12 @@ def _generate_framework_script(self, user_script: str) -> str: exit 1 fi - if [[ -f 'requirements.txt' ]]; then - # Some py3 containers has typing, which may breaks pip install - pip uninstall --yes typing - - pip install -r requirements.txt - fi + {requirements_block} {entry_point_command} {entry_point} "$@" """ ).format( + requirements_block=requirements_block, entry_point_command=" ".join(self.command), entry_point=user_script, ) diff --git a/sagemaker-core/tests/unit/test_processing_regressions.py b/sagemaker-core/tests/unit/test_processing_regressions.py new file mode 100644 index 0000000000..3fb65397e9 --- /dev/null +++ b/sagemaker-core/tests/unit/test_processing_regressions.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session.""" +import pytest +from unittest.mock import MagicMock, patch, call + + +class TestBug1ProcessorWaitUsesSession: + """Bug 1: wait=True should use sagemaker_session, not global default client.""" + + def test_processor_wait_for_job_uses_session_no_logs(self): + """Test that _wait_for_job uses the Processor's sagemaker_session (no logs).""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) + mock_session.boto_session = MagicMock() + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + + # Mock ProcessingJob.get to return a completed job + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "Completed" + MockPJ.get.return_value = mock_refreshed + + processor._wait_for_job(mock_job, logs=False) + + MockPJ.get.assert_called_with( + processing_job_name="test-job", + session=mock_session.boto_session, + ) + + def test_processor_wait_for_job_uses_session_with_logs(self): + """Test that _wait_for_job with logs=True uses logs_for_processing_job.""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) + mock_session.boto_session = MagicMock() + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + + with patch("sagemaker.core.processing.logs_for_processing_job") as mock_logs: + processor._wait_for_job(mock_job, logs=True) + + mock_logs.assert_called_once_with( + mock_session, "test-job", wait=True + ) + + def test_processor_wait_for_job_raises_on_failed(self): + """Test that _wait_for_job raises RuntimeError when job fails.""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) + mock_session.boto_session = MagicMock() + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "Failed" + mock_refreshed.failure_reason = "OutOfMemory" + MockPJ.get.return_value = mock_refreshed + + with pytest.raises(RuntimeError, match="failed.*OutOfMemory"): + processor._wait_for_job(mock_job, logs=False) + + def test_processor_wait_for_job_timeout(self): + """Test that _wait_for_job raises RuntimeError on timeout.""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) + mock_session.boto_session = MagicMock() + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "InProgress" + MockPJ.get.return_value = mock_refreshed + + with patch("sagemaker.core.processing.time") as mock_time: + # Simulate timeout: first call returns 0, second returns > timeout + mock_time.time.side_effect = [0, 0, 5000] + mock_time.sleep = MagicMock() + + with pytest.raises(RuntimeError, match="Timed out"): + processor._wait_for_job(mock_job, logs=False, timeout=1) + + def test_processor_run_calls_wait_for_job(self): + """Test that Processor.run with wait=True calls _wait_for_job.""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) + mock_session.boto_session = MagicMock() + mock_session.sagemaker_client = MagicMock() + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + # Verify _wait_for_job method exists and is callable + assert hasattr(processor, '_wait_for_job') + assert callable(processor._wait_for_job) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 59adcdfbfc..e13a5dd989 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -5,15 +5,25 @@ """ import logging +import re import time from contextlib import contextmanager from typing import Optional, Tuple +from sagemaker.core.helper.session_helper import Session from sagemaker.core.resources import TrainingJob from sagemaker.core.utils.exceptions import FailedStatusError, TimeoutExceededError from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil +logger = logging.getLogger(__name__) + + +def _to_snake_case(name: str) -> str: + """Convert a CamelCase string to snake_case.""" + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + @contextmanager def _suppress_info_logging(): @@ -215,17 +225,62 @@ def get_mlflow_url(training_job) -> str: +def _refresh_training_job( + training_job: TrainingJob, + sagemaker_session: Optional[Session] = None, +) -> None: + """Refresh training job using session-aware client if available. + + Uses the provided sagemaker_session's client to describe the training job + and update its attributes. This avoids using the global default client, + which fixes NoCredentialsError when using assumed-role sessions. + + TODO: Ideally sagemaker-core's TrainingJob.refresh() should accept a + session/client parameter so we don't need to call boto3 directly here. + This workaround should be removed once sagemaker-core supports + session-aware refresh. See: https://github.com/aws/sagemaker-python-sdk/issues/5765 + + Args: + training_job (TrainingJob): The training job to refresh. + sagemaker_session (Optional[Session]): SageMaker session with the + correct credentials. If None, falls back to default refresh. + """ + if sagemaker_session is not None: + response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job.training_job_name + ) + # Update training_job attributes from the describe response + for key, value in response.items(): + snake_key = _to_snake_case(key) + if hasattr(training_job, snake_key): + try: + setattr(training_job, snake_key, value) + except (AttributeError, TypeError, ValueError) as e: + logger.debug( + "Could not set attribute %s on training job: %s", + snake_key, + e, + ) + else: + training_job.refresh() + + def wait( training_job: TrainingJob, poll: int = 5, - timeout: Optional[int] = 3000 + timeout: Optional[int] = 3000, + sagemaker_session: Optional[Session] = None, ) -> None: """Wait for training job to complete with progress tracking. Args: training_job (TrainingJob): The SageMaker training job to monitor. - poll (int): Polling interval in seconds. Defaults to 3. - timeout (Optional[int]): Maximum wait time in seconds. Defaults to None. + poll (int): Polling interval in seconds. Defaults to 5. + timeout (Optional[int]): Maximum wait time in seconds. Defaults to 3000. + sagemaker_session (Optional[Session]): SageMaker session to use for + describe calls. If provided, uses the session's sagemaker_client + instead of the global default client, which fixes + NoCredentialsError when using assumed-role sessions. Raises: FailedStatusError: If the training job fails. @@ -277,7 +332,7 @@ def get_cached_mlflow_url(): iteration += 1 time.sleep(0.5) if iteration >= poll * 2: - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) iteration = 0 status = training_job.training_job_status @@ -442,7 +497,7 @@ def get_cached_mlflow_url(): while True: iteration += 1 time.sleep(poll) - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) status = training_job.training_job_status secondary_status = training_job.secondary_status diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..fcc3726cdf 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -118,6 +118,7 @@ from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait from sagemaker.train.local.local_container import _LocalContainer @@ -790,7 +791,10 @@ def train( self._latest_training_job = training_job if wait: - training_job.wait(logs=logs) + trainer_wait( + training_job=training_job, + sagemaker_session=self.sagemaker_session, + ) if logs and not wait: logger.warning( "Not displaing the training container logs as 'wait' is set to False." diff --git a/sagemaker-train/tests/unit/test_train_regressions.py b/sagemaker-train/tests/unit/test_train_regressions.py new file mode 100644 index 0000000000..50bc6b9f7f --- /dev/null +++ b/sagemaker-train/tests/unit/test_train_regressions.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session.""" +import inspect +import pytest +from unittest.mock import MagicMock + + +class TestBug1ModelTrainerWait: + """Bug 1: ModelTrainer.train(wait=True) should use sagemaker_session.""" + + def test_wait_function_accepts_sagemaker_session(self): + """Test that the wait function accepts sagemaker_session parameter.""" + from sagemaker.train.common_utils.trainer_wait import wait + + sig = inspect.signature(wait) + assert "sagemaker_session" in sig.parameters + + def test_refresh_training_job_uses_session_client(self): + """Test that _refresh_training_job uses session's sagemaker_client.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, + ) + + mock_session = MagicMock() + mock_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed", + "TrainingJobName": "test-job", + } + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=mock_session) + + mock_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName="test-job" + ) + + def test_refresh_training_job_without_session_uses_default(self): + """Test that _refresh_training_job falls back to default refresh.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, + ) + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=None) + + mock_job.refresh.assert_called_once() + + def test_to_snake_case(self): + """Test the _to_snake_case helper function.""" + from sagemaker.train.common_utils.trainer_wait import _to_snake_case + + assert _to_snake_case("TrainingJobStatus") == "training_job_status" + assert _to_snake_case("TrainingJobName") == "training_job_name" + assert _to_snake_case("SecondaryStatus") == "secondary_status" + assert _to_snake_case("already_snake") == "already_snake" + + def test_refresh_training_job_updates_attributes(self): + """Test that _refresh_training_job updates job attributes from describe response.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, + ) + + mock_session = MagicMock() + mock_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed", + "TrainingJobName": "test-job", + "SecondaryStatus": "Completed", + } + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + mock_job.training_job_status = "InProgress" + mock_job.secondary_status = "Training" + + _refresh_training_job(mock_job, sagemaker_session=mock_session) + + # Verify attributes were updated via setattr + mock_session.sagemaker_client.describe_training_job.assert_called_once()