Skip to content

Commit 31cd9c7

Browse files
committed
[SPARK-56373][PYSPARK] Add docstring annotations to classify PySpark APIs for Spark Connect compatibility
Adds three RST directives to PySpark modules, classes, and methods to indicate Spark Connect compatibility status: - `.. classic:: true` -- API is only available in Classic Spark (not Spark Connect) - `.. connect:: true` -- API is available in Spark Connect - `.. connect_migration:: <message>` -- migration guidance for users transitioning to Spark Connect Annotations are resolved by inheriting from the nearest annotated ancestor; a child annotation overrides the parent's. No functional code changes -- docstrings only. The annotation spec is documented in `python/pyspark/__init__.py`.
1 parent 98cdaee commit 31cd9c7

File tree

27 files changed

+176
-3
lines changed

27 files changed

+176
-3
lines changed

python/pyspark/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@
4444
Information about a barrier task.
4545
- :class:`InheritableThread`:
4646
A inheritable thread to use in Spark when the pinned thread mode is on.
47+
48+
Spark Connect compatibility annotations
49+
=======================================
50+
51+
The following RST directives annotate PySpark modules, classes, and methods with
52+
their Spark Connect compatibility status:
53+
54+
- ``.. classic:: true`` -- the API is only available in Classic Spark (not Spark Connect).
55+
- ``.. connect:: true`` -- the API is available in Spark Connect.
56+
- ``.. connect_migration:: <message>`` -- migration guidance for users transitioning
57+
from Classic Spark to Spark Connect.
58+
59+
Annotations are resolved by inheriting from the nearest annotated ancestor. A child
60+
annotation overrides the parent's.
61+
62+
.. classic:: true
63+
.. connect:: true
4764
"""
4865

4966
import sys

python/pyspark/accumulators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18+
"""
19+
.. classic:: true
20+
21+
.. connect_migration:: Use `df.observe(name, *exprs)` to collect named metrics during
22+
query execution.
23+
"""
24+
1825
import os
1926
import sys
2027
import select

python/pyspark/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18+
"""
19+
.. classic:: true
20+
21+
.. connect_migration:: Read Spark SQL configuration values using `spark.conf.get(key)`
22+
and write them using `spark.conf.set(key, value)`.
23+
"""
24+
1825
__all__ = ["SparkConf"]
1926

2027
import sys

python/pyspark/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
"""
19+
.. classic:: true
20+
"""

python/pyspark/core/context.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,8 @@ def setLogLevel(self, logLevel: str) -> None:
558558
Examples
559559
--------
560560
>>> sc.setLogLevel("WARN") # doctest :+SKIP
561+
562+
.. connect_migration:: Replace sc.setLogLevel(level) with spark.log.level(level)
561563
"""
562564
self._jsc.setLogLevel(logLevel)
563565

@@ -630,6 +632,8 @@ def applicationId(self) -> str:
630632
--------
631633
>>> sc.applicationId # doctest: +ELLIPSIS
632634
'local-...'
635+
636+
.. connect_migration:: Replace spark.sparkContext.applicationId with spark.conf.get("spark.app.id")
633637
"""
634638
return self._jsc.sc().applicationId()
635639

@@ -675,6 +679,9 @@ def defaultParallelism(self) -> int:
675679
--------
676680
>>> sc.defaultParallelism > 0
677681
True
682+
683+
.. connect_migration:: Replace spark.sparkContext.defaultParallelism with
684+
int(spark.conf.get("spark.default.parallelism", "200"))
678685
"""
679686
return self._jsc.sc().defaultParallelism()
680687

@@ -734,6 +741,9 @@ def emptyRDD(self) -> RDD[Any]:
734741
EmptyRDD...
735742
>>> sc.emptyRDD().count()
736743
0
744+
745+
.. connect_migration:: Replace sc.emptyRDD with an empty list. When used with
746+
createDataFrame: spark.createDataFrame([], schema)
737747
"""
738748
return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
739749

@@ -828,6 +838,9 @@ def parallelize(self, c: Iterable[T], numSlices: Optional[int] = None) -> RDD[T]
828838
>>> strings = ["a", "b", "c"]
829839
>>> sc.parallelize(strings, 2).glom().collect()
830840
[['a'], ['b', 'c']]
841+
842+
.. connect_migration:: Replace sc.parallelize(data) with the Python collection directly.
843+
When used with createDataFrame: spark.createDataFrame(data, schema)
831844
"""
832845
numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
833846
if isinstance(c, range):
@@ -2212,6 +2225,10 @@ def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool =
22122225
>>> suppress = lock.acquire()
22132226
>>> print(result)
22142227
Cancelled
2228+
2229+
.. connect_migration:: Replace sc.setJobGroup(groupId, desc) with
2230+
spark.conf.set("spark.job.group.id", groupId) and
2231+
spark.conf.set("spark.job.description", desc)
22152232
"""
22162233
self._jsc.setJobGroup(groupId, description, interruptOnCancel)
22172234

@@ -2410,6 +2427,9 @@ def setLocalProperty(self, key: str, value: str) -> None:
24102427
-----
24112428
If you run jobs in parallel, use :class:`pyspark.InheritableThread` for thread
24122429
local inheritance.
2430+
2431+
.. connect_migration:: Replace spark.sparkContext.setLocalProperty(key, value) with
2432+
spark.conf.set(key, value)
24132433
"""
24142434
self._jsc.setLocalProperty(key, value)
24152435

@@ -2441,6 +2461,9 @@ def setJobDescription(self, value: str) -> None:
24412461
-----
24422462
If you run jobs in parallel, use :class:`pyspark.InheritableThread` for thread
24432463
local inheritance.
2464+
2465+
.. connect_migration:: Replace sc.setJobDescription(desc) with
2466+
spark.conf.set("spark.job.description", desc)
24442467
"""
24452468
self._jsc.setJobDescription(value)
24462469

@@ -2610,6 +2633,9 @@ def getConf(self) -> SparkConf:
26102633
"""Return a copy of this SparkContext's configuration :class:`SparkConf`.
26112634
26122635
.. versionadded:: 2.1.0
2636+
2637+
.. connect_migration:: Replace sc.getConf() with spark.conf. For a specific key use
2638+
spark.conf.get(key)
26132639
"""
26142640
conf = SparkConf()
26152641
conf.setAll(self._conf.getAll())

python/pyspark/core/rdd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def map(self: "RDD[T]", f: Callable[[T], U], preservesPartitioning: bool = False
609609
>>> rdd = sc.parallelize(["b", "a", "c"])
610610
>>> sorted(rdd.map(lambda x: (x, 1)).collect())
611611
[('a', 1), ('b', 1), ('c', 1)]
612+
613+
.. connect_migration:: Replace rdd.map() with DataFrame operations. Use
614+
df.withColumn(), df.select() with a UDF, or a pandas UDF instead.
612615
"""
613616

614617
def func(_: int, iterator: Iterable[T]) -> Iterable[U]:
@@ -697,6 +700,9 @@ def mapPartitions(
697700
...
698701
>>> rdd.mapPartitions(f).collect()
699702
[3, 7]
703+
704+
.. connect_migration:: Replace rdd.mapPartitions() with a pandas UDF using
705+
applyInPandas.
700706
"""
701707

702708
def func(_: int, iterator: Iterable[T]) -> Iterable[U]:

python/pyspark/errors/exceptions/connect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
"""
19+
.. connect:: true
20+
"""
21+
1722
import grpc
1823
import json
1924
from grpc import StatusCode

python/pyspark/java_gateway.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
# limitations under the License.
1616
#
1717

18+
"""
19+
.. classic:: true
20+
"""
21+
1822
import atexit
1923
import os
2024
import signal

python/pyspark/ml/clustering.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,10 @@ class LDAModel(JavaModel, _LDAParams):
14291429
including local and distributed data structures.
14301430
14311431
.. versionadded:: 2.0.0
1432+
1433+
.. classic:: true
1434+
1435+
.. connect_migration:: LDA model family is not supported in Spark Connect.
14321436
"""
14331437

14341438
@since("3.0.0")
@@ -1530,6 +1534,10 @@ class DistributedLDAModel(LDAModel, JavaMLReadable["DistributedLDAModel"], JavaM
15301534
for each training document.
15311535
15321536
.. versionadded:: 2.0.0
1537+
1538+
.. classic:: true
1539+
1540+
.. connect_migration:: LDA model family is not supported in Spark Connect.
15331541
"""
15341542

15351543
@functools.cache
@@ -1608,6 +1616,10 @@ class LocalLDAModel(LDAModel, JavaMLReadable["LocalLDAModel"], JavaMLWritable):
16081616
This model stores the inferred topics only; it does not store info about the training dataset.
16091617
16101618
.. versionadded:: 2.0.0
1619+
1620+
.. classic:: true
1621+
1622+
.. connect_migration:: LDA model family is not supported in Spark Connect.
16111623
"""
16121624

16131625
pass
@@ -1682,6 +1694,10 @@ class LDA(JavaEstimator[LDAModel], _LDAParams, JavaMLReadable["LDA"], JavaMLWrit
16821694
>>> sameLocalModel = LocalLDAModel.load(local_model_path)
16831695
>>> model.transform(df).take(1) == sameLocalModel.transform(df).take(1)
16841696
True
1697+
1698+
.. classic:: true
1699+
1700+
.. connect_migration:: LDA model family is not supported in Spark Connect.
16851701
"""
16861702

16871703
_input_kwargs: Dict[str, Any]

python/pyspark/ml/connect/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
# limitations under the License.
1616
#
1717

18-
"""Spark Connect Python Client - ML module"""
18+
"""
19+
Spark Connect Python Client - ML module
20+
21+
.. connect:: true
22+
"""
1923

2024
from pyspark.sql.connect.utils import check_dependencies
2125

0 commit comments

Comments
 (0)