diff --git a/python/pyspark/core/rdd.py b/python/pyspark/core/rdd.py index 803bb6d5b882b..573cc7ae165c5 100644 --- a/python/pyspark/core/rdd.py +++ b/python/pyspark/core/rdd.py @@ -4843,7 +4843,8 @@ def sumApprox( jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() assert self.ctx._jvm is not None jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) - r = jdrdd.sumApprox(timeout, confidence).getFinalValue() + partial = jdrdd.sumApprox(timeout, confidence) + r = partial.initialValue() return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) def meanApprox( diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 4c43d6181cc36..507d13c74d815 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -638,6 +638,18 @@ def test_distinct(self): self.assertEqual(result.getNumPartitions(), 5) self.assertEqual(result.count(), 3) + def test_count_approx_respects_timeout(self): + rdd = self.sc.range(1000000, numSlices=8) + start = time.time() + result = rdd.countApprox(timeout=100) + elapsed = time.time() - start + self.assertLess(elapsed, 10) + self.assertIsNotNone(result) + + def test_count_approx_returns_exact_when_completed(self): + rdd = self.sc.parallelize(range(1000), 8) + self.assertEqual(rdd.countApprox(timeout=5000), 1000) + def test_external_group_by_key(self): self.sc._conf.set("spark.python.worker.memory", "1m") N = 2000001