Skip to content

Commit 052d4ab

Browse files
committed
test(feature-processor): Mock classpath_jars in spark session factory tests
- Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_session_factory_configuration - Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_session_factory_configuration_on_training_job - Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_session_factory - Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_session_factory_with_iceberg_config - Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_session_factory_same_instance - Add @patch decorator to mock feature_store_pyspark.classpath_jars in test_spark_configs_use_dynamic_hadoop_version - Replace direct calls to feature_store_pyspark.classpath_jars() with mock_classpath_jars.return_value - Update test_repack_model.py to use resolved path variable for consistency in _get_safe_members test
1 parent caa8650 commit 052d4ab

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def env_helper():
3232
)
3333

3434

35-
def test_spark_session_factory_configuration():
35+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
36+
def test_spark_session_factory_configuration(mock_classpath_jars):
3637
env_helper = Mock()
3738
spark_config = {"spark.test.key": "spark.test.value"}
3839
spark_session_factory = SparkSessionFactory(env_helper, spark_config)
@@ -72,7 +73,7 @@ def test_spark_session_factory_configuration():
7273
assert jsc_hadoop_configs.get("mapreduce.fileoutputcommitter.marksuccessfuljobs") == "false"
7374

7475
# Verify configurations when not running on a training job
75-
assert ",".join(feature_store_pyspark.classpath_jars()) in spark_configs.get("spark.jars")
76+
assert ",".join(mock_classpath_jars.return_value) in spark_configs.get("spark.jars")
7677
from sagemaker.mlops.feature_store.feature_processor._spark_factory import _get_hadoop_version
7778
hadoop_version = _get_hadoop_version()
7879
assert ",".join(
@@ -83,7 +84,8 @@ def test_spark_session_factory_configuration():
8384
) in spark_configs.get("spark.jars.packages")
8485

8586

86-
def test_spark_session_factory_configuration_on_training_job():
87+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
88+
def test_spark_session_factory_configuration_on_training_job(mock_classpath_jars):
8789
env_helper = Mock()
8890
spark_config = {"spark.test.key": "spark.test.value"}
8991
spark_session_factory = SparkSessionFactory(env_helper, spark_config)
@@ -94,11 +96,12 @@ def test_spark_session_factory_configuration_on_training_job():
9496
assert all(tup[0] != "spark.jars.packages" for tup in spark_config)
9597

9698
# spark.jars should always be present (Feature Store JARs are always on the classpath)
97-
assert ",".join(feature_store_pyspark.classpath_jars()) in dict(spark_config).get("spark.jars")
99+
assert ",".join(mock_classpath_jars.return_value) in dict(spark_config).get("spark.jars")
98100

99101

102+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
100103
@patch("pyspark.context.SparkContext.getOrCreate")
101-
def test_spark_session_factory(mock_spark_context):
104+
def test_spark_session_factory(mock_spark_context, mock_classpath_jars):
102105
env_helper = Mock()
103106
env_helper.get_instance_count.return_value = 1
104107
spark_session_factory = SparkSessionFactory(env_helper)
@@ -114,8 +117,9 @@ def test_spark_session_factory(mock_spark_context):
114117
assert spark_conf.get(cfg[0]) == cfg[1]
115118

116119

120+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
117121
@patch("pyspark.context.SparkContext.getOrCreate")
118-
def test_spark_session_factory_with_iceberg_config(mock_spark_context):
122+
def test_spark_session_factory_with_iceberg_config(mock_spark_context, mock_classpath_jars):
119123
mock_env_helper = Mock()
120124
mock_spark_context.side_effect = [Mock(), Mock()]
121125

@@ -138,8 +142,9 @@ def test_spark_session_factory_with_iceberg_config(mock_spark_context):
138142
mock_conf.assert_has_calls(expected_calls, any_order=False)
139143

140144

145+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
141146
@patch("pyspark.context.SparkContext.getOrCreate")
142-
def test_spark_session_factory_same_instance(mock_spark_context):
147+
def test_spark_session_factory_same_instance(mock_spark_context, mock_classpath_jars):
143148
mock_env_helper = Mock()
144149
mock_spark_context.side_effect = [Mock(), Mock()]
145150

@@ -202,7 +207,8 @@ def test_get_hadoop_version_unknown_falls_back():
202207
assert _get_hadoop_version() == "3.3.4"
203208

204209

205-
def test_spark_configs_use_dynamic_hadoop_version():
210+
@patch("feature_store_pyspark.classpath_jars", return_value=["/path/to/jar.jar"])
211+
def test_spark_configs_use_dynamic_hadoop_version(mock_classpath_jars):
206212
with patch.object(pyspark, "__version__", "3.5.1"):
207213
env_helper = Mock()
208214
factory = SparkSessionFactory(env_helper)

sagemaker-mlops/tests/unit/workflow/test_repack_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_is_bad_link_unsafe():
9292

9393
def test_get_safe_members_all_safe():
9494
"""Test _get_safe_members yields all safe members."""
95-
base = _get_resolved_path("")
95+
base = _get_resolved_path("/tmp/extract")
9696

9797
mock_member1 = Mock()
9898
mock_member1.name = "safe/file1.txt"
@@ -105,7 +105,7 @@ def test_get_safe_members_all_safe():
105105
mock_member2.islnk = Mock(return_value=False)
106106

107107
members = [mock_member1, mock_member2]
108-
safe_members = list(_get_safe_members(members, "/tmp/extract"))
108+
safe_members = list(_get_safe_members(members, base))
109109

110110
assert len(safe_members) == 2
111111
assert mock_member1 in safe_members

0 commit comments

Comments
 (0)