|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific |
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
| 14 | +import os |
14 | 15 | import unittest |
15 | 16 | import pytest |
16 | 17 | import requests |
17 | 18 | from unittest.mock import Mock, patch, MagicMock |
18 | 19 | import boto3 |
19 | 20 | import sagemaker |
20 | 21 | from sagemaker.core.telemetry.constants import Feature |
| 22 | +from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR |
21 | 23 | from sagemaker.core.telemetry.telemetry_logging import ( |
22 | 24 | _send_telemetry_request, |
23 | 25 | _telemetry_emitter, |
|
33 | 35 |
|
34 | 36 | # Try to import sagemaker-serve exceptions, skip tests if not available |
35 | 37 | try: |
36 | | - from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException |
| 38 | + from sagemaker.serve.utils.exceptions import ( |
| 39 | + ModelBuilderException, |
| 40 | + LocalModelOutOfMemoryException, |
| 41 | + ) |
| 42 | + |
37 | 43 | SAGEMAKER_SERVE_AVAILABLE = True |
38 | 44 | except ImportError: |
39 | 45 | SAGEMAKER_SERVE_AVAILABLE = False |
| 46 | + |
40 | 47 | # Create mock exceptions for type hints |
41 | 48 | class ModelBuilderException(Exception): |
42 | 49 | pass |
| 50 | + |
43 | 51 | class LocalModelOutOfMemoryException(Exception): |
44 | 52 | pass |
45 | 53 |
|
| 54 | + |
46 | 55 | MOCK_SESSION = Mock() |
47 | 56 | MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") |
48 | 57 | MOCK_FEATURE = Feature.SDK_DEFAULTS |
@@ -158,10 +167,7 @@ def test_telemetry_emitter_decorator_success( |
158 | 167 | 1, [11, 12], MOCK_SESSION, None, None, expected_extra_str |
159 | 168 | ) |
160 | 169 |
|
161 | | - @pytest.mark.skipif( |
162 | | - not SAGEMAKER_SERVE_AVAILABLE, |
163 | | - reason="Requires sagemaker-serve package" |
164 | | - ) |
| 170 | + @pytest.mark.skipif(not SAGEMAKER_SERVE_AVAILABLE, reason="Requires sagemaker-serve package") |
165 | 171 | @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") |
166 | 172 | @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") |
167 | 173 | def test_telemetry_emitter_decorator_handle_exception_success( |
@@ -194,7 +200,7 @@ def test_telemetry_emitter_decorator_handle_exception_success( |
194 | 200 |
|
195 | 201 | mock_send_telemetry_request.assert_called_once_with( |
196 | 202 | 0, |
197 | | - [1, 2], |
| 203 | + [11, 12], |
198 | 204 | MOCK_SESSION, |
199 | 205 | str(mock_exception_obj), |
200 | 206 | mock_exception_obj.__class__.__name__, |
@@ -357,3 +363,135 @@ def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_a |
357 | 363 | _send_telemetry_request(1, [1, 2], mock_session) |
358 | 364 | # Assert telemetry request was not sent |
359 | 365 | mock_requests_helper.assert_not_called() |
| 366 | + |
| 367 | + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") |
| 368 | + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") |
| 369 | + def test_telemetry_emitter_with_created_by_env_var( |
| 370 | + self, mock_resolve_config, mock_send_telemetry_request |
| 371 | + ): |
| 372 | + """Test that x-createdBy is included when SAGEMAKER_PYSDK_CREATED_BY env var is set""" |
| 373 | + mock_resolve_config.return_value = False |
| 374 | + |
| 375 | + # Set environment variable |
| 376 | + os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai" |
| 377 | + |
| 378 | + try: |
| 379 | + mock_local_client = LocalSagemakerClientMock() |
| 380 | + mock_local_client.mock_create_model() |
| 381 | + |
| 382 | + args = mock_send_telemetry_request.call_args.args |
| 383 | + extra_str = str(args[5]) |
| 384 | + |
| 385 | + # Verify x-createdBy is in the extra string with URL encoding |
| 386 | + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", extra_str) |
| 387 | + |
| 388 | + # Verify forward slashes are encoded as %2F |
| 389 | + self.assertNotIn("x-createdBy=awslabs/agent-plugins", extra_str) |
| 390 | + finally: |
| 391 | + # Clean up environment variable |
| 392 | + if _CREATED_BY_ENV_VAR in os.environ: |
| 393 | + del os.environ[_CREATED_BY_ENV_VAR] |
| 394 | + |
| 395 | + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") |
| 396 | + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") |
| 397 | + def test_telemetry_emitter_without_created_by_env_var( |
| 398 | + self, mock_resolve_config, mock_send_telemetry_request |
| 399 | + ): |
| 400 | + """Test that x-createdBy is NOT included when env var is not set""" |
| 401 | + mock_resolve_config.return_value = False |
| 402 | + |
| 403 | + # Ensure environment variable is not set |
| 404 | + if _CREATED_BY_ENV_VAR in os.environ: |
| 405 | + del os.environ[_CREATED_BY_ENV_VAR] |
| 406 | + |
| 407 | + mock_local_client = LocalSagemakerClientMock() |
| 408 | + mock_local_client.mock_create_model() |
| 409 | + |
| 410 | + args = mock_send_telemetry_request.call_args.args |
| 411 | + extra_str = str(args[5]) |
| 412 | + |
| 413 | + # Verify x-createdBy is NOT in the extra string |
| 414 | + self.assertNotIn("x-createdBy", extra_str) |
| 415 | + |
| 416 | + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") |
| 417 | + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") |
| 418 | + def test_telemetry_emitter_created_by_with_special_chars( |
| 419 | + self, mock_resolve_config, mock_send_telemetry_request |
| 420 | + ): |
| 421 | + """Test that x-createdBy properly URL-encodes special characters""" |
| 422 | + mock_resolve_config.return_value = False |
| 423 | + |
| 424 | + # Set environment variable with special characters |
| 425 | + os.environ[_CREATED_BY_ENV_VAR] = "My App & Tools (v2.0)" |
| 426 | + |
| 427 | + try: |
| 428 | + mock_local_client = LocalSagemakerClientMock() |
| 429 | + mock_local_client.mock_create_model() |
| 430 | + |
| 431 | + args = mock_send_telemetry_request.call_args.args |
| 432 | + extra_str = str(args[5]) |
| 433 | + |
| 434 | + # Verify special characters are URL-encoded |
| 435 | + self.assertIn("x-createdBy=My%20App%20%26%20Tools%20%28v2.0%29", extra_str) |
| 436 | + |
| 437 | + # Verify raw special characters are NOT in the URL |
| 438 | + self.assertNotIn("My App & Tools", extra_str) |
| 439 | + self.assertNotIn("(v2.0)", extra_str) |
| 440 | + finally: |
| 441 | + if _CREATED_BY_ENV_VAR in os.environ: |
| 442 | + del os.environ[_CREATED_BY_ENV_VAR] |
| 443 | + |
| 444 | + @patch("sagemaker.core.telemetry.telemetry_logging._send_telemetry_request") |
| 445 | + @patch("sagemaker.core.telemetry.telemetry_logging.resolve_value_from_config") |
| 446 | + def test_telemetry_emitter_created_by_empty_string( |
| 447 | + self, mock_resolve_config, mock_send_telemetry_request |
| 448 | + ): |
| 449 | + """Test that x-createdBy is NOT included when env var is empty string""" |
| 450 | + mock_resolve_config.return_value = False |
| 451 | + |
| 452 | + # Set environment variable to empty string |
| 453 | + os.environ[_CREATED_BY_ENV_VAR] = "" |
| 454 | + |
| 455 | + try: |
| 456 | + mock_local_client = LocalSagemakerClientMock() |
| 457 | + mock_local_client.mock_create_model() |
| 458 | + |
| 459 | + args = mock_send_telemetry_request.call_args.args |
| 460 | + extra_str = str(args[5]) |
| 461 | + |
| 462 | + # Verify x-createdBy is NOT added for empty string |
| 463 | + self.assertNotIn("x-createdBy", extra_str) |
| 464 | + finally: |
| 465 | + if _CREATED_BY_ENV_VAR in os.environ: |
| 466 | + del os.environ[_CREATED_BY_ENV_VAR] |
| 467 | + |
| 468 | + def test_construct_url_with_created_by(self): |
| 469 | + """Test URL construction includes x-createdBy in extra_info""" |
| 470 | + mock_accountId = "123456789012" |
| 471 | + mock_region = "us-west-2" |
| 472 | + mock_status = "1" |
| 473 | + mock_feature = "15" |
| 474 | + mock_extra_info = ( |
| 475 | + "DataSet.create&x-sdkVersion=3.0&x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai" |
| 476 | + ) |
| 477 | + |
| 478 | + url = _construct_url( |
| 479 | + accountId=mock_accountId, |
| 480 | + region=mock_region, |
| 481 | + status=mock_status, |
| 482 | + feature=mock_feature, |
| 483 | + failure_reason=None, |
| 484 | + failure_type=None, |
| 485 | + extra_info=mock_extra_info, |
| 486 | + ) |
| 487 | + |
| 488 | + expected_url = ( |
| 489 | + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" |
| 490 | + f"x-accountId={mock_accountId}" |
| 491 | + f"&x-status={mock_status}" |
| 492 | + f"&x-feature={mock_feature}" |
| 493 | + f"&x-extra={mock_extra_info}" |
| 494 | + ) |
| 495 | + |
| 496 | + self.assertEqual(url, expected_url) |
| 497 | + self.assertIn("x-createdBy=awslabs%2Fagent-plugins%2Fsagemaker-ai", url) |
0 commit comments