Skip to content

Commit a40c856

Browse files
ryantanakaRyan Tanakamollyheamazon
authored
feature: add telemetry attribution module for SDK usage provenance (#5661)
* feature: add telemetry attribution module for SDK usage provenance * feature: add TrainingJob ARN to telemetry for training jobs and fixed bug with telemetry not being sent for *Trainer.train() if sagemaker_session is not provided * adding createdBy metadata to user agent string if attribution env var has been set to aid in resource attribution * fix: removed unused patch on builtins.open in test_create_with_byoc which was not being used and causing unintended patches to open calls elsewhere --------- Co-authored-by: Ryan Tanaka <rrtanaka@amazon.com> Co-authored-by: Molly He <mollyhe@amazon.com>
1 parent c07247e commit a40c856

File tree

10 files changed

+487
-22
lines changed

10 files changed

+487
-22
lines changed

sagemaker-core/src/sagemaker/core/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
# Partner App
1616
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
1717

18+
# Attribution
19+
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401
20+
1821
# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
1922
# They are not re-exported from core to avoid circular dependencies
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Attribution module for tracking the provenance of SDK usage."""
14+
from __future__ import absolute_import
15+
import os
16+
from enum import Enum
17+
18+
_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY"
19+
20+
21+
class Attribution(Enum):
22+
"""Enumeration of known SDK attribution sources."""
23+
24+
SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai"
25+
26+
27+
def set_attribution(attribution: Attribution):
28+
"""Sets the SDK usage attribution to the specified source.
29+
30+
Call this at the top of scripts generated by an agent or integration
31+
to enable accurate telemetry attribution.
32+
33+
Args:
34+
attribution (Attribution): The attribution source to set.
35+
36+
Raises:
37+
TypeError: If attribution is not an Attribution enum member.
38+
"""
39+
if not isinstance(attribution, Attribution):
40+
raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}")
41+
os.environ[_CREATED_BY_ENV_VAR] = attribution.value
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Resource creation module for tracking ARNs of resources created via SDK calls."""
14+
from __future__ import absolute_import
15+
16+
# Maps class name (string) to the attribute name holding the resource ARN.
17+
# String-based keys avoid cross-package imports and circular dependencies.
18+
_RESOURCE_ARN_ATTRIBUTES = {
19+
"TrainingJob": "training_job_arn",
20+
}
21+
22+
23+
def get_resource_arn(response):
24+
"""Extract the ARN from a SDK response object if available.
25+
26+
Uses string-based type name lookup to avoid cross-package imports.
27+
28+
Args:
29+
response: The return value of a _telemetry_emitter-decorated function.
30+
31+
Returns:
32+
str: The ARN string if available, otherwise None.
33+
"""
34+
if response is None:
35+
return None
36+
37+
arn_attr = _RESOURCE_ARN_ATTRIBUTES.get(type(response).__name__)
38+
if not arn_attr:
39+
return None
40+
41+
arn = getattr(response, arn_attr, None)
42+
43+
# Guard against Unassigned sentinel used in resources.py
44+
if not arn or type(arn).__name__ == "Unassigned":
45+
return None
46+
47+
return str(arn)

sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@
1313
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
import platform
1718
import sys
1819
from time import perf_counter
1920
from typing import List
2021
import functools
2122
import requests
23+
from urllib.parse import quote
2224

2325
import boto3
2426
from sagemaker.core.helper.session_helper import Session
27+
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
28+
from sagemaker.core.telemetry.resource_creation import get_resource_arn
2529
from sagemaker.core.common_utils import resolve_value_from_config
2630
from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH
2731
from sagemaker.core.telemetry.constants import (
@@ -81,7 +85,7 @@ def wrapper(*args, **kwargs):
8185
sagemaker_session = None
8286
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
8387
# Get the sagemaker_session from the instance method args
84-
sagemaker_session = args[0].sagemaker_session
88+
sagemaker_session = args[0].sagemaker_session or _get_default_sagemaker_session()
8589
elif len(args) > 0 and hasattr(args[0], "_sagemaker_session"):
8690
# Get the sagemaker_session from the instance method args (private attribute)
8791
sagemaker_session = args[0]._sagemaker_session
@@ -137,13 +141,23 @@ def wrapper(*args, **kwargs):
137141
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
138142
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"
139143

144+
# Add created_by from environment variable if available
145+
created_by = os.environ.get(_CREATED_BY_ENV_VAR, "")
146+
if created_by:
147+
extra += f"&x-createdBy={quote(created_by, safe='')}"
148+
140149
start_timer = perf_counter()
141150
try:
142151
# Call the original function
143152
response = func(*args, **kwargs)
144153
stop_timer = perf_counter()
145154
elapsed = stop_timer - start_timer
146155
extra += f"&x-latency={round(elapsed, 2)}"
156+
# For specified response types (e.g., TrainingJob), obtain the ARN of the
157+
# resource created if present so that it can be included.
158+
resource_arn = get_resource_arn(response)
159+
if resource_arn:
160+
extra += f"&x-resourceArn={resource_arn}"
147161
if not telemetry_opt_out_flag:
148162
_send_telemetry_request(
149163
STATUS_TO_CODE[str(Status.SUCCESS)],

sagemaker-core/src/sagemaker/core/utils/user_agent.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,31 @@
1717

1818
import importlib_metadata
1919

20+
from string import ascii_letters, digits
21+
22+
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
23+
2024
SagemakerCore_PREFIX = "AWS-SageMakerCore"
25+
26+
_USERAGENT_ALLOWED_CHARACTERS = ascii_letters + digits + "!$%&'*+-.^_`|~,"
27+
28+
29+
def sanitize_user_agent_string_component(raw_str, allow_hash=False):
30+
"""Sanitize a User-Agent string component by replacing disallowed characters with '-'.
31+
32+
Args:
33+
raw_str (str): The input string to sanitize.
34+
allow_hash (bool): Whether '#' is considered an allowed character.
35+
36+
Returns:
37+
str: The sanitized string.
38+
"""
39+
return "".join(
40+
c if c in _USERAGENT_ALLOWED_CHARACTERS or (allow_hash and c == "#") else "-"
41+
for c in raw_str
42+
)
43+
44+
2145
STUDIO_PREFIX = "AWS-SageMaker-Studio"
2246
NOTEBOOK_PREFIX = "AWS-SageMaker-Notebook-Instance"
2347

@@ -74,4 +98,9 @@ def get_user_agent_extra_suffix() -> str:
7498
if studio_app_type:
7599
suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)
76100

101+
# Add created_by metadata if attribution has been set
102+
created_by = os.environ.get(_CREATED_BY_ENV_VAR)
103+
if created_by:
104+
suffix = "{} md/{}#{}".format(suffix, "createdBy", sanitize_user_agent_string_component(created_by))
105+
77106
return suffix

sagemaker-core/tests/unit/generated/test_user_agent.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License"). You
4-
# may not use this file except in compliance with the License. A copy of
5-
# the License is located at
6-
#
7-
# http://aws.amazon.com/apache2.0/
8-
#
9-
# or in the "license" file accompanying this file. This file is
10-
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11-
# ANY KIND, either express or implied. See the License for the specific
12-
# language governing permissions and limitations under the License.
131
from __future__ import absolute_import
142

153
import json
4+
import os
165
from mock import patch, mock_open
176

7+
import pytest
188

9+
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
1910
from sagemaker.core.utils.user_agent import (
2011
SagemakerCore_PREFIX,
2112
SagemakerCore_VERSION,
@@ -24,8 +15,15 @@
2415
process_notebook_metadata_file,
2516
process_studio_metadata_file,
2617
get_user_agent_extra_suffix,
18+
sanitize_user_agent_string_component,
2719
)
28-
from sagemaker.core.utils.user_agent import SagemakerCore_PREFIX
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def clean_env():
24+
yield
25+
if _CREATED_BY_ENV_VAR in os.environ:
26+
del os.environ[_CREATED_BY_ENV_VAR]
2927

3028

3129
# Test process_notebook_metadata_file function
@@ -58,6 +56,27 @@ def test_process_studio_metadata_file_not_exists(tmp_path):
5856
assert process_studio_metadata_file() is None
5957

6058

59+
# Test sanitize_user_agent_string_component function
60+
def test_sanitize_replaces_slash_with_dash():
61+
assert sanitize_user_agent_string_component("awslabs/agent-plugins/sagemaker-ai") == "awslabs-agent-plugins-sagemaker-ai"
62+
63+
64+
def test_sanitize_allows_alphanumeric():
65+
assert sanitize_user_agent_string_component("abc123") == "abc123"
66+
67+
68+
def test_sanitize_replaces_hash_when_not_allowed():
69+
assert sanitize_user_agent_string_component("foo#bar") == "foo-bar"
70+
71+
72+
def test_sanitize_allows_hash_when_permitted():
73+
assert sanitize_user_agent_string_component("foo#bar", allow_hash=True) == "foo#bar"
74+
75+
76+
def test_sanitize_replaces_space_with_dash():
77+
assert sanitize_user_agent_string_component("foo bar") == "foo-bar"
78+
79+
6180
# Test get_user_agent_extra_suffix function
6281
def test_get_user_agent_extra_suffix():
6382
assert get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION}"
@@ -78,3 +97,20 @@ def test_get_user_agent_extra_suffix():
7897
get_user_agent_extra_suffix()
7998
== f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION} md/{STUDIO_PREFIX}#studio_type"
8099
)
100+
101+
102+
def test_get_user_agent_extra_suffix_without_created_by():
103+
suffix = get_user_agent_extra_suffix()
104+
assert "createdBy" not in suffix
105+
106+
107+
def test_get_user_agent_extra_suffix_with_created_by():
108+
os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai"
109+
suffix = get_user_agent_extra_suffix()
110+
assert "md/createdBy#awslabs-agent-plugins-sagemaker-ai" in suffix
111+
112+
113+
def test_get_user_agent_extra_suffix_created_by_sanitized():
114+
os.environ[_CREATED_BY_ENV_VAR] = "my agent/v1.0 (test)"
115+
suffix = get_user_agent_extra_suffix()
116+
assert "md/createdBy#my-agent-v1.0--test-" in suffix
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import os
15+
import pytest
16+
from sagemaker.core.telemetry.attribution import (
17+
_CREATED_BY_ENV_VAR,
18+
Attribution,
19+
set_attribution,
20+
)
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def clean_env():
25+
yield
26+
if _CREATED_BY_ENV_VAR in os.environ:
27+
del os.environ[_CREATED_BY_ENV_VAR]
28+
29+
30+
def test_set_attribution_sagemaker_agent_plugin():
31+
set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN)
32+
assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value
33+
34+
35+
def test_set_attribution_invalid_type_raises():
36+
with pytest.raises(TypeError):
37+
set_attribution("awslabs/agent-plugins/sagemaker-ai")
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from unittest.mock import MagicMock
16+
from sagemaker.core.utils.utils import Unassigned
17+
from sagemaker.core.telemetry.resource_creation import _RESOURCE_ARN_ATTRIBUTES, get_resource_arn
18+
19+
20+
# Each entry: (class_name, arn_attr, arn_value)
21+
_RESOURCE_TEST_CASES = [
22+
(
23+
"TrainingJob",
24+
"training_job_arn",
25+
"arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job",
26+
),
27+
]
28+
29+
30+
def test_get_resource_arn_none_response():
31+
assert get_resource_arn(None) is None
32+
33+
34+
def test_get_resource_arn_unknown_type():
35+
assert get_resource_arn("some string") is None
36+
assert get_resource_arn(42) is None
37+
38+
39+
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
40+
def test_get_resource_arn_with_valid_arn(class_name, arn_attr, arn_value):
41+
mock_resource = MagicMock()
42+
mock_resource.__class__.__name__ = class_name
43+
setattr(mock_resource, arn_attr, arn_value)
44+
assert get_resource_arn(mock_resource) == arn_value
45+
46+
47+
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
48+
def test_get_resource_arn_with_unassigned(class_name, arn_attr, arn_value):
49+
mock_resource = MagicMock()
50+
mock_resource.__class__.__name__ = class_name
51+
setattr(mock_resource, arn_attr, Unassigned())
52+
assert get_resource_arn(mock_resource) is None
53+
54+
55+
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
56+
def test_get_resource_arn_with_none_arn(class_name, arn_attr, arn_value):
57+
mock_resource = MagicMock()
58+
mock_resource.__class__.__name__ = class_name
59+
setattr(mock_resource, arn_attr, None)
60+
assert get_resource_arn(mock_resource) is None
61+
62+
63+
# Verify string keys in _RESOURCE_ARN_ATTRIBUTES match actual class names
64+
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
65+
def test_resource_class_name_matches_dict_key(class_name, arn_attr, arn_value):
66+
from sagemaker.core.resources import TrainingJob
67+
68+
_CLASS_MAP = {
69+
"TrainingJob": TrainingJob,
70+
}
71+
cls = _CLASS_MAP.get(class_name)
72+
assert cls is not None, f"No class found for key '{class_name}'"
73+
assert cls.__name__ == class_name
74+
assert class_name in _RESOURCE_ARN_ATTRIBUTES

0 commit comments

Comments
 (0)