Skip to content

Commit 427565b

Browse files
committed
Refactor DataFrame methods to use _to_expr_list for expression handling
1 parent 308acb8 commit 427565b

File tree

2 files changed

+18
-22
lines changed

2 files changed

+18
-22
lines changed

python/datafusion/dataframe.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, sort_or_default
43+
from datafusion.expr import Expr, SortExpr, _to_expr_list, sort_or_default
4444
from datafusion.plan import ExecutionPlan, LogicalPlan
4545
from datafusion.record_batch import RecordBatchStream
4646

@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394
df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396
"""
397-
exprs_internal = [
398-
Expr.column(arg).expr if isinstance(arg, str) else arg.expr for arg in exprs
399-
]
397+
exprs_internal = _to_expr_list(exprs)
400398
return DataFrame(self.df.select(*exprs_internal))
401399

402400
def drop(self, *columns: str) -> DataFrame:
@@ -548,19 +546,8 @@ def aggregate(
548546
group_by_list = group_by if isinstance(group_by, list) else [group_by]
549547
aggs_list = aggs if isinstance(aggs, list) else [aggs]
550548

551-
group_by_exprs = [
552-
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
553-
]
554-
555-
aggs_exprs: list[expr_internal.Expr] = []
556-
for agg in aggs_list:
557-
if not isinstance(agg, Expr):
558-
msg = (
559-
f"Expected Expr, got {type(agg).__name__}. "
560-
"Use col() or lit() to construct expressions."
561-
)
562-
raise TypeError(msg)
563-
aggs_exprs.append(agg.expr)
549+
group_by_exprs = _to_expr_list(group_by_list)
550+
aggs_exprs = _to_expr_list(aggs_list)
564551
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
565552

566553
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
@@ -575,10 +562,14 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
575562
Returns:
576563
DataFrame after sorting.
577564
"""
578-
exprs_raw = [
579-
sort_or_default(Expr.column(expr) if isinstance(expr, str) else expr)
580-
for expr in exprs
581-
]
565+
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
566+
raw_exprs_iter = iter(_to_expr_list(expr_seq))
567+
exprs_raw = []
568+
for e in exprs:
569+
if isinstance(e, SortExpr):
570+
exprs_raw.append(sort_or_default(e))
571+
else:
572+
exprs_raw.append(sort_or_default(Expr(next(raw_exprs_iter))))
582573
return DataFrame(self.df.sort(*exprs_raw))
583574

584575
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:

python/datafusion/expr.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from __future__ import annotations
2424

25-
from typing import TYPE_CHECKING, Any, ClassVar, Optional
25+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence
2626

2727
import pyarrow as pa
2828

@@ -215,6 +215,11 @@
215215
]
216216

217217

218+
def _to_expr_list(exprs: Sequence[Expr | str]) -> list[expr_internal.Expr]:
219+
"""Convert a sequence of expressions or column names to raw expressions."""
220+
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in exprs]
221+
222+
218223
def expr_list_to_raw_expr_list(
219224
expr_list: Optional[list[Expr] | Expr],
220225
) -> Optional[list[expr_internal.Expr]]:

0 commit comments

Comments
 (0)