Skip to content

Commit b627d30

Browse files
timsaucerclaude
andcommitted
Support pyarrow DataType in arrow_cast
Allow arrow_cast to accept a pyarrow DataType in addition to str and Expr. The DataType is converted to its string representation before being passed to DataFusion. Adds test coverage for the new input type. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a662e18 commit b627d30

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

python/datafusion/functions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,10 +2634,10 @@ def arrow_typeof(arg: Expr) -> Expr:
26342634
return Expr(f.arrow_typeof(arg.expr))
26352635

26362636

2637-
def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
2637+
def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
26382638
"""Casts an expression to a specified data type.
26392639
2640-
The ``data_type`` can be a string or an ``Expr``.
2640+
The ``data_type`` can be a string, a ``pyarrow.DataType``, or an ``Expr``.
26412641
26422642
Examples:
26432643
>>> ctx = dfn.SessionContext()
@@ -2647,7 +2647,18 @@ def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
26472647
... )
26482648
>>> result.collect_column("c")[0].as_py()
26492649
1.0
2650+
2651+
>>> import pyarrow as pa
2652+
>>> result = df.select(
2653+
... dfn.functions.arrow_cast(
2654+
... dfn.col("a"), data_type=pa.float64()
2655+
... ).alias("c")
2656+
... )
2657+
>>> result.collect_column("c")[0].as_py()
2658+
1.0
26502659
"""
2660+
if isinstance(data_type, pa.DataType):
2661+
data_type = str(data_type)
26512662
if isinstance(data_type, str):
26522663
data_type = Expr.string_literal(data_type)
26532664
return Expr(f.arrow_cast(expr.expr, data_type.expr))

python/tests/test_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,19 @@ def test_arrow_cast(df):
13021302
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
13031303

13041304

1305+
def test_arrow_cast_with_pyarrow_type(df):
1306+
df = df.select(
1307+
f.arrow_cast(column("b"), pa.float64()).alias("b_as_float"),
1308+
f.arrow_cast(column("b"), pa.int32()).alias("b_as_int"),
1309+
f.arrow_cast(column("b"), pa.string()).alias("b_as_str"),
1310+
)
1311+
result = df.collect()[0]
1312+
1313+
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
1314+
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
1315+
assert result.column(2) == pa.array(["4", "5", "6"], type=pa.string())
1316+
1317+
13051318
def test_case(df):
13061319
df = df.select(
13071320
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),

0 commit comments

Comments
 (0)