Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
# Partner App
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401

# Attribution
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401

# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
# They are not re-exported from core to avoid circular dependencies
41 changes: 41 additions & 0 deletions sagemaker-core/src/sagemaker/core/telemetry/attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Attribution module for tracking the provenance of SDK usage."""
from __future__ import absolute_import
import os
from enum import Enum

_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY"


class Attribution(Enum):
"""Enumeration of known SDK attribution sources."""

SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai"


def set_attribution(attribution: Attribution):
"""Sets the SDK usage attribution to the specified source.

Call this at the top of scripts generated by an agent or integration
to enable accurate telemetry attribution.

Args:
attribution (Attribution): The attribution source to set.

Raises:
TypeError: If attribution is not an Attribution enum member.
"""
if not isinstance(attribution, Attribution):
raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}")
os.environ[_CREATED_BY_ENV_VAR] = attribution.value
47 changes: 47 additions & 0 deletions sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Resource creation module for tracking ARNs of resources created via SDK calls."""
from __future__ import absolute_import

# Maps class name (string) to the attribute name holding the resource ARN.
# String-based keys avoid cross-package imports and circular dependencies.
_RESOURCE_ARN_ATTRIBUTES = {
"TrainingJob": "training_job_arn",
}


def get_resource_arn(response):
"""Extract the ARN from a SDK response object if available.

Uses string-based type name lookup to avoid cross-package imports.

Args:
response: The return value of a _telemetry_emitter-decorated function.

Returns:
str: The ARN string if available, otherwise None.
"""
if response is None:
return None

arn_attr = _RESOURCE_ARN_ATTRIBUTES.get(type(response).__name__)
if not arn_attr:
return None

arn = getattr(response, arn_attr, None)

# Guard against Unassigned sentinel used in resources.py
if not arn or type(arn).__name__ == "Unassigned":
return None

return str(arn)
16 changes: 15 additions & 1 deletion sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
from __future__ import absolute_import
import logging
import os
import platform
import sys
from time import perf_counter
from typing import List
import functools
import requests
from urllib.parse import quote

import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.telemetry.resource_creation import get_resource_arn
from sagemaker.core.common_utils import resolve_value_from_config
from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH
from sagemaker.core.telemetry.constants import (
Expand Down Expand Up @@ -81,7 +85,7 @@ def wrapper(*args, **kwargs):
sagemaker_session = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
sagemaker_session = args[0].sagemaker_session
sagemaker_session = args[0].sagemaker_session or _get_default_sagemaker_session()
elif len(args) > 0 and hasattr(args[0], "_sagemaker_session"):
# Get the sagemaker_session from the instance method args (private attribute)
sagemaker_session = args[0]._sagemaker_session
Expand Down Expand Up @@ -137,13 +141,23 @@ def wrapper(*args, **kwargs):
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

# Add created_by from environment variable if available
created_by = os.environ.get(_CREATED_BY_ENV_VAR, "")
if created_by:
extra += f"&x-createdBy={quote(created_by, safe='')}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
# For specified response types (e.g., TrainingJob), obtain the ARN of the
# resource created if present so that it can be included.
resource_arn = get_resource_arn(response)
if resource_arn:
extra += f"&x-resourceArn={resource_arn}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
Expand Down
29 changes: 29 additions & 0 deletions sagemaker-core/src/sagemaker/core/utils/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,31 @@

import importlib_metadata

from string import ascii_letters, digits

from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR

SagemakerCore_PREFIX = "AWS-SageMakerCore"

_USERAGENT_ALLOWED_CHARACTERS = ascii_letters + digits + "!$%&'*+-.^_`|~,"


def sanitize_user_agent_string_component(raw_str, allow_hash=False):
"""Sanitize a User-Agent string component by replacing disallowed characters with '-'.

Args:
raw_str (str): The input string to sanitize.
allow_hash (bool): Whether '#' is considered an allowed character.

Returns:
str: The sanitized string.
"""
return "".join(
c if c in _USERAGENT_ALLOWED_CHARACTERS or (allow_hash and c == "#") else "-"
for c in raw_str
)


STUDIO_PREFIX = "AWS-SageMaker-Studio"
NOTEBOOK_PREFIX = "AWS-SageMaker-Notebook-Instance"

Expand Down Expand Up @@ -74,4 +98,9 @@ def get_user_agent_extra_suffix() -> str:
if studio_app_type:
suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)

# Add created_by metadata if attribution has been set
created_by = os.environ.get(_CREATED_BY_ENV_VAR)
if created_by:
suffix = "{} md/{}#{}".format(suffix, "createdBy", sanitize_user_agent_string_component(created_by))

return suffix
62 changes: 49 additions & 13 deletions sagemaker-core/tests/unit/generated/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
from mock import patch, mock_open

import pytest

from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.utils.user_agent import (
SagemakerCore_PREFIX,
SagemakerCore_VERSION,
Expand All @@ -24,8 +15,15 @@
process_notebook_metadata_file,
process_studio_metadata_file,
get_user_agent_extra_suffix,
sanitize_user_agent_string_component,
)
from sagemaker.core.utils.user_agent import SagemakerCore_PREFIX


@pytest.fixture(autouse=True)
def clean_env():
yield
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]


# Test process_notebook_metadata_file function
Expand Down Expand Up @@ -58,6 +56,27 @@ def test_process_studio_metadata_file_not_exists(tmp_path):
assert process_studio_metadata_file() is None


# Test sanitize_user_agent_string_component function
def test_sanitize_replaces_slash_with_dash():
assert sanitize_user_agent_string_component("awslabs/agent-plugins/sagemaker-ai") == "awslabs-agent-plugins-sagemaker-ai"


def test_sanitize_allows_alphanumeric():
assert sanitize_user_agent_string_component("abc123") == "abc123"


def test_sanitize_replaces_hash_when_not_allowed():
assert sanitize_user_agent_string_component("foo#bar") == "foo-bar"


def test_sanitize_allows_hash_when_permitted():
assert sanitize_user_agent_string_component("foo#bar", allow_hash=True) == "foo#bar"


def test_sanitize_replaces_space_with_dash():
assert sanitize_user_agent_string_component("foo bar") == "foo-bar"


# Test get_user_agent_extra_suffix function
def test_get_user_agent_extra_suffix():
assert get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION}"
Expand All @@ -78,3 +97,20 @@ def test_get_user_agent_extra_suffix():
get_user_agent_extra_suffix()
== f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION} md/{STUDIO_PREFIX}#studio_type"
)


def test_get_user_agent_extra_suffix_without_created_by():
suffix = get_user_agent_extra_suffix()
assert "createdBy" not in suffix


def test_get_user_agent_extra_suffix_with_created_by():
os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai"
suffix = get_user_agent_extra_suffix()
assert "md/createdBy#awslabs-agent-plugins-sagemaker-ai" in suffix


def test_get_user_agent_extra_suffix_created_by_sanitized():
os.environ[_CREATED_BY_ENV_VAR] = "my agent/v1.0 (test)"
suffix = get_user_agent_extra_suffix()
assert "md/createdBy#my-agent-v1.0--test-" in suffix
37 changes: 37 additions & 0 deletions sagemaker-core/tests/unit/telemetry/test_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import pytest
from sagemaker.core.telemetry.attribution import (
_CREATED_BY_ENV_VAR,
Attribution,
set_attribution,
)


@pytest.fixture(autouse=True)
def clean_env():
yield
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]


def test_set_attribution_sagemaker_agent_plugin():
set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN)
assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value


def test_set_attribution_invalid_type_raises():
with pytest.raises(TypeError):
set_attribution("awslabs/agent-plugins/sagemaker-ai")
74 changes: 74 additions & 0 deletions sagemaker-core/tests/unit/telemetry/test_resource_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import pytest
from unittest.mock import MagicMock
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.telemetry.resource_creation import _RESOURCE_ARN_ATTRIBUTES, get_resource_arn


# Each entry: (class_name, arn_attr, arn_value)
_RESOURCE_TEST_CASES = [
(
"TrainingJob",
"training_job_arn",
"arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job",
),
]


def test_get_resource_arn_none_response():
assert get_resource_arn(None) is None


def test_get_resource_arn_unknown_type():
assert get_resource_arn("some string") is None
assert get_resource_arn(42) is None


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_valid_arn(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, arn_value)
assert get_resource_arn(mock_resource) == arn_value


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_unassigned(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, Unassigned())
assert get_resource_arn(mock_resource) is None


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_none_arn(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, None)
assert get_resource_arn(mock_resource) is None


# Verify string keys in _RESOURCE_ARN_ATTRIBUTES match actual class names
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_resource_class_name_matches_dict_key(class_name, arn_attr, arn_value):
from sagemaker.core.resources import TrainingJob

_CLASS_MAP = {
"TrainingJob": TrainingJob,
}
cls = _CLASS_MAP.get(class_name)
assert cls is not None, f"No class found for key '{class_name}'"
assert cls.__name__ == class_name
assert class_name in _RESOURCE_ARN_ATTRIBUTES
Loading
Loading