|
32 | 32 | from py4j.protocol import Py4JJavaError |
33 | 33 |
|
34 | 34 | from pyspark import SparkConf, SparkContext |
35 | | -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually |
| 35 | +from pyspark.testing.utils import ( |
| 36 | + ReusedPySparkTestCase, |
| 37 | + PySparkTestCase, |
| 38 | + QuietTest, |
| 39 | + eventually, |
| 40 | +) |
36 | 41 |
|
37 | 42 |
|
38 | 43 | class WorkerTests(ReusedPySparkTestCase): |
@@ -272,6 +277,65 @@ def test_worker_crash(self): |
272 | 277 | rdd.map(lambda x: os.getpid()).collect() |
273 | 278 |
|
274 | 279 |
|
| 280 | +class SimpleWorkerTests(WorkerTests): |
| 281 | + """Run worker tests through the non-daemon (simple-worker) path. |
| 282 | +
|
| 283 | + Windows always uses this path; Linux/macOS use it when |
| 284 | + spark.python.use.daemon=false. |
| 285 | + """ |
| 286 | + |
| 287 | + @classmethod |
| 288 | + def conf(cls): |
| 289 | + _conf = super(SimpleWorkerTests, cls).conf() |
| 290 | + _conf.set("spark.python.use.daemon", "false") |
| 291 | + return _conf |
| 292 | + |
| 293 | + def test_create_dataframe(self): |
| 294 | + """DataFrame creation through the simple-worker path.""" |
| 295 | + from pyspark.sql import SparkSession |
| 296 | + |
| 297 | + spark = SparkSession(self.sc) |
| 298 | + df = spark.createDataFrame([("Alice", 30), ("Bob", 25)], ["name", "age"]) |
| 299 | + self.assertEqual(df.count(), 2) |
| 300 | + |
| 301 | + def test_udf(self): |
| 302 | + """UDF execution through the simple-worker path.""" |
| 303 | + from pyspark.sql import SparkSession |
| 304 | + from pyspark.sql.functions import udf |
| 305 | + from pyspark.sql.types import StringType |
| 306 | + |
| 307 | + spark = SparkSession(self.sc) |
| 308 | + str_udf = udf(lambda x: f"val_{x}", StringType()) |
| 309 | + rows = spark.range(5).withColumn("x", str_udf("id")).collect() |
| 310 | + self.assertEqual(len(rows), 5) |
| 311 | + |
| 312 | + def test_datasource_read(self): |
| 313 | + """Python Data Source read through the simple-worker path.""" |
| 314 | + from pyspark.sql import SparkSession |
| 315 | + from pyspark.sql.datasource import DataSource, DataSourceReader |
| 316 | + |
| 317 | + class TestReader(DataSourceReader): |
| 318 | + def read(self, partition): |
| 319 | + yield (0, "a") |
| 320 | + yield (1, "b") |
| 321 | + |
| 322 | + class TestDataSource(DataSource): |
| 323 | + @classmethod |
| 324 | + def name(cls): |
| 325 | + return "test_simple_worker" |
| 326 | + |
| 327 | + def schema(self): |
| 328 | + return "id INT, value STRING" |
| 329 | + |
| 330 | + def reader(self, schema): |
| 331 | + return TestReader() |
| 332 | + |
| 333 | + spark = SparkSession(self.sc) |
| 334 | + spark.dataSource.register(TestDataSource) |
| 335 | + df = spark.read.format("test_simple_worker").load() |
| 336 | + self.assertEqual(df.count(), 2) |
| 337 | + |
| 338 | + |
275 | 339 | if __name__ == "__main__": |
276 | 340 | from pyspark.testing import main |
277 | 341 |
|
|
0 commit comments