Skip to content

Commit ffe5001

Browse files
author
Ryan Tanaka
committed
feature: add telemetry attribution module for SDK usage provenance
1 parent e965a1b commit ffe5001

File tree

5 files changed

+233
-6
lines changed

5 files changed

+233
-6
lines changed

sagemaker-core/src/sagemaker/core/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
# Partner App
1616
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
1717

18+
# Attribution
19+
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401
20+
1821
# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
1922
# They are not re-exported from core to avoid circular dependencies
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Attribution module for tracking the provenance of SDK usage."""
14+
from __future__ import absolute_import
15+
import os
16+
from enum import Enum
17+
18+
_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY"
19+
20+
21+
class Attribution(Enum):
22+
"""Enumeration of known SDK attribution sources."""
23+
24+
SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai"
25+
26+
27+
def set_attribution(attribution: Attribution):
28+
"""Sets the SDK usage attribution to the specified source.
29+
30+
Call this at the top of scripts generated by an agent or integration
31+
to enable accurate telemetry attribution.
32+
33+
Args:
34+
attribution (Attribution): The attribution source to set.
35+
36+
Raises:
37+
TypeError: If attribution is not an Attribution enum member.
38+
"""
39+
if not isinstance(attribution, Attribution):
40+
raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}")
41+
os.environ[_CREATED_BY_ENV_VAR] = attribution.value

sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,18 @@
1313
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
import platform
1718
import sys
1819
from time import perf_counter
1920
from typing import List
2021
import functools
2122
import requests
23+
from urllib.parse import quote
2224

2325
import boto3
2426
from sagemaker.core.helper.session_helper import Session
27+
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
2528
from sagemaker.core.common_utils import resolve_value_from_config
2629
from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH
2730
from sagemaker.core.telemetry.constants import (
@@ -137,6 +140,11 @@ def wrapper(*args, **kwargs):
137140
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
138141
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"
139142

143+
# Add created_by from environment variable if available
144+
created_by = os.environ.get(_CREATED_BY_ENV_VAR, "")
145+
if created_by:
146+
extra += f"&x-createdBy={quote(created_by, safe='')}"
147+
140148
start_timer = perf_counter()
141149
try:
142150
# Call the original function
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import os
15+
import pytest
16+
from sagemaker.core.telemetry.attribution import (
17+
_CREATED_BY_ENV_VAR,
18+
Attribution,
19+
set_attribution,
20+
)
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def clean_env():
25+
yield
26+
if _CREATED_BY_ENV_VAR in os.environ:
27+
del os.environ[_CREATED_BY_ENV_VAR]
28+
29+
30+
def test_set_attribution_sagemaker_agent_plugin():
31+
set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN)
32+
assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value
33+
34+
35+
def test_set_attribution_invalid_type_raises():
36+
with pytest.raises(TypeError):
37+
set_attribution("awslabs/agent-plugins/sagemaker-ai")

sagemaker-core/tests/unit/telemetry/test_telemetry_logging.py

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import os
1415
import unittest
1516
import pytest
1617
import requests
1718
from unittest.mock import Mock, patch, MagicMock
1819
import boto3
1920
import sagemaker
2021
from sagemaker.core.telemetry.constants import Feature
22+
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
2123
from sagemaker.core.telemetry.telemetry_logging import (
2224
_send_telemetry_request,
2325
_telemetry_emitter,
@@ -33,16 +35,23 @@
3335

3436
# Try to import sagemaker-serve exceptions, skip tests if not available
3537
try:
36-
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
38+
from sagemaker.serve.utils.exceptions import (
39+
ModelBuilderException,
40+
LocalModelOutOfMemoryException,
41+
)
42+
3743
SAGEMAKER_SERVE_AVAILABLE = True
3844
except ImportError:
3945
SAGEMAKER_SERVE_AVAILABLE = False
46+
4047
# Create mock exceptions for type hints
4148
class ModelBuilderException(Exception):
4249
pass
50+
4351
class LocalModelOutOfMemoryException(Exception):
4452
pass
4553

54+
4655
MOCK_SESSION = Mock()
4756
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
4857
MOCK_FEATURE = Feature.SDK_DEFAULTS
@@ -158,10 +167,7 @@ def test_telemetry_emitter_decorator_success(
158167
1, [11, 12], MOCK_SESSION, None, None, expected_extra_str
159168
)
160169

161-
@pytest.mark.skipif(
162-
not SAGEMAKER_SERVE_AVAILABLE,
163-
reason="Requires sagemaker-serve package"
164-
)
170+
@pytest.mark.skipif(not SAGEMAKER_SERVE_AVAILABLE, reason="Requires sagemaker-serve package")
165171
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
166172
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
167173
def test_telemetry_emitter_decorator_handle_exception_success(
@@ -194,7 +200,7 @@ def test_telemetry_emitter_decorator_handle_exception_success(
194200

195201
mock_send_telemetry_request.assert_called_once_with(
196202
0,
197-
[1, 2],
203+
[11, 12],
198204
MOCK_SESSION,
199205
str(mock_exception_obj),
200206
mock_exception_obj.__class__.__name__,
@@ -357,3 +363,135 @@ def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_a
357363
_send_telemetry_request(1, [1, 2], mock_session)
358364
# Assert telemetry request was not sent
359365
mock_requests_helper.assert_not_called()
366+
367+
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
368+
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
369+
def test_telemetry_emitter_with_created_by_env_var(
370+
self, mock_resolve_config, mock_send_telemetry_request
371+
):
372+
"""Test that x-createdBy is included when SAGEMAKER_PYSDK_CREATED_BY env var is set"""
373+
mock_resolve_config.return_value = False
374+
375+
# Set environment variable
376+
os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai"
377+
378+
try:
379+
mock_local_client = LocalSagemakerClientMock()
380+
mock_local_client.mock_create_model()
381+
382+
args = mock_send_telemetry_request.call_args.args
383+
extra_str = str(args[5])
384+
385+
# Verify x-createdBy is in the extra string with URL encoding
386+
self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", extra_str)
387+
388+
# Verify forward slashes are encoded as %2F
389+
self.assertNotIn("x-createdBy=awslabs/agent-plugins", extra_str)
390+
finally:
391+
# Clean up environment variable
392+
if _CREATED_BY_ENV_VAR in os.environ:
393+
del os.environ[_CREATED_BY_ENV_VAR]
394+
395+
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
396+
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
397+
def test_telemetry_emitter_without_created_by_env_var(
398+
self, mock_resolve_config, mock_send_telemetry_request
399+
):
400+
"""Test that x-createdBy is NOT included when env var is not set"""
401+
mock_resolve_config.return_value = False
402+
403+
# Ensure environment variable is not set
404+
if _CREATED_BY_ENV_VAR in os.environ:
405+
del os.environ[_CREATED_BY_ENV_VAR]
406+
407+
mock_local_client = LocalSagemakerClientMock()
408+
mock_local_client.mock_create_model()
409+
410+
args = mock_send_telemetry_request.call_args.args
411+
extra_str = str(args[5])
412+
413+
# Verify x-createdBy is NOT in the extra string
414+
self.assertNotIn("x-createdBy", extra_str)
415+
416+
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
417+
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
418+
def test_telemetry_emitter_created_by_with_special_chars(
419+
self, mock_resolve_config, mock_send_telemetry_request
420+
):
421+
"""Test that x-createdBy properly URL-encodes special characters"""
422+
mock_resolve_config.return_value = False
423+
424+
# Set environment variable with special characters
425+
os.environ[_CREATED_BY_ENV_VAR] = "My App & Tools (v2.0)"
426+
427+
try:
428+
mock_local_client = LocalSagemakerClientMock()
429+
mock_local_client.mock_create_model()
430+
431+
args = mock_send_telemetry_request.call_args.args
432+
extra_str = str(args[5])
433+
434+
# Verify special characters are URL-encoded
435+
self.assertIn("x-createdBy=My%20App%20%26%20Tools%20%28v2.0%29", extra_str)
436+
437+
# Verify raw special characters are NOT in the URL
438+
self.assertNotIn("My App & Tools", extra_str)
439+
self.assertNotIn("(v2.0)", extra_str)
440+
finally:
441+
if _CREATED_BY_ENV_VAR in os.environ:
442+
del os.environ[_CREATED_BY_ENV_VAR]
443+
444+
@patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request")
445+
@patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config")
446+
def test_telemetry_emitter_created_by_empty_string(
447+
self, mock_resolve_config, mock_send_telemetry_request
448+
):
449+
"""Test that x-createdBy is NOT included when env var is empty string"""
450+
mock_resolve_config.return_value = False
451+
452+
# Set environment variable to empty string
453+
os.environ[_CREATED_BY_ENV_VAR] = ""
454+
455+
try:
456+
mock_local_client = LocalSagemakerClientMock()
457+
mock_local_client.mock_create_model()
458+
459+
args = mock_send_telemetry_request.call_args.args
460+
extra_str = str(args[5])
461+
462+
# Verify x-createdBy is NOT added for empty string
463+
self.assertNotIn("x-createdBy", extra_str)
464+
finally:
465+
if _CREATED_BY_ENV_VAR in os.environ:
466+
del os.environ[_CREATED_BY_ENV_VAR]
467+
468+
def test_construct_url_with_created_by(self):
469+
"""Test URL construction includes x-createdBy in extra_info"""
470+
mock_accountId = "123456789012"
471+
mock_region = "us-west-2"
472+
mock_status = "1"
473+
mock_feature = "15"
474+
mock_extra_info = (
475+
"DataSet.create&x-sdkVersion=3.0&x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai"
476+
)
477+
478+
url = _construct_url(
479+
accountId=mock_accountId,
480+
region=mock_region,
481+
status=mock_status,
482+
feature=mock_feature,
483+
failure_reason=None,
484+
failure_type=None,
485+
extra_info=mock_extra_info,
486+
)
487+
488+
expected_url = (
489+
f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?"
490+
f"x-accountId={mock_accountId}"
491+
f"&x-status={mock_status}"
492+
f"&x-feature={mock_feature}"
493+
f"&x-extra={mock_extra_info}"
494+
)
495+
496+
self.assertEqual(url, expected_url)
497+
self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url)

0 commit comments

Comments
 (0)