Skip to content

Commit 8a755c5

Browse files
authored
test: unskip sm-core integ tests (aws#5856)
* test: unskip sm-core integ tests * fix: resolve intelligent defaults bug in ImageRetriever by replacing ineffective locals() assignment with args dict * fix: use PascalCase for config attribute lookup to match schema keys * test: update expected sagemaker-distribution image version from 3.0.0 to 3.2.0 * fix: ignore unknown fields in DefaultPayloadsModel to handle hub schema additions * fix: port kms_utils from SageMakerHulkPythonSDK and enable test_advanced_job_setting - Add tests/integ/kms_utils.py with get_or_create_kms_key utility ported from SageMakerHulkPythonSDK, adapted for sagemaker-core import paths - Uncomment kms_utils import in test_decorator.py to fix NameError in s3_kms_key fixture * fix: unskip test_decorator_with_spark_job with conditional skipif and fix test bugs - Add skipif for Python versions without Spark image (only 3.9/3.12 supported) - Fix Properties set literal to dict: {"spark.app.name": "remote-spark-test"} - Fix spark.conf.get() to use string key instead of attribute reference * Revert "fix: ignore unknown fields in DefaultPayloadsModel to handle hub schema additions" * update kms helper name
1 parent addb37b commit 8a755c5

4 files changed

Lines changed: 192 additions & 13 deletions

File tree

sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@
2525
config_for_framework,
2626
)
2727
from sagemaker.core.workflow.utilities import override_pipeline_parameter_var
28-
from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, SAGEMAKER, _simple_path
28+
from sagemaker.core.config.config_schema import IMAGE_RETRIEVER, MODULES, PYTHON_SDK, SAGEMAKER, _simple_path
2929
from sagemaker.core.config.config_manager import SageMakerConfig
3030

31+
32+
def _to_pascal_case(name):
33+
"""Convert snake_case to PascalCase."""
34+
camel = to_camel_case(name)
35+
return camel[0].upper() + camel[1:] if camel else camel
36+
3137
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3238
HUGGING_FACE_FRAMEWORK = "huggingface"
3339
PYTORCH_FRAMEWORK = "pytorch"
@@ -114,11 +120,25 @@ def retrieve_hugging_face_uri(
114120
if name in CONFIGURABLE_ATTRIBUTES and not val:
115121
default_value = ImageRetriever._config.resolve_value_from_config(
116122
config_path=_simple_path(
117-
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
123+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name)
118124
)
119125
)
120126
if default_value is not None:
121-
locals()[name] = default_value
127+
args[name] = default_value
128+
129+
# Apply resolved defaults back to local variables
130+
version = args.get("version", version)
131+
py_version = args.get("py_version", py_version)
132+
instance_type = args.get("instance_type", instance_type)
133+
accelerator_type = args.get("accelerator_type", accelerator_type)
134+
image_scope = args.get("image_scope", image_scope)
135+
container_version = args.get("container_version", container_version)
136+
distributed = args.get("distributed", distributed)
137+
base_framework_version = args.get("base_framework_version", base_framework_version)
138+
training_compiler_config = args.get("training_compiler_config", training_compiler_config)
139+
sdk_version = args.get("sdk_version", sdk_version)
140+
inference_tool = args.get("inference_tool", inference_tool)
141+
serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config)
122142

123143
if training_compiler_config:
124144
final_image_scope = image_scope
@@ -503,11 +523,28 @@ def retrieve(
503523
if name in CONFIGURABLE_ATTRIBUTES and not val:
504524
default_value = ImageRetriever._config.resolve_value_from_config(
505525
config_path=_simple_path(
506-
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
526+
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, _to_pascal_case(name)
507527
)
508528
)
509529
if default_value is not None:
510-
locals()[name] = default_value
530+
args[name] = default_value
531+
532+
# Apply resolved defaults back to local variables
533+
version = args.get("version", version)
534+
py_version = args.get("py_version", py_version)
535+
instance_type = args.get("instance_type", instance_type)
536+
accelerator_type = args.get("accelerator_type", accelerator_type)
537+
image_scope = args.get("image_scope", image_scope)
538+
container_version = args.get("container_version", container_version)
539+
distributed = args.get("distributed", distributed)
540+
smp = args.get("smp", smp)
541+
base_framework_version = args.get("base_framework_version", base_framework_version)
542+
training_compiler_config = args.get("training_compiler_config", training_compiler_config)
543+
model_id = args.get("model_id", model_id)
544+
model_version = args.get("model_version", model_version)
545+
sdk_version = args.get("sdk_version", sdk_version)
546+
inference_tool = args.get("inference_tool", inference_tool)
547+
serverless_inference_config = args.get("serverless_inference_config", serverless_inference_config)
511548

512549
for name, val in args.items():
513550
if is_pipeline_variable(val):

sagemaker-core/tests/integ/image_retriever/test_image_retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_retrieve_base_python_image_uri():
9494
assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0"
9595

9696

97-
@pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever")
97+
# @pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever")
9898
@patch.object(SageMakerConfig, "resolve_value_from_config")
9999
def test_retrieve_image_uri_intelligent_default(mock_load_config):
100100
def custom_return(config_path=None, **kwargs):
@@ -116,5 +116,5 @@ def custom_return(config_path=None, **kwargs):
116116
)
117117
assert (
118118
image_uri
119-
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu"
119+
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.2.0-gpu"
120120
)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""KMS test helpers for integration tests.
14+
15+
Ported from SageMakerHulkPythonSDK/tests/integ/kms_utils.py.
16+
17+
NOTE: KMS keys created by these helpers use a fixed alias and are intentionally
18+
reused across test runs rather than deleted after each run. This is because KMS
19+
keys have a mandatory 7-day minimum deletion window (schedule_key_deletion), so
20+
per-run create/delete is not practical. The persistent shared key approach avoids
21+
accumulating orphaned keys and unnecessary costs.
22+
"""
23+
from __future__ import absolute_import
24+
25+
import json
26+
27+
from sagemaker.core.common_utils import aws_partition, sts_regional_endpoint
28+
29+
PRINCIPAL_TEMPLATE = (
30+
'["{account_id}", "{role_arn}", '
31+
'"arn:{partition}:iam::{account_id}:role/{sagemaker_role}"] '
32+
)
33+
34+
KEY_ALIAS = "SageMakerTestKMSKey"
35+
POLICY_NAME = "default"
36+
KEY_POLICY = """
37+
{{
38+
"Version": "2012-10-17",
39+
"Id": "{id}",
40+
"Statement": [
41+
{{
42+
"Sid": "Enable IAM User Permissions",
43+
"Effect": "Allow",
44+
"Principal": {{
45+
"AWS": {principal}
46+
}},
47+
"Action": "kms:*",
48+
"Resource": "*"
49+
}}
50+
]
51+
}}
52+
"""
53+
54+
55+
def _get_kms_key_arn(kms_client, alias):
56+
try:
57+
response = kms_client.describe_key(KeyId="alias/" + alias)
58+
return response["KeyMetadata"]["Arn"]
59+
except kms_client.exceptions.NotFoundException:
60+
return None
61+
62+
63+
def _get_kms_key_id(kms_client, alias):
64+
try:
65+
response = kms_client.describe_key(KeyId="alias/" + alias)
66+
return response["KeyMetadata"]["KeyId"]
67+
except kms_client.exceptions.NotFoundException:
68+
return None
69+
70+
71+
def _create_kms_key(
72+
kms_client, account_id, region, role_arn=None, sagemaker_role="SageMakerRole", alias=KEY_ALIAS
73+
):
74+
if role_arn:
75+
principal = PRINCIPAL_TEMPLATE.format(
76+
partition=aws_partition(region),
77+
account_id=account_id,
78+
role_arn=role_arn,
79+
sagemaker_role=sagemaker_role,
80+
)
81+
else:
82+
principal = '"{account_id}"'.format(account_id=account_id)
83+
84+
response = kms_client.create_key(
85+
Policy=KEY_POLICY.format(
86+
id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role
87+
),
88+
Description="KMS key for SageMaker Python SDK integ tests",
89+
)
90+
key_arn = response["KeyMetadata"]["Arn"]
91+
92+
if alias:
93+
kms_client.create_alias(AliasName="alias/" + alias, TargetKeyId=key_arn)
94+
return key_arn
95+
96+
97+
def _add_role_to_policy(
98+
kms_client, account_id, role_arn, region, alias=KEY_ALIAS, sagemaker_role="SageMakerRole"
99+
):
100+
key_id = _get_kms_key_id(kms_client, alias)
101+
policy = kms_client.get_key_policy(KeyId=key_id, PolicyName=POLICY_NAME)
102+
policy = json.loads(policy["Policy"])
103+
principal = policy["Statement"][0]["Principal"]["AWS"]
104+
105+
if role_arn not in principal or sagemaker_role not in principal:
106+
principal = PRINCIPAL_TEMPLATE.format(
107+
partition=aws_partition(region),
108+
account_id=account_id,
109+
role_arn=role_arn,
110+
sagemaker_role=sagemaker_role,
111+
)
112+
113+
kms_client.put_key_policy(
114+
KeyId=key_id,
115+
PolicyName=POLICY_NAME,
116+
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal),
117+
)
118+
119+
120+
def get_or_create_kms_key(
121+
sagemaker_session, role_arn=None, alias=KEY_ALIAS, sagemaker_role="SageMakerRole"
122+
):
123+
kms_client = sagemaker_session.boto_session.client("kms")
124+
kms_key_arn = _get_kms_key_arn(kms_client, alias)
125+
126+
region = sagemaker_session.boto_region_name
127+
sts_client = sagemaker_session.boto_session.client(
128+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
129+
)
130+
account_id = sts_client.get_caller_identity()["Account"]
131+
132+
if kms_key_arn is None:
133+
return _create_kms_key(kms_client, account_id, region, role_arn, sagemaker_role, alias)
134+
135+
if role_arn:
136+
_add_role_to_policy(kms_client, account_id, role_arn, region, alias, sagemaker_role)
137+
138+
return kms_key_arn

sagemaker-core/tests/integ/remote_function/test_decorator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import sys
1415
import time
1516
from typing import Union
1617

@@ -40,7 +41,7 @@
4041
from sagemaker.core.common_utils import unique_name_from_base
4142
from tests.integ.s3_utils import assert_s3_files_exist
4243

43-
# from tests.integ.kms_utils import get_or_create_kms_key # TODO: provide KMS utils
44+
from tests.integ.integ_test_kms_helpers import get_or_create_kms_key
4445
import os
4546

4647
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data")
@@ -122,7 +123,7 @@ def divide(x, y):
122123

123124

124125
# TODO: add VPC settings, update SageMakerRole with KMS permissions
125-
@pytest.mark.skip
126+
# @pytest.mark.skip
126127
def test_advanced_job_setting(
127128
sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key
128129
):
@@ -573,7 +574,10 @@ def my_func():
573574
assert client_error_message in str(error)
574575

575576

576-
@pytest.mark.skip
577+
@pytest.mark.skipif(
578+
sys.version_info[:2] not in [(3, 9), (3, 12)],
579+
reason="SageMaker Spark image only available for Python 3.9 and 3.12",
580+
)
577581
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
578582
@remote(
579583
role=ROLE,
@@ -584,7 +588,7 @@ def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
584588
configuration=[
585589
{
586590
"Classification": "spark-defaults",
587-
"Properties": {"spark.app.name", "remote-spark-test"},
591+
"Properties": {"spark.app.name": "remote-spark-test"},
588592
}
589593
]
590594
),
@@ -594,12 +598,12 @@ def test_spark_transform():
594598

595599
spark = SparkSession.builder.getOrCreate()
596600

597-
assert spark.conf.get(spark.app.name) == "remote-spark-test"
601+
assert spark.conf.get("spark.app.name") == "remote-spark-test"
598602

599603
test_spark_transform()
600604

601605

602-
@pytest.mark.skip
606+
# @pytest.mark.skip
603607
def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
604608
"""
605609
This test runs a docker container. The Container invocation will execute a python script

0 commit comments

Comments
 (0)