diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index b507ae1a93..8e5dc0955e 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -1271,6 +1271,18 @@ def _pack_and_upload_code( entrypoint_s3_uri = s3_payload.replace("sourcedir.tar.gz", "runproc.sh") + # Upload the CodeArtifact-aware install_requirements script alongside the source code + import sagemaker.core.utils.install_requirements as _ir_mod + + install_req_s3_uri = s3_payload.replace("sourcedir.tar.gz", "install_requirements.py") + evaluated_kms_key = kms_key if kms_key else self.output_kms_key + s3.S3Uploader.upload_string_as_file_body( + body=open(_ir_mod.__file__, "r").read(), + desired_s3_uri=install_req_s3_uri, + kms_key=evaluated_kms_key, + sagemaker_session=self.sagemaker_session, + ) + script = os.path.basename(code) evaluated_kms_key = kms_key if kms_key else self.output_kms_key s3_runproc_sh = self._create_and_upload_runproc( @@ -1373,7 +1385,7 @@ def _generate_framework_script(self, user_script: str) -> str: # Some py3 containers has typing, which may breaks pip install pip uninstall --yes typing - pip install -r requirements.txt + python3 /opt/ml/processing/input/code/install_requirements.py requirements.txt fi {entry_point_command} {entry_point} "$@" diff --git a/sagemaker-core/src/sagemaker/core/utils/install_requirements.py b/sagemaker-core/src/sagemaker/core/utils/install_requirements.py new file mode 100644 index 0000000000..9849ac3593 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/utils/install_requirements.py @@ -0,0 +1,199 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""CodeArtifact-aware pip requirements installer. + +Reads ``CA_REPOSITORY_ARN`` from the environment and authenticates with +CodeArtifact before installing packages. Tries boto3 first (matching +``sagemaker-training-toolkit``), falls back to AWS CLI, and hard-fails +when the env var is set but neither mechanism is available. + +Can be used as: + - An importable module: + + - ``configure_pip()`` — returns an authenticated pip index URL (or ``None``). + Use when you need to build your own pip command with custom flags. + - ``install_requirements(path)`` — configures pip and runs ``pip install -r``. + Use when you just want requirements installed. + + :: + + from sagemaker.core.utils.install_requirements import configure_pip, install_requirements + + - A standalone script: ``python install_requirements.py requirements.txt`` +""" + +from __future__ import absolute_import + +import enum +import logging +import os +import re +import subprocess +import sys + +logger = logging.getLogger(__name__) + +CA_REPOSITORY_ARN_ENV = "CA_REPOSITORY_ARN" + +_ARN_RE = re.compile(r"arn:([^:]+):codeartifact:([^:]+):([^:]+):repository/([^/]+)/(.+)") + + +class CodeArtifactAuthMethod(enum.Enum): + """Authentication method for CodeArtifact pip configuration.""" + + BOTO3 = "boto3" + """Use boto3 only. Fails if boto3 is not available.""" + + AWS_CLI = "aws_cli" + """Use AWS CLI only. Fails if AWS CLI is not available.""" + + AUTO = "auto" + """Try boto3 first, fall back to AWS CLI, hard-fail if neither is available.""" + + +def _parse_arn(arn): + """Parse a CodeArtifact repository ARN into its components. + + Returns: + Tuple of (region, account, domain, repository) or raises ValueError. + """ + m = _ARN_RE.match(arn) + if not m: + raise ValueError(f"Invalid {CA_REPOSITORY_ARN_ENV}: {arn}") + _, region, account, domain, repo = m.groups() + return region, account, domain, repo + + +def _get_index_boto3(region, account, domain, repo): + """Build an authenticated pip index URL using boto3.""" + import boto3 # noqa: delay import — may not be installed + + ca = boto3.client("codeartifact", region_name=region) + token = ca.get_authorization_token(domain=domain, domainOwner=account)["authorizationToken"] + endpoint = ca.get_repository_endpoint( + domain=domain, domainOwner=account, repository=repo, format="pypi" + )["repositoryEndpoint"] + return re.sub( + "https://", + f"https://aws:{token}@", + re.sub(f"{repo}/?$", f"{repo}/simple/", endpoint), + ) + + +def _login_awscli(region, account, domain, repo): + """Configure pip globally via ``aws codeartifact login``.""" + subprocess.check_call( + [ + "aws", + "codeartifact", + "login", + "--tool", + "pip", + "--domain", + domain, + "--domain-owner", + account, + "--repository", + repo, + "--region", + region, + ] + ) + + +def configure_pip(auth_method=CodeArtifactAuthMethod.AUTO): + """Configure pip for CodeArtifact if ``CA_REPOSITORY_ARN`` is set. + + Args: + auth_method: Authentication mechanism to use. Defaults to ``CodeArtifactAuthMethod.AUTO`` + (try boto3 first, fall back to AWS CLI). + + Returns: + An authenticated pip index URL (str) when boto3 succeeds, + ``None`` when AWS CLI was used (pip config modified globally), + or ``None`` when ``CA_REPOSITORY_ARN`` is not set. + + Raises: + SystemExit: When ``CA_REPOSITORY_ARN`` is set but the requested + auth method is not available. + ValueError: When the ARN format is invalid. + """ + arn = os.environ.get(CA_REPOSITORY_ARN_ENV) + if not arn: + return None + + region, account, domain, repo = _parse_arn(arn) + logger.info( + "Configuring pip for CodeArtifact " + "(domain=%s, domain_owner=%s, repository=%s, region=%s)", + domain, + account, + repo, + region, + ) + + if auth_method in (CodeArtifactAuthMethod.BOTO3, CodeArtifactAuthMethod.AUTO): + try: + return _get_index_boto3(region, account, domain, repo) + except ImportError: + if auth_method == CodeArtifactAuthMethod.BOTO3: + logger.error("boto3 is not available") + sys.exit(1) + logger.info("boto3 not available, trying AWS CLI fallback") + + if auth_method in (CodeArtifactAuthMethod.AWS_CLI, CodeArtifactAuthMethod.AUTO): + try: + _login_awscli(region, account, domain, repo) + return None + except FileNotFoundError: + if auth_method == CodeArtifactAuthMethod.AWS_CLI: + logger.error("AWS CLI is not available") + sys.exit(1) + logger.info("AWS CLI not available") + + # Hard fail — CA is configured but we can't authenticate + logger.error( + "%s is set but neither boto3 nor AWS CLI is available " + "to authenticate with CodeArtifact.", + CA_REPOSITORY_ARN_ENV, + ) + sys.exit(1) + + +def install_requirements( + requirements_file="requirements.txt", python_executable=None, auth_method=CodeArtifactAuthMethod.AUTO +): + """Install pip requirements with optional CodeArtifact authentication. + + Args: + requirements_file: Path to the requirements file. + python_executable: Python executable to use for pip. Defaults to ``sys.executable``. + auth_method: Authentication mechanism for CodeArtifact. Defaults to ``CodeArtifactAuthMethod.AUTO``. + """ + python_executable = python_executable or sys.executable + pip_cmd = [python_executable, "-m", "pip", "install", "-r", requirements_file] + index = configure_pip(auth_method=auth_method) + if index: + pip_cmd.extend(["-i", index]) + logger.info("Running: %s", " ".join(pip_cmd)) + subprocess.check_call(pip_cmd) + + +def main(): + """CLI entry point.""" + req_file = sys.argv[1] if len(sys.argv) > 1 else "requirements.txt" + install_requirements(req_file) + + +if __name__ == "__main__": + main() diff --git a/sagemaker-core/tests/unit/test_install_requirements.py b/sagemaker-core/tests/unit/test_install_requirements.py new file mode 100644 index 0000000000..ac026d6e8c --- /dev/null +++ b/sagemaker-core/tests/unit/test_install_requirements.py @@ -0,0 +1,238 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import subprocess +import sys +from unittest import mock + +import pytest + +from sagemaker.core.utils.install_requirements import ( + CA_REPOSITORY_ARN_ENV, + CodeArtifactAuthMethod, + _parse_arn, + configure_pip, + install_requirements, + main, +) + +_MODULE = "sagemaker.core.utils.install_requirements" + +VALID_ARN = "arn:aws:codeartifact:us-west-2:123456789012:repository/my-domain/my-repo" +PARSED = ("us-west-2", "123456789012", "my-domain", "my-repo") +FAKE_TOKEN = "fake-auth-token" +FAKE_ENDPOINT = ( + "https://my-domain-123456789012.d.codeartifact.us-west-2.amazonaws.com/pypi/my-repo/" +) +EXPECTED_INDEX = ( + f"https://aws:{FAKE_TOKEN}" + "@my-domain-123456789012.d.codeartifact.us-west-2.amazonaws.com/pypi/my-repo/simple/" +) +EXPECTED_CLI_CMD = [ + "aws", "codeartifact", "login", "--tool", "pip", + "--domain", "my-domain", "--domain-owner", "123456789012", + "--repository", "my-repo", "--region", "us-west-2", +] # fmt: skip + + +@pytest.fixture() +def ca_env(): + """Set CA_REPOSITORY_ARN in the environment for the duration of a test.""" + with mock.patch.dict("os.environ", {CA_REPOSITORY_ARN_ENV: VALID_ARN}): + yield + + +@pytest.fixture() +def mock_boto3_ca(): + """Mock boto3 CodeArtifact client with valid responses.""" + client = mock.MagicMock() + client.get_authorization_token.return_value = {"authorizationToken": FAKE_TOKEN} + client.get_repository_endpoint.return_value = {"repositoryEndpoint": FAKE_ENDPOINT} + with mock.patch("boto3.client", return_value=client) as factory: + yield factory, client + + +def _pip_cmd(*extra): + return [sys.executable, "-m", "pip", "install", "-r", "reqs.txt", *extra] + + +# --------------------------------------------------------------------------- +# _parse_arn +# --------------------------------------------------------------------------- +class TestParseArn: + def test_valid_arn(self): + assert _parse_arn(VALID_ARN) == PARSED + + def test_invalid_arn(self): + with pytest.raises(ValueError, match="Invalid CA_REPOSITORY_ARN"): + _parse_arn("not-an-arn") + + def test_arn_with_nested_repo(self): + region, account, domain, repo = _parse_arn( + "arn:aws:codeartifact:eu-west-1:111111111111:repository/dom/nested/repo" + ) + assert (region, account, domain, repo) == ( + "eu-west-1", + "111111111111", + "dom", + "nested/repo", + ) + + +# --------------------------------------------------------------------------- +# configure_pip — AUTO (default) +# --------------------------------------------------------------------------- +class TestConfigurePipAuto: + def test_no_env_var_returns_none(self): + with mock.patch.dict("os.environ", {}, clear=True): + assert configure_pip() is None + + def test_invalid_arn_raises(self, ca_env): + with mock.patch.dict("os.environ", {CA_REPOSITORY_ARN_ENV: "garbage"}): + with pytest.raises(ValueError, match="Invalid"): + configure_pip() + + def test_boto3_success(self, ca_env, mock_boto3_ca): + factory, client = mock_boto3_ca + result = configure_pip() + + factory.assert_called_once_with("codeartifact", region_name="us-west-2") + client.get_authorization_token.assert_called_once_with( + domain="my-domain", domainOwner="123456789012" + ) + client.get_repository_endpoint.assert_called_once_with( + domain="my-domain", domainOwner="123456789012", repository="my-repo", format="pypi" + ) + assert result == EXPECTED_INDEX + + def test_falls_back_to_cli_when_no_boto3(self, ca_env): + with mock.patch(f"{_MODULE}._get_index_boto3", side_effect=ImportError): + with mock.patch("subprocess.check_call") as mock_call: + result = configure_pip() + + mock_call.assert_called_once_with(EXPECTED_CLI_CMD) + assert result is None + + def test_hard_fails_when_nothing_available(self, ca_env): + with mock.patch(f"{_MODULE}._get_index_boto3", side_effect=ImportError): + with mock.patch(f"{_MODULE}._login_awscli", side_effect=FileNotFoundError): + with pytest.raises(SystemExit): + configure_pip() + + def test_boto3_api_error_propagates_not_fallback(self, ca_env): + """boto3 available but API call fails → raise, don't fall back to CLI.""" + with mock.patch(f"{_MODULE}._get_index_boto3", side_effect=Exception("AccessDenied")): + with mock.patch(f"{_MODULE}._login_awscli") as mock_cli: + with pytest.raises(Exception, match="AccessDenied"): + configure_pip() + mock_cli.assert_not_called() + + +# --------------------------------------------------------------------------- +# configure_pip — BOTO3 only +# --------------------------------------------------------------------------- +class TestConfigurePipBoto3Only: + def test_fails_when_unavailable(self, ca_env): + with mock.patch(f"{_MODULE}._get_index_boto3", side_effect=ImportError): + with pytest.raises(SystemExit): + configure_pip(auth_method=CodeArtifactAuthMethod.BOTO3) + + def test_does_not_try_cli(self, ca_env, mock_boto3_ca): + with mock.patch(f"{_MODULE}._login_awscli") as mock_cli: + configure_pip(auth_method=CodeArtifactAuthMethod.BOTO3) + mock_cli.assert_not_called() + + +# --------------------------------------------------------------------------- +# configure_pip — AWS_CLI only +# --------------------------------------------------------------------------- +class TestConfigurePipCliOnly: + def test_succeeds(self, ca_env): + with mock.patch("subprocess.check_call") as mock_call: + result = configure_pip(auth_method=CodeArtifactAuthMethod.AWS_CLI) + assert result is None + mock_call.assert_called_once() + + def test_fails_when_unavailable(self, ca_env): + with mock.patch(f"{_MODULE}._login_awscli", side_effect=FileNotFoundError): + with pytest.raises(SystemExit): + configure_pip(auth_method=CodeArtifactAuthMethod.AWS_CLI) + + def test_does_not_try_boto3(self, ca_env): + with mock.patch(f"{_MODULE}._get_index_boto3") as mock_boto3: + with mock.patch("subprocess.check_call"): + configure_pip(auth_method=CodeArtifactAuthMethod.AWS_CLI) + mock_boto3.assert_not_called() + + +# --------------------------------------------------------------------------- +# install_requirements +# --------------------------------------------------------------------------- +class TestInstallRequirements: + def test_without_codeartifact(self): + with mock.patch.dict("os.environ", {}, clear=True): + with mock.patch("subprocess.check_call") as mock_call: + install_requirements("reqs.txt") + mock_call.assert_called_once_with(_pip_cmd()) + + def test_with_codeartifact_index(self): + with mock.patch(f"{_MODULE}.configure_pip", return_value=EXPECTED_INDEX): + with mock.patch("subprocess.check_call") as mock_call: + install_requirements("reqs.txt") + mock_call.assert_called_once_with(_pip_cmd("-i", EXPECTED_INDEX)) + + def test_with_cli_fallback_no_index_flag(self): + with mock.patch(f"{_MODULE}.configure_pip", return_value=None): + with mock.patch("subprocess.check_call") as mock_call: + install_requirements("reqs.txt") + mock_call.assert_called_once_with(_pip_cmd()) + + def test_custom_python_executable(self): + with mock.patch.dict("os.environ", {}, clear=True): + with mock.patch("subprocess.check_call") as mock_call: + install_requirements("reqs.txt", python_executable="/usr/bin/python3") + mock_call.assert_called_once_with( + ["/usr/bin/python3", "-m", "pip", "install", "-r", "reqs.txt"] + ) + + def test_pip_failure_propagates(self): + with mock.patch.dict("os.environ", {}, clear=True): + with mock.patch( + "subprocess.check_call", side_effect=subprocess.CalledProcessError(1, "pip") + ): + with pytest.raises(subprocess.CalledProcessError): + install_requirements("reqs.txt") + + def test_auth_method_passed_through(self): + with mock.patch(f"{_MODULE}.configure_pip", return_value=None) as mock_configure: + with mock.patch("subprocess.check_call"): + install_requirements("reqs.txt", auth_method=CodeArtifactAuthMethod.BOTO3) + mock_configure.assert_called_once_with(auth_method=CodeArtifactAuthMethod.BOTO3) + + +# --------------------------------------------------------------------------- +# main (CLI entry point) +# --------------------------------------------------------------------------- +class TestMain: + def test_default_requirements_file(self): + with mock.patch(f"{_MODULE}.install_requirements") as mock_install: + with mock.patch("sys.argv", ["install_requirements.py"]): + main() + mock_install.assert_called_once_with("requirements.txt") + + def test_custom_requirements_file(self): + with mock.patch(f"{_MODULE}.install_requirements") as mock_install: + with mock.patch("sys.argv", ["install_requirements.py", "custom.txt"]): + main() + mock_install.assert_called_once_with("custom.txt") diff --git a/sagemaker-core/tests/unit/test_processing.py b/sagemaker-core/tests/unit/test_processing.py index dbe8d5f9ef..59d3331022 100644 --- a/sagemaker-core/tests/unit/test_processing.py +++ b/sagemaker-core/tests/unit/test_processing.py @@ -823,6 +823,8 @@ def test_generate_framework_script(self, mock_session): assert "#!/bin/bash" in script assert "train.py" in script assert "python3" in script + assert "install_requirements.py" in script + assert "pip install -r requirements.txt" not in script def test_create_and_upload_runproc_with_pipeline(self, mock_session): processor = FrameworkProcessor( @@ -1240,7 +1242,7 @@ def test_pack_and_upload_code_with_local_file(self, mock_session): with patch( "sagemaker.core.s3.S3Uploader.upload_string_as_file_body", return_value="s3://bucket/runproc.sh", - ): + ) as mock_upload: result_uri, result_inputs, result_job_name = processor._pack_and_upload_code( code=entry_point, source_dir=None, @@ -1253,6 +1255,15 @@ def test_pack_and_upload_code_with_local_file(self, mock_session): assert result_uri == "s3://bucket/runproc.sh" assert len(result_inputs) == 1 + # Verify both install_requirements.py and runproc.sh were uploaded + upload_uris = [ + call.kwargs.get("desired_s3_uri") or call.args[1] + for call in mock_upload.call_args_list + ] + assert any("install_requirements.py" in uri for uri in upload_uris) + assert any("runproc.sh" in uri for uri in upload_uris) + assert mock_upload.call_count == 2 + class TestProcessingInputOutputHelpers: def test_processing_input_with_app_managed(self): diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..48c42c9093 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -637,6 +637,14 @@ def _create_training_job_args( # Copy everything under container_drivers/ to a temporary directory shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True) + # Copy the CodeArtifact-aware install_requirements script from sagemaker-core + # so it's available in the container at /opt/ml/input/data/sm_drivers/scripts/ + import sagemaker.core.utils.install_requirements as _ir_mod + shutil.copy2( + _ir_mod.__file__, + os.path.join(self._temp_code_dir.name, "scripts", "install_requirements.py"), + ) + # If distributed is provided, overwrite code under /drivers if self.distributed: distributed_driver_dir = self.distributed.driver_dir diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index c943769618..836471952c 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -28,7 +28,7 @@ if [ -f requirements.txt ]; then echo "Installing requirements" cat requirements.txt - $SM_PIP_CMD install -r requirements.txt + $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/install_requirements.py requirements.txt else echo "No requirements.txt file found. Skipping installation." fi @@ -36,7 +36,7 @@ INSTALL_REQUIREMENTS = """ echo "Installing requirements" -$SM_PIP_CMD install -r {requirements_file} +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/install_requirements.py {requirements_file} """ EXEUCTE_DISTRIBUTED_DRIVER = """