Skip to content

Commit 8624808

Browse files
committed
feat(feature-processor): Add dynamic Hadoop version resolution
Replace hardcoded hadoop-aws:3.3.1 and hadoop-common:3.3.1 Maven coordinates with dynamically resolved versions based on the installed PySpark version. Add SPARK_TO_HADOOP_MAP supporting Spark 3.1-3.5 with fallback to latest known Hadoop version for unknown Spark versions. --- X-AI-Prompt: implement dynamic hadoop version resolution in _spark_factory.py for multi-spark version compatibility X-AI-Tool: kiro
1 parent b31f02a commit 8624808

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
"""Contains factory classes for instantiating Spark objects."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from functools import lru_cache
1718
from typing import List, Tuple, Dict
1819

1920
import feature_store_pyspark
21+
import pyspark
2022
import feature_store_pyspark.FeatureStoreManager as fsm
2123
from pyspark.conf import SparkConf
2224
from pyspark.context import SparkContext
@@ -26,6 +28,33 @@
2628

2729
SPARK_APP_NAME = "FeatureProcessor"
2830

31+
logger = logging.getLogger(__name__)
32+
33+
SPARK_TO_HADOOP_MAP = {
34+
"3.1": "3.2.0",
35+
"3.2": "3.3.1",
36+
"3.3": "3.3.2",
37+
"3.4": "3.3.4",
38+
"3.5": "3.3.4",
39+
}
40+
41+
_DEFAULT_HADOOP_VERSION = "3.3.4"
42+
43+
44+
def _get_hadoop_version():
45+
"""Resolve the Hadoop version for the installed PySpark version."""
46+
spark_version = pyspark.__version__
47+
major_minor = ".".join(spark_version.split(".")[:2])
48+
hadoop_version = SPARK_TO_HADOOP_MAP.get(major_minor)
49+
if hadoop_version is None:
50+
hadoop_version = _DEFAULT_HADOOP_VERSION
51+
logger.warning(
52+
"Unknown Spark version %s. Falling back to Hadoop %s.",
53+
spark_version,
54+
hadoop_version,
55+
)
56+
return hadoop_version
57+
2958

3059
class SparkSessionFactory:
3160
"""Lazy loading, memoizing, instantiation of SparkSessions.
@@ -116,9 +145,10 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]:
116145

117146
if not is_training_job:
118147
fp_spark_jars = feature_store_pyspark.classpath_jars()
148+
hadoop_version = _get_hadoop_version()
119149
fp_spark_packages = [
120-
"org.apache.hadoop:hadoop-aws:3.3.1",
121-
"org.apache.hadoop:hadoop-common:3.3.1",
150+
f"org.apache.hadoop:hadoop-aws:{hadoop_version}",
151+
f"org.apache.hadoop:hadoop-common:{hadoop_version}",
122152
]
123153

124154
if self.spark_config and "spark.jars" in self.spark_config:

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import feature_store_pyspark
17+
import pyspark
1718
import pytest
1819
from mock import Mock, patch, call
1920

@@ -72,10 +73,12 @@ def test_spark_session_factory_configuration():
7273

7374
# Verify configurations when not running on a training job
7475
assert ",".join(feature_store_pyspark.classpath_jars()) in spark_configs.get("spark.jars")
76+
from sagemaker.mlops.feature_store.feature_processor._spark_factory import _get_hadoop_version
77+
hadoop_version = _get_hadoop_version()
7578
assert ",".join(
7679
[
77-
"org.apache.hadoop:hadoop-aws:3.3.1",
78-
"org.apache.hadoop:hadoop-common:3.3.1",
80+
f"org.apache.hadoop:hadoop-aws:{hadoop_version}",
81+
f"org.apache.hadoop:hadoop-common:{hadoop_version}",
7982
]
8083
) in spark_configs.get("spark.jars.packages")
8184

@@ -173,3 +176,34 @@ def test_spark_session_factory_get_spark_session_with_iceberg_config(env_helper)
173176
== "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO"
174177
)
175178
assert iceberg_configs.get("spark.sql.catalog.catalog.glue.skip-name-validation") == "true"
179+
180+
181+
@pytest.mark.parametrize(
182+
"spark_version,expected_hadoop",
183+
[
184+
("3.1.3", "3.2.0"),
185+
("3.2.2", "3.3.1"),
186+
("3.3.2", "3.3.2"),
187+
("3.4.1", "3.3.4"),
188+
("3.5.1", "3.3.4"),
189+
],
190+
)
191+
def test_get_hadoop_version(spark_version, expected_hadoop):
192+
with patch.object(pyspark, "__version__", spark_version):
193+
from sagemaker.mlops.feature_store.feature_processor._spark_factory import _get_hadoop_version
194+
assert _get_hadoop_version() == expected_hadoop
195+
196+
197+
def test_get_hadoop_version_unknown_falls_back():
198+
with patch.object(pyspark, "__version__", "3.6.0"):
199+
from sagemaker.mlops.feature_store.feature_processor._spark_factory import _get_hadoop_version
200+
assert _get_hadoop_version() == "3.3.4"
201+
202+
203+
def test_spark_configs_use_dynamic_hadoop_version():
204+
with patch.object(pyspark, "__version__", "3.5.1"):
205+
env_helper = Mock()
206+
factory = SparkSessionFactory(env_helper)
207+
configs = dict(factory._get_spark_configs(is_training_job=False))
208+
assert "org.apache.hadoop:hadoop-aws:3.3.4" in configs.get("spark.jars.packages")
209+
assert "org.apache.hadoop:hadoop-common:3.3.4" in configs.get("spark.jars.packages")

0 commit comments

Comments
 (0)