Skip to content

Commit 95ef9f7

Browse files
committed
[spark] Add aggregate functions: count_distinct, collect_set, count_if, max_by, min_by, bool_and, bool_or, kurtosis
Adds 8 PySpark aggregate functions (plus 3 aliases: countDistinct, every, some) tracked in duckdb/duckdb#14525: - count_distinct / countDistinct: array_length(array_distinct(list(x))) - collect_set: array_distinct(list(x)) (excludes NULL) - count_if: count_if(x) - max_by(col, ord): arg_max(arg, val) - min_by(col, ord): arg_min(arg, val) - bool_and / every: bool_and(x) - bool_or / some: bool_or(x) - kurtosis: kurtosis(x) Single-column count_distinct only (matching existing approx_count_distinct). Multi-column variant left for a follow-up due to Spark/SQL NULL-handling semantics.
1 parent 5c2a7f7 commit 95ef9f7

2 files changed

Lines changed: 253 additions & 0 deletions

File tree

duckdb/experimental/spark/sql/functions.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import operator
12
import warnings
23
from collections.abc import Callable
4+
from functools import reduce
35
from typing import TYPE_CHECKING, Any, Optional, Union, overload
46

57
from duckdb import (
@@ -6208,6 +6210,164 @@ def expr(str: str) -> Column:
62086210
return Column(SQLExpression(str))
62096211

62106212

6213+
def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
6214+
"""Aggregate function: returns the number of distinct rows considering the given columns.
6215+
6216+
Rows where any of the supplied columns is NULL are excluded from the count,
6217+
matching Spark / standard SQL `COUNT(DISTINCT col1, col2, ...)` semantics.
6218+
6219+
.. versionadded:: 1.3.0
6220+
6221+
Examples:
6222+
--------
6223+
>>> df = spark.createDataFrame([(1,), (1,), (2,), (None,)], ["v"])
6224+
>>> df.select(count_distinct(df.v).alias("d")).collect()
6225+
[Row(d=2)]
6226+
6227+
>>> df = spark.createDataFrame(
6228+
... [(1, "a"), (1, "a"), (1, "b"), (None, "c"), (2, None)], ["a", "b"]
6229+
... )
6230+
>>> df.select(count_distinct("a", "b").alias("d")).collect()
6231+
[Row(d=2)]
6232+
"""
6233+
exprs = [_to_column_expr(c) for c in (col, *cols)]
6234+
if len(exprs) == 1:
6235+
arg = exprs[0]
6236+
else:
6237+
any_null = reduce(operator.or_, (e.isnull() for e in exprs))
6238+
arg = CaseExpression(any_null, ConstantExpression(None)).otherwise(FunctionExpression("struct_pack", *exprs))
6239+
return _invoke_function(
6240+
"array_length",
6241+
FunctionExpression(
6242+
"array_distinct",
6243+
FunctionExpression("list", arg),
6244+
),
6245+
)
6246+
6247+
6248+
def countDistinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column:
6249+
"""Alias of :func:`count_distinct`."""
6250+
return count_distinct(col, *cols)
6251+
6252+
6253+
def collect_set(col: "ColumnOrName") -> Column:
6254+
"""Aggregate function: returns a set of objects with duplicate elements eliminated.
6255+
6256+
NULL values are excluded. The order of elements is non-deterministic.
6257+
6258+
.. versionadded:: 1.6.0
6259+
6260+
Examples:
6261+
--------
6262+
>>> df = spark.createDataFrame([(1,), (1,), (2,)], ["v"])
6263+
>>> sorted(df.select(collect_set("v")).first()[0])
6264+
[1, 2]
6265+
"""
6266+
return _invoke_function(
6267+
"array_distinct",
6268+
FunctionExpression("list", _to_column_expr(col)),
6269+
)
6270+
6271+
6272+
def count_if(col: "ColumnOrName") -> Column:
6273+
"""Aggregate function: returns the number of `TRUE` values for the expression.
6274+
6275+
.. versionadded:: 3.5.0
6276+
6277+
Examples:
6278+
--------
6279+
>>> df = spark.createDataFrame([(1,), (2,), (3,)], ["v"])
6280+
>>> df.select(count_if(df.v > 1).alias("c")).collect()
6281+
[Row(c=2)]
6282+
"""
6283+
return _invoke_function_over_columns("count_if", col)
6284+
6285+
6286+
def max_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
6287+
"""Returns the value associated with the maximum value of `ord`.
6288+
6289+
.. versionadded:: 3.3.0
6290+
6291+
Examples:
6292+
--------
6293+
>>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"])
6294+
>>> df.select(max_by("k", "v")).first()[0]
6295+
'b'
6296+
"""
6297+
return _invoke_function("arg_max", _to_column_expr(col), _to_column_expr(ord))
6298+
6299+
6300+
def min_by(col: "ColumnOrName", ord: "ColumnOrName") -> Column:
6301+
"""Returns the value associated with the minimum value of `ord`.
6302+
6303+
.. versionadded:: 3.3.0
6304+
6305+
Examples:
6306+
--------
6307+
>>> df = spark.createDataFrame([("a", 1), ("b", 3), ("c", 2)], ["k", "v"])
6308+
>>> df.select(min_by("k", "v")).first()[0]
6309+
'a'
6310+
"""
6311+
return _invoke_function("arg_min", _to_column_expr(col), _to_column_expr(ord))
6312+
6313+
6314+
def bool_and(col: "ColumnOrName") -> Column:
6315+
"""Aggregate function: returns true if all values of `col` are true.
6316+
6317+
.. versionadded:: 3.5.0
6318+
6319+
Examples:
6320+
--------
6321+
>>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"])
6322+
>>> df.select(bool_and("b")).first()[0]
6323+
False
6324+
"""
6325+
return _invoke_function_over_columns("bool_and", col)
6326+
6327+
6328+
def every(col: "ColumnOrName") -> Column:
6329+
"""Alias of :func:`bool_and`."""
6330+
return bool_and(col)
6331+
6332+
6333+
def bool_or(col: "ColumnOrName") -> Column:
6334+
"""Aggregate function: returns true if at least one value of `col` is true.
6335+
6336+
.. versionadded:: 3.5.0
6337+
6338+
Examples:
6339+
--------
6340+
>>> df = spark.createDataFrame([(True,), (True,), (False,)], ["b"])
6341+
>>> df.select(bool_or("b")).first()[0]
6342+
True
6343+
"""
6344+
return _invoke_function_over_columns("bool_or", col)
6345+
6346+
6347+
def some(col: "ColumnOrName") -> Column:
6348+
"""Alias of :func:`bool_or`."""
6349+
return bool_or(col)
6350+
6351+
6352+
def any(col: "ColumnOrName") -> Column:
6353+
"""Alias of :func:`bool_or`."""
6354+
return bool_or(col)
6355+
6356+
6357+
def kurtosis(col: "ColumnOrName") -> Column:
6358+
"""Aggregate function: returns the kurtosis of the values in a group.
6359+
6360+
.. versionadded:: 1.6.0
6361+
6362+
Examples:
6363+
--------
6364+
>>> df = spark.createDataFrame([(1.0,), (2.0,), (3.0,), (4.0,)], ["v"])
6365+
>>> df.select(kurtosis("v")).first()[0] is not None
6366+
True
6367+
"""
6368+
return _invoke_function_over_columns("kurtosis", col)
6369+
6370+
62116371
def broadcast(df: "DataFrame") -> "DataFrame":
62126372
"""The broadcast function in Spark is used to optimize joins by broadcasting a smaller
62136373
dataset to all the worker nodes. However, DuckDB operates on a single-node architecture .
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pytest
2+
3+
_ = pytest.importorskip("duckdb.experimental.spark")
4+
5+
from spark_namespace.sql import functions as F
6+
from spark_namespace.sql.types import Row
7+
8+
9+
class TestSparkAggregateFunctions:
10+
def test_count_distinct(self, spark):
11+
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"])
12+
res = df.groupBy("k").agg(F.count_distinct("v").alias("d")).collect()
13+
assert res == [Row(k="g", d=2)]
14+
15+
def test_countDistinct_alias(self, spark):
16+
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2)], ["k", "v"])
17+
res = df.groupBy("k").agg(F.countDistinct("v").alias("d")).collect()
18+
assert res == [Row(k="g", d=2)]
19+
20+
def test_count_distinct_multi_col(self, spark):
21+
df = spark.createDataFrame(
22+
[
23+
("g", 1, "a"),
24+
("g", 1, "a"),
25+
("g", 1, "b"),
26+
("g", None, "c"),
27+
("g", 2, None),
28+
("g", None, None),
29+
],
30+
["k", "a", "b"],
31+
)
32+
res = df.groupBy("k").agg(F.count_distinct("a", "b").alias("d")).collect()
33+
assert res == [Row(k="g", d=2)]
34+
35+
def test_collect_set(self, spark):
36+
df = spark.createDataFrame([("g", 1), ("g", 1), ("g", 2), ("g", None)], ["k", "v"])
37+
row = df.groupBy("k").agg(F.collect_set("v").alias("s")).collect()[0]
38+
assert row.k == "g"
39+
assert sorted(row.s) == [1, 2]
40+
41+
def test_count_if(self, spark):
42+
df = spark.createDataFrame([("g", 1), ("g", 2), ("g", 3)], ["k", "v"])
43+
res = df.groupBy("k").agg(F.count_if(F.col("v") > 1).alias("c")).collect()
44+
assert res == [Row(k="g", c=2)]
45+
46+
def test_max_by(self, spark):
47+
df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"])
48+
res = df.groupBy("k").agg(F.max_by("name", "v").alias("m")).collect()
49+
assert res == [Row(k="g", m="b")]
50+
51+
def test_min_by(self, spark):
52+
df = spark.createDataFrame([("g", "a", 1), ("g", "b", 3), ("g", "c", 2)], ["k", "name", "v"])
53+
res = df.groupBy("k").agg(F.min_by("name", "v").alias("m")).collect()
54+
assert res == [Row(k="g", m="a")]
55+
56+
def test_bool_and(self, spark):
57+
df = spark.createDataFrame([("g", True), ("g", True), ("g", False)], ["k", "b"])
58+
res = df.groupBy("k").agg(F.bool_and("b").alias("r")).collect()
59+
assert res == [Row(k="g", r=False)]
60+
61+
df2 = spark.createDataFrame([("g", True), ("g", True), ("g", True)], ["k", "b"])
62+
res2 = df2.groupBy("k").agg(F.bool_and("b").alias("r")).collect()
63+
assert res2 == [Row(k="g", r=True)]
64+
65+
def test_every_alias(self, spark):
66+
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
67+
res = df.groupBy("k").agg(F.every("b").alias("r")).collect()
68+
assert res == [Row(k="g", r=False)]
69+
70+
def test_bool_or(self, spark):
71+
df = spark.createDataFrame([("g", True), ("g", False), ("g", False)], ["k", "b"])
72+
res = df.groupBy("k").agg(F.bool_or("b").alias("r")).collect()
73+
assert res == [Row(k="g", r=True)]
74+
75+
df2 = spark.createDataFrame([("g", False), ("g", False)], ["k", "b"])
76+
res2 = df2.groupBy("k").agg(F.bool_or("b").alias("r")).collect()
77+
assert res2 == [Row(k="g", r=False)]
78+
79+
def test_some_alias(self, spark):
80+
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
81+
res = df.groupBy("k").agg(F.some("b").alias("r")).collect()
82+
assert res == [Row(k="g", r=True)]
83+
84+
def test_any_alias(self, spark):
85+
df = spark.createDataFrame([("g", True), ("g", False)], ["k", "b"])
86+
res = df.groupBy("k").agg(F.any("b").alias("r")).collect()
87+
assert res == [Row(k="g", r=True)]
88+
89+
def test_kurtosis(self, spark):
90+
df = spark.createDataFrame([("g", 1.0), ("g", 2.0), ("g", 3.0), ("g", 4.0)], ["k", "v"])
91+
row = df.groupBy("k").agg(F.kurtosis("v").alias("kur")).collect()[0]
92+
assert row.k == "g"
93+
assert row.kur is not None

0 commit comments

Comments
 (0)