Skip to content

Commit 308acb8

Browse files
committed
Add type checks for expressions in DataFrame methods to enforce correct usage
1 parent 0daf438 commit 308acb8

2 files changed

Lines changed: 75 additions & 4 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,12 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
445445
Returns:
446446
DataFrame with the new column.
447447
"""
448+
if not isinstance(expr, Expr):
449+
msg = (
450+
f"Expected Expr, got {type(expr).__name__}. "
451+
"Use col() or lit() to construct expressions."
452+
)
453+
raise TypeError(msg)
448454
return DataFrame(self.df.with_column(name, expr.expr))
449455

450456
def with_columns(
@@ -479,11 +485,28 @@ def _simplify_expression(
479485
if isinstance(expr, Expr):
480486
expr_list.append(expr.expr)
481487
elif isinstance(expr, Iterable):
482-
expr_list.extend(inner_expr.expr for inner_expr in expr)
488+
for inner_expr in expr:
489+
if not isinstance(inner_expr, Expr):
490+
msg = (
491+
f"Expected Expr, got {type(inner_expr).__name__}. "
492+
"Use col() or lit() to construct expressions."
493+
)
494+
raise TypeError(msg)
495+
expr_list.append(inner_expr.expr)
483496
else:
484-
raise NotImplementedError
497+
msg = (
498+
f"Expected Expr, got {type(expr).__name__}. "
499+
"Use col() or lit() to construct expressions."
500+
)
501+
raise TypeError(msg)
485502
if named_exprs:
486503
for alias, expr in named_exprs.items():
504+
if not isinstance(expr, Expr):
505+
msg = (
506+
f"Expected Expr, got {type(expr).__name__}. "
507+
"Use col() or lit() to construct expressions."
508+
)
509+
raise TypeError(msg)
487510
expr_list.append(expr.alias(alias).expr)
488511
return expr_list
489512

@@ -528,7 +551,16 @@ def aggregate(
528551
group_by_exprs = [
529552
Expr.column(e).expr if isinstance(e, str) else e.expr for e in group_by_list
530553
]
531-
aggs_exprs = [e.expr for e in aggs_list]
554+
555+
aggs_exprs: list[expr_internal.Expr] = []
556+
for agg in aggs_list:
557+
if not isinstance(agg, Expr):
558+
msg = (
559+
f"Expected Expr, got {type(agg).__name__}. "
560+
"Use col() or lit() to construct expressions."
561+
)
562+
raise TypeError(msg)
563+
aggs_exprs.append(agg.expr)
532564
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
533565

534566
def sort(self, *exprs: Expr | SortExpr | str) -> DataFrame:
@@ -770,7 +802,15 @@ def join_on(
770802
Returns:
771803
DataFrame after join.
772804
"""
773-
exprs = [expr.expr for expr in on_exprs]
805+
exprs = []
806+
for expr in on_exprs:
807+
if not isinstance(expr, Expr):
808+
msg = (
809+
f"Expected Expr, got {type(expr).__name__}. "
810+
"Use col() or lit() to construct expressions."
811+
)
812+
raise TypeError(msg)
813+
exprs.append(expr.expr)
774814
return DataFrame(self.df.join_on(right.df, exprs, how))
775815

776816
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/tests/test_dataframe.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ def test_with_column(df):
350350
assert result.column(2) == pa.array([5, 7, 9])
351351

352352

353+
def test_with_column_invalid_expr(df):
354+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
355+
df.with_column("c", "a")
356+
357+
353358
def test_with_columns(df):
354359
df = df.with_columns(
355360
(column("a") + column("b")).alias("c"),
@@ -381,6 +386,13 @@ def test_with_columns(df):
381386
assert result.column(6) == pa.array([5, 7, 9])
382387

383388

389+
def test_with_columns_invalid_expr(df):
390+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
391+
df.with_columns("a")
392+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
393+
df.with_columns(c="a")
394+
395+
384396
def test_cast(df):
385397
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
386398
expected = pa.schema(
@@ -539,6 +551,25 @@ def test_join_on():
539551
assert table.to_pydict() == expected
540552

541553

554+
def test_join_on_invalid_expr():
555+
ctx = SessionContext()
556+
557+
batch = pa.RecordBatch.from_arrays(
558+
[pa.array([1, 2]), pa.array([4, 5])],
559+
names=["a", "b"],
560+
)
561+
df = ctx.create_dataframe([[batch]], "l")
562+
df1 = ctx.create_dataframe([[batch]], "r")
563+
564+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
565+
df.join_on(df1, "a")
566+
567+
568+
def test_aggregate_invalid_aggs(df):
569+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
570+
df.aggregate([], "a")
571+
572+
542573
def test_distinct():
543574
ctx = SessionContext()
544575

0 commit comments

Comments
 (0)