4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , _to_expr_list , sort_or_default
43+ from datafusion .expr import (
44+ Expr ,
45+ SortExpr ,
46+ _ensure_expr ,
47+ _to_expr_list ,
48+ sort_or_default ,
49+ )
4450from datafusion .plan import ExecutionPlan , LogicalPlan
4551from datafusion .record_batch import RecordBatchStream
4652
@@ -424,13 +430,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
424430 """
425431 df = self .df
426432 for p in predicates :
427- if not isinstance (p , Expr ):
428- msg = (
429- f"Expected Expr, got { type (p ).__name__ } . "
430- "Use col() or lit() to construct expressions."
431- )
432- raise TypeError (msg )
433- df = df .filter (p .expr )
433+ df = df .filter (_ensure_expr (p ))
434434 return DataFrame (df )
435435
436436 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -443,13 +443,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
443443 Returns:
444444 DataFrame with the new column.
445445 """
446- if not isinstance (expr , Expr ):
447- msg = (
448- f"Expected Expr, got { type (expr ).__name__ } . "
449- "Use col() or lit() to construct expressions."
450- )
451- raise TypeError (msg )
452- return DataFrame (self .df .with_column (name , expr .expr ))
446+ return DataFrame (self .df .with_column (name , _ensure_expr (expr )))
453447
454448 def with_columns (
455449 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
@@ -480,31 +474,13 @@ def _simplify_expression(
480474 ) -> list [expr_internal .Expr ]:
481475 expr_list = []
482476 for expr in exprs :
483- if isinstance (expr , Expr ):
484- expr_list .append (expr .expr )
485- elif isinstance (expr , Iterable ):
486- for inner_expr in expr :
487- if not isinstance (inner_expr , Expr ):
488- msg = (
489- f"Expected Expr, got { type (inner_expr ).__name__ } . "
490- "Use col() or lit() to construct expressions."
491- )
492- raise TypeError (msg )
493- expr_list .append (inner_expr .expr )
477+ if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
478+ expr_list .extend (_ensure_expr (inner_expr ) for inner_expr in expr )
494479 else :
495- msg = (
496- f"Expected Expr, got { type (expr ).__name__ } . "
497- "Use col() or lit() to construct expressions."
498- )
499- raise TypeError (msg )
480+ expr_list .append (_ensure_expr (expr ))
500481 if named_exprs :
501482 for alias , expr in named_exprs .items ():
502- if not isinstance (expr , Expr ):
503- msg = (
504- f"Expected Expr, got { type (expr ).__name__ } . "
505- "Use col() or lit() to construct expressions."
506- )
507- raise TypeError (msg )
483+ _ensure_expr (expr )
508484 expr_list .append (expr .alias (alias ).expr )
509485 return expr_list
510486
@@ -549,15 +525,7 @@ def aggregate(
549525 group_by_exprs = [
550526 Expr .column (e ).expr if isinstance (e , str ) else e .expr for e in group_by_list
551527 ]
552- aggs_exprs = []
553- for agg in aggs_list :
554- if not isinstance (agg , Expr ):
555- msg = (
556- f"Expected Expr, got { type (agg ).__name__ } . "
557- "Use col() or lit() to construct expressions."
558- )
559- raise TypeError (msg )
560- aggs_exprs .append (agg .expr )
528+ aggs_exprs = [_ensure_expr (agg ) for agg in aggs_list ]
561529 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
562530
563531 def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
@@ -803,15 +771,7 @@ def join_on(
803771 Returns:
804772 DataFrame after join.
805773 """
806- exprs = []
807- for expr in on_exprs :
808- if not isinstance (expr , Expr ):
809- msg = (
810- f"Expected Expr, got { type (expr ).__name__ } . "
811- "Use col() or lit() to construct expressions."
812- )
813- raise TypeError (msg )
814- exprs .append (expr .expr )
774+ exprs = [_ensure_expr (expr ) for expr in on_exprs ]
815775 return DataFrame (self .df .join_on (right .df , exprs , how ))
816776
817777 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments