Skip to content

Commit beb078d

Browse files
committed
fix spark session bug
1 parent 5ca86dc commit beb078d

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def get_spark_session_with_iceberg_config(self, warehouse_s3_uri, catalog) -> Sp
180180
SparkSession: A SparkSession ready to support reading and writing data from an Iceberg
181181
Table.
182182
"""
183-
conf = self.spark_session._jvm.SparkSession().conf()
183+
conf = self.spark_session.conf
184184

185185
for cfg in self._get_iceberg_configs(warehouse_s3_uri, catalog):
186186
conf.set(cfg[0], cfg[1])

sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,20 @@ def test_spark_session_factory_with_iceberg_config(mock_spark_context):
117117
spark_session_factory = SparkSessionFactory(mock_env_helper)
118118

119119
spark_session = spark_session_factory.spark_session
120+
mock_conf = Mock()
120121

121-
spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config(
122-
"warehouse", "catalog"
123-
)
122+
with patch.object(type(spark_session), "conf", new_callable=lambda: property(lambda self: mock_conf)):
123+
spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config(
124+
"warehouse", "catalog"
125+
)
124126

125-
assert spark_session is spark_session_with_iceberg_config
126-
mock_spark_conf = spark_session._jvm.SparkSession().conf()
127-
expected_calls = [
128-
call.set(cfg[0], cfg[1])
129-
for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog")
130-
]
127+
assert spark_session is spark_session_with_iceberg_config
128+
expected_calls = [
129+
call.set(cfg[0], cfg[1])
130+
for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog")
131+
]
131132

132-
mock_spark_conf.assert_has_calls(expected_calls, any_order=False)
133+
mock_conf.assert_has_calls(expected_calls, any_order=False)
133134

134135

135136
@patch("pyspark.context.SparkContext.getOrCreate")

0 commit comments

Comments
 (0)