Skip to content

Commit 275ee1c

Browse files
committed
feat: add CodeArtifact support for ModelTrainer and FrameworkProcessor requirements.txt installation
SDK v3's `ModelTrainer` and `FrameworkProcessor` override the container entrypoint with SDK-generated scripts (`sm_train.sh`, `runproc.sh`), bypassing the container's entrypoint which involved `sagemaker-training-toolkit` handling `CA_REPOSITORY_ARN`-based CodeArtifact authentication. This broke CodeArtifact support for both training (Bug 4) and processing (Bug 3) reported in #5765. This is the stopgap solution proposed in this comment[#5765 (comment)]: a self-contained install_requirements.py script that the SDK uploads to the container alongside its generated entrypoint scripts. - Add `install_requirements.py` in sagemaker-core — reads `CA_REPOSITORY_ARN` from container environment; no-op if unset - Try `boto3` first (matching sagemaker-training-toolkit), fall back to `AWS CLI`, hard-fail if neither is available - Wire into `ModelTrainer`: copy script into `sm_drivers/scripts/`, update `INSTALL_REQUIREMENTS` templates to call it instead of bare `pip install` - Wire into `FrameworkProcessor`: upload script as sibling file alongside `runproc.sh`, update generated script to call it
1 parent 98683ac commit 275ee1c

6 files changed

Lines changed: 472 additions & 4 deletions

File tree

sagemaker-core/src/sagemaker/core/processing.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,18 @@ def _pack_and_upload_code(
12711271

12721272
entrypoint_s3_uri = s3_payload.replace("sourcedir.tar.gz", "runproc.sh")
12731273

1274+
# Upload the CodeArtifact-aware install_requirements script alongside the source code
1275+
import sagemaker.core.utils.install_requirements as _ir_mod
1276+
1277+
install_req_s3_uri = s3_payload.replace("sourcedir.tar.gz", "install_requirements.py")
1278+
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
1279+
s3.S3Uploader.upload_string_as_file_body(
1280+
body=open(_ir_mod.__file__, "r").read(),
1281+
desired_s3_uri=install_req_s3_uri,
1282+
kms_key=evaluated_kms_key,
1283+
sagemaker_session=self.sagemaker_session,
1284+
)
1285+
12741286
script = os.path.basename(code)
12751287
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
12761288
s3_runproc_sh = self._create_and_upload_runproc(
@@ -1373,7 +1385,7 @@ def _generate_framework_script(self, user_script: str) -> str:
13731385
# Some py3 containers has typing, which may breaks pip install
13741386
pip uninstall --yes typing
13751387
1376-
pip install -r requirements.txt
1388+
python3 /opt/ml/processing/input/code/install_requirements.py requirements.txt
13771389
fi
13781390
13791391
{entry_point_command} {entry_point} "$@"
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""CodeArtifact-aware pip requirements installer.
14+
15+
Reads ``CA_REPOSITORY_ARN`` from the environment and authenticates with
16+
CodeArtifact before installing packages. Tries boto3 first (matching
17+
``sagemaker-training-toolkit``), falls back to AWS CLI, and hard-fails
18+
when the env var is set but neither mechanism is available.
19+
20+
Can be used as:
21+
- An importable module:
22+
23+
- ``configure_pip()`` — returns an authenticated pip index URL (or ``None``).
24+
Use when you need to build your own pip command with custom flags.
25+
- ``install_requirements(path)`` — configures pip and runs ``pip install -r``.
26+
Use when you just want requirements installed.
27+
28+
::
29+
30+
from sagemaker.core.utils.install_requirements import configure_pip, install_requirements
31+
32+
- A standalone script: ``python install_requirements.py requirements.txt``
33+
"""
34+
35+
from __future__ import absolute_import
36+
37+
import enum
38+
import logging
39+
import os
40+
import re
41+
import subprocess
42+
import sys
43+
44+
logger = logging.getLogger(__name__)
45+
46+
CA_REPOSITORY_ARN_ENV = "CA_REPOSITORY_ARN"
47+
48+
_ARN_RE = re.compile(r"arn:([^:]+):codeartifact:([^:]+):([^:]+):repository/([^/]+)/(.+)")
49+
50+
51+
class CodeArtifactAuthMethod(enum.Enum):
52+
"""Authentication method for CodeArtifact pip configuration."""
53+
54+
BOTO3 = "boto3"
55+
"""Use boto3 only. Fails if boto3 is not available."""
56+
57+
AWS_CLI = "aws_cli"
58+
"""Use AWS CLI only. Fails if AWS CLI is not available."""
59+
60+
AUTO = "auto"
61+
"""Try boto3 first, fall back to AWS CLI, hard-fail if neither is available."""
62+
63+
64+
def _parse_arn(arn):
65+
"""Parse a CodeArtifact repository ARN into its components.
66+
67+
Returns:
68+
Tuple of (region, account, domain, repository) or raises ValueError.
69+
"""
70+
m = _ARN_RE.match(arn)
71+
if not m:
72+
raise ValueError(f"Invalid {CA_REPOSITORY_ARN_ENV}: {arn}")
73+
_, region, account, domain, repo = m.groups()
74+
return region, account, domain, repo
75+
76+
77+
def _get_index_boto3(region, account, domain, repo):
78+
"""Build an authenticated pip index URL using boto3."""
79+
import boto3 # noqa: delay import — may not be installed
80+
81+
ca = boto3.client("codeartifact", region_name=region)
82+
token = ca.get_authorization_token(domain=domain, domainOwner=account)["authorizationToken"]
83+
endpoint = ca.get_repository_endpoint(
84+
domain=domain, domainOwner=account, repository=repo, format="pypi"
85+
)["repositoryEndpoint"]
86+
return re.sub(
87+
"https://",
88+
f"https://aws:{token}@",
89+
re.sub(f"{repo}/?$", f"{repo}/simple/", endpoint),
90+
)
91+
92+
93+
def _login_awscli(region, account, domain, repo):
94+
"""Configure pip globally via ``aws codeartifact login``."""
95+
subprocess.check_call(
96+
[
97+
"aws",
98+
"codeartifact",
99+
"login",
100+
"--tool",
101+
"pip",
102+
"--domain",
103+
domain,
104+
"--domain-owner",
105+
account,
106+
"--repository",
107+
repo,
108+
"--region",
109+
region,
110+
]
111+
)
112+
113+
114+
def configure_pip(auth_method=CodeArtifactAuthMethod.AUTO):
115+
"""Configure pip for CodeArtifact if ``CA_REPOSITORY_ARN`` is set.
116+
117+
Args:
118+
auth_method: Authentication mechanism to use. Defaults to ``CodeArtifactAuthMethod.AUTO``
119+
(try boto3 first, fall back to AWS CLI).
120+
121+
Returns:
122+
An authenticated pip index URL (str) when boto3 succeeds,
123+
``None`` when AWS CLI was used (pip config modified globally),
124+
or ``None`` when ``CA_REPOSITORY_ARN`` is not set.
125+
126+
Raises:
127+
SystemExit: When ``CA_REPOSITORY_ARN`` is set but the requested
128+
auth method is not available.
129+
ValueError: When the ARN format is invalid.
130+
"""
131+
arn = os.environ.get(CA_REPOSITORY_ARN_ENV)
132+
if not arn:
133+
return None
134+
135+
region, account, domain, repo = _parse_arn(arn)
136+
logger.info(
137+
"Configuring pip for CodeArtifact "
138+
"(domain=%s, domain_owner=%s, repository=%s, region=%s)",
139+
domain,
140+
account,
141+
repo,
142+
region,
143+
)
144+
145+
if auth_method in (CodeArtifactAuthMethod.BOTO3, CodeArtifactAuthMethod.AUTO):
146+
try:
147+
return _get_index_boto3(region, account, domain, repo)
148+
except ImportError:
149+
if auth_method == CodeArtifactAuthMethod.BOTO3:
150+
logger.error("boto3 is not available")
151+
sys.exit(1)
152+
logger.info("boto3 not available, trying AWS CLI fallback")
153+
154+
if auth_method in (CodeArtifactAuthMethod.AWS_CLI, CodeArtifactAuthMethod.AUTO):
155+
try:
156+
_login_awscli(region, account, domain, repo)
157+
return None
158+
except FileNotFoundError:
159+
if auth_method == CodeArtifactAuthMethod.AWS_CLI:
160+
logger.error("AWS CLI is not available")
161+
sys.exit(1)
162+
logger.info("AWS CLI not available")
163+
164+
# Hard fail — CA is configured but we can't authenticate
165+
logger.error(
166+
"%s is set but neither boto3 nor AWS CLI is available "
167+
"to authenticate with CodeArtifact.",
168+
CA_REPOSITORY_ARN_ENV,
169+
)
170+
sys.exit(1)
171+
172+
173+
def install_requirements(
174+
requirements_file="requirements.txt", python_executable=None, auth_method=CodeArtifactAuthMethod.AUTO
175+
):
176+
"""Install pip requirements with optional CodeArtifact authentication.
177+
178+
Args:
179+
requirements_file: Path to the requirements file.
180+
python_executable: Python executable to use for pip. Defaults to ``sys.executable``.
181+
auth_method: Authentication mechanism for CodeArtifact. Defaults to ``CodeArtifactAuthMethod.AUTO``.
182+
"""
183+
python_executable = python_executable or sys.executable
184+
pip_cmd = [python_executable, "-m", "pip", "install", "-r", requirements_file]
185+
index = configure_pip(auth_method=auth_method)
186+
if index:
187+
pip_cmd.extend(["-i", index])
188+
logger.info("Running: %s", " ".join(pip_cmd))
189+
subprocess.check_call(pip_cmd)
190+
191+
192+
def main():
193+
"""CLI entry point."""
194+
req_file = sys.argv[1] if len(sys.argv) > 1 else "requirements.txt"
195+
install_requirements(req_file)
196+
197+
198+
if __name__ == "__main__":
199+
main()

0 commit comments

Comments
 (0)