Skip to content

Commit dffc2fb

Browse files
committed
feat(feature-processor): Add Spark container image version resolver
Create _image_resolver.py with SPARK_IMAGE_SUPPORT_MATRIX and _get_spark_image_uri(session) that auto-detects PySpark + Python versions and returns the correct SageMaker Spark container image URI. Raises ValueError for unsupported combinations (e.g. Spark 3.4 has no container image, Python 3.12 only with Spark 3.5). --- X-AI-Prompt: create image resolver module for multi-spark version container image resolution X-AI-Tool: kiro
1 parent 8624808 commit dffc2fb

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Resolves SageMaker Spark container image URIs based on installed PySpark and Python versions."""
14+
from __future__ import absolute_import
15+
16+
import sys
17+
18+
import pyspark
19+
20+
from sagemaker.core import image_uris
21+
22+
SPARK_IMAGE_SUPPORT_MATRIX = {
23+
"3.1": ["py37"],
24+
"3.2": ["py39"],
25+
"3.3": ["py39"],
26+
"3.5": ["py39", "py312"],
27+
}
28+
29+
30+
def _get_spark_image_uri(session):
31+
"""Resolve the SageMaker Spark container image URI for the installed PySpark and Python versions.
32+
33+
Args:
34+
session: SageMaker Session with boto_region_name attribute.
35+
36+
Returns:
37+
str: The ECR image URI for the matching Spark container.
38+
39+
Raises:
40+
ValueError: If the Spark/Python version combination is not supported.
41+
"""
42+
spark_version = ".".join(pyspark.__version__.split(".")[:2])
43+
py_version = f"py{sys.version_info[0]}{sys.version_info[1]}"
44+
45+
supported_py = SPARK_IMAGE_SUPPORT_MATRIX.get(spark_version)
46+
if supported_py is None:
47+
supported = ", ".join(sorted(SPARK_IMAGE_SUPPORT_MATRIX.keys()))
48+
raise ValueError(
49+
f"No SageMaker Spark container image available for Spark {spark_version}. "
50+
f"Supported versions for remote execution: {supported}."
51+
)
52+
53+
if py_version not in supported_py:
54+
raise ValueError(
55+
f"SageMaker Spark {spark_version} container images support "
56+
f"{', '.join(supported_py)}. Current Python version: {py_version}."
57+
)
58+
59+
return image_uris.retrieve(
60+
framework="spark",
61+
region=session.boto_region_name,
62+
version=spark_version,
63+
py_version=py_version,
64+
container_version="v1",
65+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sys
16+
17+
import pyspark
18+
import pytest
19+
from mock import Mock, patch
20+
21+
from sagemaker.mlops.feature_store.feature_processor._image_resolver import _get_spark_image_uri
22+
23+
24+
@patch("sagemaker.mlops.feature_store.feature_processor._image_resolver.image_uris.retrieve")
25+
def test_spark_33_py39(mock_retrieve):
26+
mock_retrieve.return_value = "123456.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.3-cpu-py39-v1"
27+
session = Mock(boto_region_name="us-west-2")
28+
with patch.object(pyspark, "__version__", "3.3.2"), \
29+
patch.object(sys, "version_info", (3, 9, 0)):
30+
result = _get_spark_image_uri(session)
31+
mock_retrieve.assert_called_once_with(
32+
framework="spark",
33+
region="us-west-2",
34+
version="3.3",
35+
py_version="py39",
36+
container_version="v1",
37+
)
38+
assert result == mock_retrieve.return_value
39+
40+
41+
@patch("sagemaker.mlops.feature_store.feature_processor._image_resolver.image_uris.retrieve")
42+
def test_spark_35_py39(mock_retrieve):
43+
mock_retrieve.return_value = "123456.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.5-cpu-py39-v1"
44+
session = Mock(boto_region_name="us-west-2")
45+
with patch.object(pyspark, "__version__", "3.5.1"), \
46+
patch.object(sys, "version_info", (3, 9, 0)):
47+
result = _get_spark_image_uri(session)
48+
mock_retrieve.assert_called_once_with(
49+
framework="spark",
50+
region="us-west-2",
51+
version="3.5",
52+
py_version="py39",
53+
container_version="v1",
54+
)
55+
assert result == mock_retrieve.return_value
56+
57+
58+
@patch("sagemaker.mlops.feature_store.feature_processor._image_resolver.image_uris.retrieve")
59+
def test_spark_35_py312(mock_retrieve):
60+
mock_retrieve.return_value = "123456.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.5-cpu-py312-v1"
61+
session = Mock(boto_region_name="us-west-2")
62+
with patch.object(pyspark, "__version__", "3.5.1"), \
63+
patch.object(sys, "version_info", (3, 12, 0)):
64+
result = _get_spark_image_uri(session)
65+
mock_retrieve.assert_called_once_with(
66+
framework="spark",
67+
region="us-west-2",
68+
version="3.5",
69+
py_version="py312",
70+
container_version="v1",
71+
)
72+
assert result == mock_retrieve.return_value
73+
74+
75+
def test_spark_34_raises():
76+
session = Mock(boto_region_name="us-west-2")
77+
with patch.object(pyspark, "__version__", "3.4.1"), \
78+
patch.object(sys, "version_info", (3, 9, 0)):
79+
with pytest.raises(ValueError, match="No SageMaker Spark container image available for Spark 3.4"):
80+
_get_spark_image_uri(session)
81+
82+
83+
def test_spark_35_py310_raises():
84+
session = Mock(boto_region_name="us-west-2")
85+
with patch.object(pyspark, "__version__", "3.5.1"), \
86+
patch.object(sys, "version_info", (3, 10, 0)):
87+
with pytest.raises(ValueError, match="SageMaker Spark 3.5 container images support"):
88+
_get_spark_image_uri(session)
89+
90+
91+
def test_spark_33_py312_raises():
92+
session = Mock(boto_region_name="us-west-2")
93+
with patch.object(pyspark, "__version__", "3.3.2"), \
94+
patch.object(sys, "version_info", (3, 12, 0)):
95+
with pytest.raises(ValueError, match="SageMaker Spark 3.3 container images support"):
96+
_get_spark_image_uri(session)
97+
98+
99+
def test_unknown_spark_version_raises():
100+
session = Mock(boto_region_name="us-west-2")
101+
with patch.object(pyspark, "__version__", "3.6.0"), \
102+
patch.object(sys, "version_info", (3, 9, 0)):
103+
with pytest.raises(ValueError, match="No SageMaker Spark container image available for Spark 3.6"):
104+
_get_spark_image_uri(session)

0 commit comments

Comments
 (0)