Skip to content

Commit 37b43b7

Browse files
shrijeetjjamesfyu
andauthored
feat: Add SageMaker token generator to sagemaker-core (#5868)
* feat: Add SageMaker token generator to sagemaker-core Embed the aws-sagemaker-token-generator library into sagemaker.core so users can generate SageMaker bearer tokens without installing a separate wheel. Usage: from sagemaker.core.aws_sagemaker_token_generator import provide_token token = provide_token(region='us-east-1') * test: Add unit tests for SageMaker token generator Add unit tests for the token_generator module covering token format, base64 encoding, version info, multiple regions, session tokens, credential validation, expiry handling, and API consistency. Tests are modeled after the Bedrock token generator test suite: https://github.com/aws/aws-bedrock-token-generator-python/blob/main/tests/test_token_generator.py --------- Co-authored-by: James Yu <jamesfyu@amazon.com>
1 parent 6308e6d commit 37b43b7

3 files changed

Lines changed: 311 additions & 0 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""AWS SageMaker Token Generator.
2+
3+
A lightweight module for generating short-term bearer tokens for AWS SageMaker
4+
API authentication. Provides the ``generate_token`` helper and the lower-level
5+
``SageMakerTokenGenerator`` class.
6+
7+
Example::
8+
9+
>>> from sagemaker.core.token_generator import generate_token
10+
>>> token = generate_token(region="us-east-1")
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import os
16+
from datetime import timedelta
17+
18+
from botocore.credentials import CredentialProvider
19+
from botocore.session import Session
20+
21+
from sagemaker.core.token_generator.token_generator import (
22+
TOKEN_DURATION,
23+
SageMakerTokenGenerator,
24+
_generate_token,
25+
)
26+
27+
__all__ = ["SageMakerTokenGenerator", "generate_token"]
28+
29+
30+
def generate_token(
31+
region: str | None = None,
32+
aws_credentials_provider: CredentialProvider | None = None,
33+
expiry: timedelta = timedelta(hours=12),
34+
) -> str:
35+
"""Generate a short-lived AWS SageMaker bearer token.
36+
37+
Args:
38+
region (str): AWS region. Falls back to the ``AWS_REGION``
39+
environment variable when not provided.
40+
aws_credentials_provider (CredentialProvider): Optional credential
41+
provider. Uses the default AWS credential chain when omitted.
42+
expiry (timedelta): Token lifetime. Must be between 1 second and
43+
12 hours inclusive. Defaults to 12 hours.
44+
45+
Returns:
46+
str: A bearer token string.
47+
48+
Raises:
49+
ValueError: If *region* is missing or *expiry* is out of range.
50+
RuntimeError: If no valid AWS credentials are found.
51+
"""
52+
region = region or os.environ.get("AWS_REGION")
53+
if not region:
54+
raise ValueError("Region must be provided or set via the AWS_REGION environment variable.")
55+
56+
if expiry.total_seconds() <= 0 or expiry.total_seconds() > TOKEN_DURATION:
57+
raise ValueError(
58+
"Token expiry must be greater than zero and less than or equal to 12 hours"
59+
)
60+
61+
credentials = (
62+
aws_credentials_provider.load() if aws_credentials_provider else Session().get_credentials()
63+
)
64+
65+
if credentials is None:
66+
raise RuntimeError(
67+
"No AWS credentials found. Check your environment or credential provider."
68+
)
69+
70+
return _generate_token(credentials, region, int(expiry.total_seconds()))
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""SageMaker Token Generator core signing logic.
2+
3+
Generates short-term bearer tokens for AWS SageMaker API authentication
4+
using SigV4 signed pre-signed URLs.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import base64
10+
11+
from botocore.auth import SigV4QueryAuth
12+
from botocore.awsrequest import AWSRequest
13+
from botocore.credentials import Credentials
14+
15+
DEFAULT_HOST: str = "sagemaker.amazonaws.com"
16+
DEFAULT_URL: str = f"https://{DEFAULT_HOST}/"
17+
SERVICE_NAME: str = "sagemaker"
18+
AUTH_PREFIX: str = "sagemaker-api-key-"
19+
TOKEN_VERSION: str = "&Version=1"
20+
TOKEN_DURATION: int = 43200 # 12 hours in seconds
21+
22+
23+
def _generate_token(credentials: Credentials, region: str, expires: int) -> str:
24+
"""Build a presigned bearer token.
25+
26+
Args:
27+
credentials (Credentials): AWS credentials.
28+
region (str): AWS region.
29+
expires (int): Expiry time in seconds.
30+
31+
Returns:
32+
str: A base64-encoded bearer token string.
33+
"""
34+
request = AWSRequest(
35+
method="POST",
36+
url=DEFAULT_URL,
37+
headers={"host": DEFAULT_HOST},
38+
params={"Action": "CallWithBearerToken"},
39+
)
40+
41+
auth = SigV4QueryAuth(credentials, SERVICE_NAME, region, expires=expires)
42+
auth.add_auth(request)
43+
44+
presigned_url = request.url.replace("https://", "") + TOKEN_VERSION
45+
encoded_token = base64.b64encode(presigned_url.encode("utf-8")).decode("utf-8")
46+
47+
return f"{AUTH_PREFIX}{encoded_token}"
48+
49+
50+
class SageMakerTokenGenerator:
51+
"""Generate short-lived AWS SageMaker bearer tokens."""
52+
53+
def get_token(self, credentials: Credentials, region: str) -> str:
54+
"""Generate a token using provided credentials and region.
55+
56+
Args:
57+
credentials (Credentials): AWS credentials to sign the request.
58+
region (str): AWS region.
59+
60+
Returns:
61+
str: A bearer token string.
62+
63+
Raises:
64+
ValueError: If inputs are invalid.
65+
"""
66+
if not credentials:
67+
raise ValueError("Credentials cannot be None")
68+
if not region or not isinstance(region, str):
69+
raise ValueError("Region must be a non-empty string")
70+
71+
return _generate_token(credentials, region, TOKEN_DURATION)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for sagemaker.core.token_generator module."""
14+
15+
from __future__ import absolute_import
16+
17+
import base64
18+
from datetime import timedelta
19+
20+
import pytest
21+
from unittest.mock import Mock
22+
23+
from botocore.credentials import Credentials
24+
25+
from sagemaker.core.token_generator import generate_token, SageMakerTokenGenerator
26+
from sagemaker.core.token_generator.token_generator import AUTH_PREFIX
27+
28+
29+
class TestSageMakerTokenGenerator:
30+
"""Tests for the SageMakerTokenGenerator class."""
31+
32+
@pytest.fixture(autouse=True)
33+
def setup(self):
34+
"""Setup test credentials and token generator instance."""
35+
self.token_generator = SageMakerTokenGenerator()
36+
self.credentials = Credentials(
37+
access_key="AKIAIOSFODNN7EXAMPLE",
38+
secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
39+
)
40+
self.region = "us-west-2"
41+
42+
def test_get_token_returns_non_null_token(self):
43+
"""Test that get_token returns a non-null token."""
44+
token = self.token_generator.get_token(self.credentials, self.region)
45+
46+
assert token is not None
47+
assert len(token) > 0
48+
49+
def test_get_token_starts_with_correct_prefix(self):
50+
"""Test that the token starts with the correct prefix."""
51+
token = self.token_generator.get_token(self.credentials, self.region)
52+
53+
assert token.startswith(AUTH_PREFIX)
54+
55+
def test_get_token_with_different_regions(self):
56+
"""Test token generation with different regions."""
57+
regions = ["us-east-1", "us-west-2", "eu-west-1", "ap-northeast-1"]
58+
59+
for region in regions:
60+
token = self.token_generator.get_token(self.credentials, region)
61+
62+
assert token is not None, f"Token should not be null for region: {region}"
63+
assert token.startswith(
64+
AUTH_PREFIX
65+
), f"Token should start with the correct prefix for region: {region}"
66+
67+
def test_get_token_is_base64_encoded(self):
68+
"""Test that the token is properly Base64 encoded."""
69+
token = self.token_generator.get_token(self.credentials, self.region)
70+
71+
token_without_prefix = token[len(AUTH_PREFIX) :]
72+
decoded = base64.b64decode(token_without_prefix)
73+
assert decoded is not None
74+
75+
def test_get_token_contains_version_info(self):
76+
"""Test that the decoded token contains version information."""
77+
token = self.token_generator.get_token(self.credentials, self.region)
78+
79+
token_without_prefix = token[len(AUTH_PREFIX) :]
80+
decoded_string = base64.b64decode(token_without_prefix).decode("utf-8")
81+
assert "&Version=1" in decoded_string
82+
83+
def test_get_token_different_credentials_produce_different_tokens(self):
84+
"""Test that different credentials produce different tokens."""
85+
credentials1 = Credentials(
86+
access_key="AKIAIOSFODNN7EXAMPLE",
87+
secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
88+
)
89+
credentials2 = Credentials(
90+
access_key="AKIAI44QH8DHBEXAMPLE",
91+
secret_key="je7MtGbClwBF/2Zp9Utk/h3yCo8nvbEXAMPLEKEY",
92+
)
93+
94+
token1 = self.token_generator.get_token(credentials1, self.region)
95+
token2 = self.token_generator.get_token(credentials2, self.region)
96+
97+
assert token1 != token2
98+
99+
def test_get_token_with_session_token(self):
100+
"""Test token generation with session token (temporary credentials)."""
101+
credentials_with_token = Credentials(
102+
access_key="AKIAIOSFODNN7EXAMPLE",
103+
secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
104+
token="AQoDYXdzEJr...<remainder of security token>",
105+
)
106+
107+
token = self.token_generator.get_token(credentials_with_token, self.region)
108+
109+
assert token is not None
110+
assert token.startswith(AUTH_PREFIX)
111+
112+
def test_get_token_no_credentials_raises_error(self):
113+
"""Test that get_token raises ValueError when credentials are None."""
114+
with pytest.raises(ValueError, match="Credentials cannot be None"):
115+
self.token_generator.get_token(None, self.region)
116+
117+
def test_get_token_no_region_raises_error(self):
118+
"""Test that get_token raises ValueError when region is None or empty."""
119+
with pytest.raises(ValueError, match="Region must be a non-empty string"):
120+
self.token_generator.get_token(self.credentials, None)
121+
122+
with pytest.raises(ValueError, match="Region must be a non-empty string"):
123+
self.token_generator.get_token(self.credentials, "")
124+
125+
def test_get_token_contains_correct_expiry(self):
126+
"""Test that the decoded token has the correct expiry duration (12 hours)."""
127+
token = self.token_generator.get_token(self.credentials, self.region)
128+
129+
token_without_prefix = token[len(AUTH_PREFIX) :]
130+
decoded_string = base64.b64decode(token_without_prefix).decode("utf-8")
131+
assert "X-Amz-Expires=43200" in decoded_string
132+
133+
def test_get_token_vs_generate_token_consistency(self):
134+
"""Test that get_token and generate_token produce identical tokens for same inputs."""
135+
mock_provider = Mock()
136+
mock_provider.load.return_value = self.credentials
137+
138+
token1 = self.token_generator.get_token(self.credentials, self.region)
139+
140+
token2 = generate_token(
141+
region=self.region,
142+
aws_credentials_provider=mock_provider,
143+
expiry=timedelta(hours=12),
144+
)
145+
146+
assert token1 == token2
147+
assert token1.startswith(AUTH_PREFIX)
148+
assert token2.startswith(AUTH_PREFIX)
149+
assert len(token1) == len(token2)
150+
151+
def test_generate_token_with_custom_expiry_produces_different_token(self):
152+
"""Test that different expiry durations produce different tokens."""
153+
mock_provider = Mock()
154+
mock_provider.load.return_value = self.credentials
155+
156+
token_default = generate_token(
157+
region=self.region,
158+
aws_credentials_provider=mock_provider,
159+
expiry=timedelta(hours=12),
160+
)
161+
162+
token_custom = generate_token(
163+
region=self.region,
164+
aws_credentials_provider=mock_provider,
165+
expiry=timedelta(hours=6),
166+
)
167+
168+
assert token_default != token_custom
169+
assert token_default.startswith(AUTH_PREFIX)
170+
assert token_custom.startswith(AUTH_PREFIX)

0 commit comments

Comments
 (0)