Skip to content

Commit 1ea4c54

Browse files
committed
Enhance DataFrame.aggregate method to accept column names as group by expressions
1 parent 99b45ff commit 1ea4c54

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

docs/source/user-guide/dataframe/index.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,19 @@ For such methods, you can pass column names directly:
142142

143143
.. code-block:: python
144144
145+
from datafusion import col, functions as f
146+
145147
df.sort('id')
148+
df.aggregate('id', [f.count(col('value'))])
146149
147150
The same operation can also be written with an explicit column expression:
148151

149152
.. code-block:: python
150153
151-
from datafusion import col
154+
from datafusion import col, functions as f
155+
152156
df.sort(col('id'))
157+
df.aggregate(col('id'), [f.count(col('value'))])
153158
154159
Whenever an argument represents an expression—such as in
155160
:py:meth:`~datafusion.DataFrame.filter` or

python/datafusion/dataframe.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -503,23 +503,27 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
503503
return DataFrame(self.df.with_column_renamed(old_name, new_name))
504504

505505
def aggregate(
506-
self, group_by: list[Expr] | Expr, aggs: list[Expr] | Expr
506+
self,
507+
group_by: list[Expr | str] | Expr | str,
508+
aggs: list[Expr] | Expr,
507509
) -> DataFrame:
508510
"""Aggregates the rows of the current DataFrame.
509511
510512
Args:
511-
group_by: List of expressions to group by.
513+
group_by: List of expressions or column names to group by.
512514
aggs: List of expressions to aggregate.
513515
514516
Returns:
515517
DataFrame after aggregation.
516518
"""
517-
group_by = group_by if isinstance(group_by, list) else [group_by]
518-
aggs = aggs if isinstance(aggs, list) else [aggs]
519+
group_by_list = group_by if isinstance(group_by, list) else [group_by]
520+
aggs_list = aggs if isinstance(aggs, list) else [aggs]
519521

520-
group_by = [e.expr for e in group_by]
521-
aggs = [e.expr for e in aggs]
522-
return DataFrame(self.df.aggregate(group_by, aggs))
522+
group_by_exprs = [
523+
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
524+
]
525+
aggs_exprs = [e.expr for e in aggs_list]
526+
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
523527

524528
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
525529
"""Sort the DataFrame by the specified sorting expressions or column names.

0 commit comments

Comments
 (0)