Skip to content

Commit d163934

Browse files
committed
Add type checks for unsupported expressions in select and sort methods
1 parent 3c16577 commit d163934

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

python/datafusion/dataframe.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,19 @@ def select(self, *exprs: Expr | str) -> DataFrame:
400400
df = df.select("a", col("b"), col("a").alias("alternate_a"))
401401
402402
"""
403-
exprs_internal = _to_expr_list(exprs)
403+
checked_exprs: list[Expr | str] = []
404+
for expr in exprs:
405+
if isinstance(expr, SortExpr):
406+
checked_exprs.append(expr.expr())
407+
elif isinstance(expr, (Expr, str)):
408+
checked_exprs.append(expr)
409+
else:
410+
msg = (
411+
f"Expected Expr or column name, got {type(expr).__name__}. "
412+
"Use col() or lit() to construct expressions."
413+
)
414+
raise TypeError(msg)
415+
exprs_internal = _to_expr_list(checked_exprs)
404416
return DataFrame(self.df.select(*exprs_internal))
405417

406418
def drop(self, *columns: str) -> DataFrame:
@@ -540,9 +552,20 @@ def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
540552
Returns:
541553
DataFrame after sorting.
542554
"""
543-
expr_seq = [e for e in exprs if not isinstance(e, SortExpr)]
555+
expr_seq: list[Expr | str] = []
556+
for e in exprs:
557+
if isinstance(e, SortExpr):
558+
continue
559+
if isinstance(e, (Expr, str)):
560+
expr_seq.append(e)
561+
else:
562+
msg = (
563+
f"Expected Expr or column name, got {type(e).__name__}. "
564+
"Use col() or lit() to construct expressions."
565+
)
566+
raise TypeError(msg)
544567
raw_exprs_iter = iter(_to_expr_list(expr_seq))
545-
exprs_raw = []
568+
exprs_raw: list[Any] = []
546569
for e in exprs:
547570
if isinstance(e, SortExpr):
548571
exprs_raw.append(sort_or_default(e))

python/tests/test_dataframe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@ def test_select_mixed_expr_string(df):
227227
assert result.column(1) == pa.array([1, 2, 3])
228228

229229

230+
def test_select_unsupported(df):
231+
with pytest.raises(
232+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
233+
):
234+
df.select(1)
235+
236+
230237
def test_filter(df):
231238
df1 = df.filter(column("a") > literal(2)).select(
232239
column("a") + column("b"),
@@ -276,6 +283,13 @@ def test_sort_string_and_expression_equivalent(df):
276283
assert result_str == result_expr
277284

278285

286+
def test_sort_unsupported(df):
287+
with pytest.raises(
288+
TypeError, match=r"Expected Expr or column name.*col\(\) or lit\(\)"
289+
):
290+
df.sort(1)
291+
292+
279293
def test_aggregate_string_and_expression_equivalent(df):
280294
from datafusion import col
281295

0 commit comments

Comments
 (0)