Skip to content

Commit 85c325b

Browse files
authored
Merge pull request #259 from m-aciek/spark-4-support
Add Spark 4.0 support via deequ:2.0.14-spark-4.0
2 parents c498e7d + 3e4e9fc commit 85c325b

7 files changed

Lines changed: 71 additions & 10 deletions

File tree

.github/workflows/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ jobs:
2525
- PYSPARK_VERSION: "3.5"
2626
PYTHON_VERSION: "3.9"
2727
JAVA_VERSION: "17"
28+
- PYSPARK_VERSION: "4.0"
29+
PYTHON_VERSION: "3.9"
30+
JAVA_VERSION: "17"
31+
PANDAS_VERSION_SPEC: ">=2.0.0"
2832

2933
steps:
3034
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -41,9 +45,11 @@ jobs:
4145
- name: Running tests with pyspark==${{matrix.PYSPARK_VERSION}}
4246
env:
4347
SPARK_VERSION: ${{matrix.PYSPARK_VERSION}}
48+
PANDAS_VERSION_SPEC: ${{matrix.PANDAS_VERSION_SPEC}}
4449
run: |
4550
pip install --upgrade pip
4651
pip install poetry==1.7.1
4752
poetry install
4853
poetry run pip install pyspark==$SPARK_VERSION
54+
if [ -n "$PANDAS_VERSION_SPEC" ]; then poetry run pip install "pandas$PANDAS_VERSION_SPEC"; fi
4955
poetry run python -m pytest -s tests --ignore=tests/test_bot.py

pydeequ/analyzers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydeequ.pandas_utils import ensure_pyspark_df
1010
from pydeequ.repository import MetricsRepository, ResultKey
1111
from enum import Enum
12-
from pydeequ.scala_utils import to_scala_seq
12+
from pydeequ.scala_utils import empty_scala_seq, to_scala_seq
1313
from pydeequ.configs import SPARK_VERSION
1414

1515
class _AnalyzerObject:
@@ -311,7 +311,7 @@ def _analyzer_jvm(self):
311311
self.instance,
312312
self.predicate,
313313
self._jvm.scala.Option.apply(self.where),
314-
self._jvm.scala.collection.Seq.empty(),
314+
empty_scala_seq(self._jvm),
315315
self._jvm.scala.Option.apply(None)
316316
)
317317

pydeequ/checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pyspark.sql import SparkSession
66

77
from pydeequ.check_functions import is_one
8-
from pydeequ.scala_utils import ScalaFunction1, to_scala_seq
8+
from pydeequ.scala_utils import ScalaFunction1, empty_scala_seq, to_scala_seq
99
from pydeequ.configs import SPARK_VERSION
1010

1111
# TODO implement custom assertions
@@ -563,7 +563,7 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None):
563563
constraintName,
564564
assertion_func,
565565
hint,
566-
self._jvm.scala.collection.Seq.empty(),
566+
empty_scala_seq(self._jvm),
567567
self._jvm.scala.Option.apply(None)
568568
)
569569
return self

pydeequ/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
SPARK_TO_DEEQU_COORD_MAPPING = {
8+
"4.0": "com.amazon.deequ:deequ:2.0.14-spark-4.0",
89
"3.5": "com.amazon.deequ:deequ:2.0.8-spark-3.5",
910
"3.3": "com.amazon.deequ:deequ:2.0.8-spark-3.3",
1011
"3.2": "com.amazon.deequ:deequ:2.0.8-spark-3.2",

pydeequ/profiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def _columnProfilesFromColumnRunBuilderRun(self, run):
254254
:return: a setter for columnProfilerRunner result
255255
"""
256256
self._run_result = run
257-
profile_map = self._jvm.scala.collection.JavaConversions.mapAsJavaMap(run.profiles()) # TODO from ScalaUtils
257+
profile_map = scala_map_to_java_map(self._jvm, run.profiles())
258258
self._profiles = {column: self._columnProfileBuilder(column, profile_map[column]) for column in profile_map}
259259
return self
260260

pydeequ/scala_utils.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,30 @@ def get_or_else_none(scala_option):
6868
return scala_option.get()
6969

7070

71+
# Cache per JVM instance so version detection only happens once per session.
72+
_jvm_converters_cache: dict = {}
73+
74+
75+
def _get_converters(jvm):
76+
"""
77+
Return (style, converters) for the running Scala version.
78+
style='jdk' → scala.jdk.javaapi.CollectionConverters (Scala 2.13, Spark 4+)
79+
style='legacy' → scala.collection.JavaConverters (Scala 2.12, Spark 3.x)
80+
"""
81+
key = id(jvm)
82+
if key not in _jvm_converters_cache:
83+
try:
84+
converters = jvm.scala.jdk.javaapi.CollectionConverters
85+
# On Scala 2.12, the path resolves to a JavaPackage placeholder (no class
86+
# exists), so attribute access succeeds but any method call raises TypeError.
87+
# Probe with an actual call to confirm the class is genuinely usable.
88+
converters.asScala(jvm.java.util.ArrayList())
89+
_jvm_converters_cache[key] = ("jdk", converters)
90+
except Exception:
91+
_jvm_converters_cache[key] = ("legacy", jvm.scala.collection.JavaConverters)
92+
return _jvm_converters_cache[key]
93+
94+
7195
def to_scala_seq(jvm, iterable):
7296
"""
7397
Helper method to take an iterable and turn it into a Scala sequence
@@ -77,7 +101,23 @@ def to_scala_seq(jvm, iterable):
77101
Returns:
78102
Scala sequence
79103
"""
80-
return jvm.scala.collection.JavaConversions.iterableAsScalaIterable(iterable).toSeq()
104+
style, converters = _get_converters(jvm)
105+
if style == "jdk":
106+
return converters.asScala(iterable).toSeq()
107+
return converters.iterableAsScalaIterableConverter(iterable).asScala().toSeq()
108+
109+
110+
def empty_scala_seq(jvm):
111+
"""
112+
Returns an empty Scala immutable List (Nil), usable as Seq[_].
113+
Converts an empty ArrayList via .asScala().toList() to produce an immutable.List
114+
rather than a Stream, which is required for Py4J constructor/method lookup to
115+
succeed across both Scala 2.12 (Spark 3.x) and Scala 2.13 (Spark 4+).
116+
"""
117+
style, converters = _get_converters(jvm)
118+
if style == "jdk":
119+
return converters.asScala(jvm.java.util.ArrayList()).toList()
120+
return converters.iterableAsScalaIterableConverter(jvm.java.util.ArrayList()).asScala().toList()
81121

82122

83123
def to_scala_map(spark_session, d):
@@ -89,15 +129,29 @@ def to_scala_map(spark_session, d):
89129
Returns:
90130
Scala map
91131
"""
92-
return spark_session._jvm.PythonUtils.toScalaMap(d)
132+
jvm = spark_session._jvm
133+
try:
134+
# PythonUtils.toScalaMap is a PySpark internal that may be removed in future versions.
135+
return jvm.PythonUtils.toScalaMap(d)
136+
except Exception:
137+
style, converters = _get_converters(jvm)
138+
if style == "jdk":
139+
return converters.asScala(d).toMap()
140+
return converters.mapAsScalaMapConverter(d).asScala().toMap()
93141

94142

95143
def scala_map_to_dict(jvm, scala_map):
96-
return dict(jvm.scala.collection.JavaConversions.mapAsJavaMap(scala_map))
144+
style, converters = _get_converters(jvm)
145+
if style == "jdk":
146+
return dict(converters.asJava(scala_map))
147+
return dict(converters.mapAsJavaMapConverter(scala_map).asJava())
97148

98149

99150
def scala_map_to_java_map(jvm, scala_map):
100-
return jvm.scala.collection.JavaConversions.mapAsJavaMap(scala_map)
151+
style, converters = _get_converters(jvm)
152+
if style == "jdk":
153+
return converters.asJava(scala_map)
154+
return converters.mapAsJavaMapConverter(scala_map).asJava()
101155

102156

103157
def java_list_to_python_list(java_list: str, datatype):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ classifiers = [
3131
python = ">=3.8,<4"
3232
numpy = ">=1.14.1"
3333
pandas = ">=0.23.0"
34-
pyspark = { version = ">=2.4.7,<4.0.0", optional = true }
34+
pyspark = { version = ">=2.4.7,<5.0.0", optional = true }
3535

3636
[tool.poetry.dev-dependencies]
3737
pytest = "^6.2.4"

0 commit comments

Comments
 (0)