Skip to content

Commit 3c16577

Browse files
committed
Add _ensure_expr function to validate and convert expressions in DataFrame methods
1 parent 0b6303e commit 3c16577

File tree

2 files changed

+26
-55
lines changed

2 files changed

+26
-55
lines changed

python/datafusion/dataframe.py

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@
4040
from datafusion._internal import DataFrame as DataFrameInternal
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43-
from datafusion.expr import Expr, SortExpr, _to_expr_list, sort_or_default
43+
from datafusion.expr import (
44+
Expr,
45+
SortExpr,
46+
_ensure_expr,
47+
_to_expr_list,
48+
sort_or_default,
49+
)
4450
from datafusion.plan import ExecutionPlan, LogicalPlan
4551
from datafusion.record_batch import RecordBatchStream
4652

@@ -424,13 +430,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
424430
"""
425431
df = self.df
426432
for p in predicates:
427-
if not isinstance(p, Expr):
428-
msg = (
429-
f"Expected Expr, got {type(p).__name__}. "
430-
"Use col() or lit() to construct expressions."
431-
)
432-
raise TypeError(msg)
433-
df = df.filter(p.expr)
433+
df = df.filter(_ensure_expr(p))
434434
return DataFrame(df)
435435

436436
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -443,13 +443,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
443443
Returns:
444444
DataFrame with the new column.
445445
"""
446-
if not isinstance(expr, Expr):
447-
msg = (
448-
f"Expected Expr, got {type(expr).__name__}. "
449-
"Use col() or lit() to construct expressions."
450-
)
451-
raise TypeError(msg)
452-
return DataFrame(self.df.with_column(name, expr.expr))
446+
return DataFrame(self.df.with_column(name, _ensure_expr(expr)))
453447

454448
def with_columns(
455449
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -480,31 +474,13 @@ def _simplify_expression(
480474
) -> list[expr_internal.Expr]:
481475
expr_list = []
482476
for expr in exprs:
483-
if isinstance(expr, Expr):
484-
expr_list.append(expr.expr)
485-
elif isinstance(expr, Iterable):
486-
for inner_expr in expr:
487-
if not isinstance(inner_expr, Expr):
488-
msg = (
489-
f"Expected Expr, got {type(inner_expr).__name__}. "
490-
"Use col() or lit() to construct expressions."
491-
)
492-
raise TypeError(msg)
493-
expr_list.append(inner_expr.expr)
477+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
478+
expr_list.extend(_ensure_expr(inner_expr) for inner_expr in expr)
494479
else:
495-
msg = (
496-
f"Expected Expr, got {type(expr).__name__}. "
497-
"Use col() or lit() to construct expressions."
498-
)
499-
raise TypeError(msg)
480+
expr_list.append(_ensure_expr(expr))
500481
if named_exprs:
501482
for alias, expr in named_exprs.items():
502-
if not isinstance(expr, Expr):
503-
msg = (
504-
f"Expected Expr, got {type(expr).__name__}. "
505-
"Use col() or lit() to construct expressions."
506-
)
507-
raise TypeError(msg)
483+
_ensure_expr(expr)
508484
expr_list.append(expr.alias(alias).expr)
509485
return expr_list
510486

@@ -549,15 +525,7 @@ def aggregate(
549525
group_by_exprs = [
550526
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
551527
]
552-
aggs_exprs = []
553-
for agg in aggs_list:
554-
if not isinstance(agg, Expr):
555-
msg = (
556-
f"Expected Expr, got {type(agg).__name__}. "
557-
"Use col() or lit() to construct expressions."
558-
)
559-
raise TypeError(msg)
560-
aggs_exprs.append(agg.expr)
528+
aggs_exprs = [_ensure_expr(agg) for agg in aggs_list]
561529
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
562530

563531
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
@@ -803,15 +771,7 @@ def join_on(
803771
Returns:
804772
DataFrame after join.
805773
"""
806-
exprs = []
807-
for expr in on_exprs:
808-
if not isinstance(expr, Expr):
809-
msg = (
810-
f"Expected Expr, got {type(expr).__name__}. "
811-
"Use col() or lit() to construct expressions."
812-
)
813-
raise TypeError(msg)
814-
exprs.append(expr.expr)
774+
exprs = [_ensure_expr(expr) for expr in on_exprs]
815775
return DataFrame(self.df.join_on(right.df, exprs, how))
816776

817777
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,17 @@ def _to_expr_list(exprs: Sequence[Expr | str]) -> list[expr_internal.Expr]:
220220
return [Expr.column(e).expr if isinstance(e, str) else e.expr for e in exprs]
221221

222222

223+
def _ensure_expr(value: Any) -> expr_internal.Expr:
224+
"""Return the internal expression or raise if the value is not an Expr."""
225+
if isinstance(value, Expr):
226+
return value.expr
227+
msg = (
228+
f"Expected Expr, got {type(value).__name__}. "
229+
"Use col() or lit() to construct expressions."
230+
)
231+
raise TypeError(msg)
232+
233+
223234
def expr_list_to_raw_expr_list(
224235
expr_list: Optional[list[Expr] | Expr],
225236
) -> Optional[list[expr_internal.Expr]]:

0 commit comments

Comments
 (0)