4040from datafusion ._internal import DataFrame as DataFrameInternal
4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
43- from datafusion .expr import Expr , SortExpr , sort_or_default
43+ from datafusion .expr import Expr , SortExpr , _to_expr_list , sort_or_default
4444from datafusion .plan import ExecutionPlan , LogicalPlan
4545from datafusion .record_batch import RecordBatchStream
4646
@@ -394,9 +394,7 @@ def select(self, *exprs: Expr | str) -> DataFrame:
394394 df = df.select("a", col("b"), col("a").alias("alternate_a"))
395395
396396 """
397- exprs_internal = [
398- Expr .column (arg ).expr if isinstance (arg , str ) else arg .expr for arg in exprs
399- ]
397+ exprs_internal = _to_expr_list (exprs )
400398 return DataFrame (self .df .select (* exprs_internal ))
401399
402400 def drop (self , * columns : str ) -> DataFrame :
@@ -548,19 +546,8 @@ def aggregate(
548546 group_by_list = group_by if isinstance (group_by , list ) else [group_by ]
549547 aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
550548
551- group_by_exprs = [
552- Expr .column (e ).expr if isinstance (e , str ) else e .expr for e in group_by_list
553- ]
554-
555- aggs_exprs : list [expr_internal .Expr ] = []
556- for agg in aggs_list :
557- if not isinstance (agg , Expr ):
558- msg = (
559- f"Expected Expr, got { type (agg ).__name__ } . "
560- "Use col() or lit() to construct expressions."
561- )
562- raise TypeError (msg )
563- aggs_exprs .append (agg .expr )
549+ group_by_exprs = _to_expr_list (group_by_list )
550+ aggs_exprs = _to_expr_list (aggs_list )
564551 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
565552
566553 def sort (self , * exprs : Expr | SortExpr | str ) -> DataFrame :
@@ -575,10 +562,14 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
575562 Returns:
576563 DataFrame after sorting.
577564 """
578- exprs_raw = [
579- sort_or_default (Expr .column (expr ) if isinstance (expr , str ) else expr )
580- for expr in exprs
581- ]
565+ expr_seq = [e for e in exprs if not isinstance (e , SortExpr )]
566+ raw_exprs_iter = iter (_to_expr_list (expr_seq ))
567+ exprs_raw = []
568+ for e in exprs :
569+ if isinstance (e , SortExpr ):
570+ exprs_raw .append (sort_or_default (e ))
571+ else :
572+ exprs_raw .append (sort_or_default (Expr (next (raw_exprs_iter ))))
582573 return DataFrame (self .df .sort (* exprs_raw ))
583574
584575 def cast (self , mapping : dict [str , pa .DataType [Any ]]) -> DataFrame :
0 commit comments