Skip to content

Commit 86142ae

Browse files
committed
Enhance type annotations for file_sort_order and order_by parameters to support string inputs
1 parent d7a466d commit 86142ae

4 files changed

Lines changed: 37 additions & 35 deletions

File tree

python/datafusion/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def register_listing_table(
553553
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554554
file_extension: str = ".parquet",
555555
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
556+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
557557
) -> None:
558558
"""Register multiple files as a single table.
559559
@@ -808,7 +808,7 @@ def register_parquet(
808808
file_extension: str = ".parquet",
809809
skip_metadata: bool = True,
810810
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[SortExpr]] | None = None,
811+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
812812
) -> None:
813813
"""Register a Parquet file as a table.
814814
@@ -1099,7 +1099,7 @@ def read_parquet(
10991099
file_extension: str = ".parquet",
11001100
skip_metadata: bool = True,
11011101
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
1102+
file_sort_order: list[list[Expr | SortExpr | str]] | None = None,
11031103
) -> DataFrame:
11041104
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051105

python/datafusion/dataframe.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
Expr,
4646
SortExpr,
4747
expr_list_to_raw_expr_list,
48-
sort_or_default,
48+
sort_list_to_raw_sort_list,
4949
)
5050
from datafusion.plan import ExecutionPlan, LogicalPlan
5151
from datafusion.record_batch import RecordBatchStream
@@ -551,20 +551,7 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
551551
Returns:
552552
DataFrame after sorting.
553553
"""
554-
exprs_raw = []
555-
for e in exprs:
556-
if isinstance(e, SortExpr):
557-
exprs_raw.append(sort_or_default(e))
558-
elif isinstance(e, str):
559-
exprs_raw.append(sort_or_default(Expr.column(e)))
560-
elif isinstance(e, Expr):
561-
exprs_raw.append(sort_or_default(e))
562-
else:
563-
error = (
564-
"Expected Expr or column name, found:"
565-
f" {type(e).__name__}. {_EXPR_TYPE_ERROR}."
566-
)
567-
raise TypeError(error)
554+
exprs_raw = sort_list_to_raw_sort_list(list(exprs))
568555
return DataFrame(self.df.sort(*exprs_raw))
569556

570557
def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:

python/datafusion/expr.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,27 @@ def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
250250

251251

252252
def sort_list_to_raw_sort_list(
253-
sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
253+
sort_list: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str],
254254
) -> Optional[list[expr_internal.SortExpr]]:
255255
"""Helper function to return an optional sort list to raw variant."""
256-
if isinstance(sort_list, (Expr, SortExpr)):
256+
if isinstance(sort_list, (Expr, SortExpr, str)):
257257
sort_list = [sort_list]
258-
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None
258+
if sort_list is None:
259+
return None
260+
raw_sort_list = []
261+
for item in sort_list:
262+
if isinstance(item, str):
263+
expr_obj = Expr.column(item)
264+
elif isinstance(item, (Expr, SortExpr)):
265+
expr_obj = item
266+
else:
267+
error = (
268+
"Expected Expr or column name, found:"
269+
f" {type(item).__name__}. {_EXPR_TYPE_ERROR}."
270+
)
271+
raise TypeError(error)
272+
raw_sort_list.append(sort_or_default(expr_obj))
273+
return raw_sort_list
259274

260275

261276
class Expr:

python/datafusion/functions.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def window(
429429
name: str,
430430
args: list[Expr],
431431
partition_by: list[Expr] | Expr | None = None,
432-
order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None,
432+
order_by: list[Expr | SortExpr | str] | Expr | SortExpr | str | None = None,
433433
window_frame: WindowFrame | None = None,
434434
ctx: SessionContext | None = None,
435435
) -> Expr:
@@ -1723,7 +1723,7 @@ def array_agg(
17231723
expression: Expr,
17241724
distinct: bool = False,
17251725
filter: Optional[Expr] = None,
1726-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
1726+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
17271727
) -> Expr:
17281728
"""Aggregate values into an array.
17291729
@@ -2222,7 +2222,7 @@ def regr_syy(
22222222
def first_value(
22232223
expression: Expr,
22242224
filter: Optional[Expr] = None,
2225-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2225+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
22262226
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22272227
) -> Expr:
22282228
"""Returns the first value in a group of values.
@@ -2254,7 +2254,7 @@ def first_value(
22542254
def last_value(
22552255
expression: Expr,
22562256
filter: Optional[Expr] = None,
2257-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2257+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
22582258
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22592259
) -> Expr:
22602260
"""Returns the last value in a group of values.
@@ -2287,7 +2287,7 @@ def nth_value(
22872287
expression: Expr,
22882288
n: int,
22892289
filter: Optional[Expr] = None,
2290-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2290+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
22912291
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
22922292
) -> Expr:
22932293
"""Returns the n-th value in a group of values.
@@ -2408,7 +2408,7 @@ def lead(
24082408
shift_offset: int = 1,
24092409
default_value: Optional[Any] = None,
24102410
partition_by: Optional[list[Expr] | Expr] = None,
2411-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2411+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
24122412
) -> Expr:
24132413
"""Create a lead window function.
24142414
@@ -2461,7 +2461,7 @@ def lag(
24612461
shift_offset: int = 1,
24622462
default_value: Optional[Any] = None,
24632463
partition_by: Optional[list[Expr] | Expr] = None,
2464-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2464+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
24652465
) -> Expr:
24662466
"""Create a lag window function.
24672467
@@ -2508,7 +2508,7 @@ def lag(
25082508

25092509
def row_number(
25102510
partition_by: Optional[list[Expr] | Expr] = None,
2511-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2511+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
25122512
) -> Expr:
25132513
"""Create a row number window function.
25142514
@@ -2542,7 +2542,7 @@ def row_number(
25422542

25432543
def rank(
25442544
partition_by: Optional[list[Expr] | Expr] = None,
2545-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2545+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
25462546
) -> Expr:
25472547
"""Create a rank window function.
25482548
@@ -2581,7 +2581,7 @@ def rank(
25812581

25822582
def dense_rank(
25832583
partition_by: Optional[list[Expr] | Expr] = None,
2584-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2584+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
25852585
) -> Expr:
25862586
"""Create a dense_rank window function.
25872587
@@ -2615,7 +2615,7 @@ def dense_rank(
26152615

26162616
def percent_rank(
26172617
partition_by: Optional[list[Expr] | Expr] = None,
2618-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2618+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
26192619
) -> Expr:
26202620
"""Create a percent_rank window function.
26212621
@@ -2650,7 +2650,7 @@ def percent_rank(
26502650

26512651
def cume_dist(
26522652
partition_by: Optional[list[Expr] | Expr] = None,
2653-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2653+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
26542654
) -> Expr:
26552655
"""Create a cumulative distribution window function.
26562656
@@ -2686,7 +2686,7 @@ def cume_dist(
26862686
def ntile(
26872687
groups: int,
26882688
partition_by: Optional[list[Expr] | Expr] = None,
2689-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2689+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
26902690
) -> Expr:
26912691
"""Create a n-tile window function.
26922692
@@ -2727,7 +2727,7 @@ def string_agg(
27272727
expression: Expr,
27282728
delimiter: str,
27292729
filter: Optional[Expr] = None,
2730-
order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
2730+
order_by: Optional[list[Expr | SortExpr | str] | Expr | SortExpr | str] = None,
27312731
) -> Expr:
27322732
"""Concatenates the input strings.
27332733

0 commit comments

Comments
 (0)