Skip to content

Commit 02eb255

Browse files
timsaucerclaude
andcommitted
Accept str for field name and type parameters in scalar functions
Allow arrow_cast, get_field, and union_extract to accept plain str arguments instead of requiring Expr wrappers. Also improve arrow_metadata test coverage and fix parameter shadowing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ea2370a commit 02eb255

File tree

2 files changed

+34
-20
lines changed

2 files changed

+34
-20
lines changed

python/datafusion/functions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2602,19 +2602,20 @@ def arrow_typeof(arg: Expr) -> Expr:
26022602
return Expr(f.arrow_typeof(arg.expr))
26032603

26042604

2605-
def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
2605+
def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
26062606
"""Casts an expression to a specified data type.
26072607
26082608
Examples:
26092609
>>> ctx = dfn.SessionContext()
26102610
>>> df = ctx.from_pydict({"a": [1]})
2611-
>>> data_type = dfn.string_literal("Float64")
26122611
>>> result = df.select(
2613-
... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c")
2612+
... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c")
26142613
... )
26152614
>>> result.collect_column("c")[0].as_py()
26162615
1.0
26172616
"""
2617+
if isinstance(data_type, str):
2618+
data_type = Expr.string_literal(data_type)
26182619
return Expr(f.arrow_cast(expr.expr, data_type.expr))
26192620

26202621

@@ -2630,11 +2631,10 @@ def arrow_metadata(*args: Expr) -> Expr:
26302631
Returns:
26312632
A Map of metadata or a specific metadata value.
26322633
"""
2633-
args = [arg.expr for arg in args]
2634-
return Expr(f.arrow_metadata(*args))
2634+
return Expr(f.arrow_metadata(*[arg.expr for arg in args]))
26352635

26362636

2637-
def get_field(expr: Expr, name: Expr) -> Expr:
2637+
def get_field(expr: Expr, name: Expr | str) -> Expr:
26382638
"""Extracts a field from a struct or map by name.
26392639
26402640
Args:
@@ -2644,10 +2644,12 @@ def get_field(expr: Expr, name: Expr) -> Expr:
26442644
Returns:
26452645
The value of the named field.
26462646
"""
2647+
if isinstance(name, str):
2648+
name = Expr.string_literal(name)
26472649
return Expr(f.get_field(expr.expr, name.expr))
26482650

26492651

2650-
def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
2652+
def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
26512653
"""Extracts a value from a union type by field name.
26522654
26532655
Returns the value of the named field if it is the currently selected
@@ -2660,6 +2662,8 @@ def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
26602662
Returns:
26612663
The extracted value or NULL.
26622664
"""
2665+
if isinstance(field_name, str):
2666+
field_name = Expr.string_literal(field_name)
26632667
return Expr(f.union_extract(union_expr.expr, field_name.expr))
26642668

26652669

python/tests/test_functions.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,11 +1143,8 @@ def test_make_time(df):
11431143

11441144
def test_arrow_cast(df):
11451145
df = df.select(
1146-
# we use `string_literal` to return utf8 instead of `literal` which returns
1147-
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
1148-
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
1149-
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
1150-
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
1146+
f.arrow_cast(column("b"), "Float64").alias("b_as_float"),
1147+
f.arrow_cast(column("b"), "Int32").alias("b_as_int"),
11511148
)
11521149
result = df.collect()
11531150
assert len(result) == 1
@@ -1673,20 +1670,35 @@ def test_get_field(df):
16731670
),
16741671
)
16751672
result = df.select(
1676-
f.get_field(column("s"), string_literal("x")).alias("x_val"),
1677-
f.get_field(column("s"), string_literal("y")).alias("y_val"),
1673+
f.get_field(column("s"), "x").alias("x_val"),
1674+
f.get_field(column("s"), "y").alias("y_val"),
16781675
).collect()[0]
16791676

16801677
assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
16811678
assert result.column(1) == pa.array([4, 5, 6])
16821679

16831680

1684-
def test_arrow_metadata(df):
1681+
def test_arrow_metadata():
1682+
ctx = SessionContext()
1683+
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
1684+
schema = pa.schema([field])
1685+
batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)
1686+
df = ctx.create_dataframe([[batch]])
1687+
1688+
# One-argument form: returns a Map of all metadata key-value pairs
16851689
result = df.select(
1686-
f.arrow_metadata(column("a")).alias("meta"),
1690+
f.arrow_metadata(column("val")).alias("meta"),
16871691
).collect()[0]
1688-
# The metadata column should be returned as a map type (possibly empty)
16891692
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())
1693+
meta = result.column(0)[0].as_py()
1694+
assert ("key1", "value1") in meta
1695+
assert ("key2", "value2") in meta
1696+
1697+
# Two-argument form: returns the value for a specific metadata key
1698+
result = df.select(
1699+
f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"),
1700+
).collect()[0]
1701+
assert result.column(0)[0].as_py() == "value1"
16901702

16911703

16921704
def test_version():
@@ -1726,7 +1738,5 @@ def test_union_extract():
17261738
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
17271739
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
17281740

1729-
result = df.select(
1730-
f.union_extract(column("u"), string_literal("int")).alias("val")
1731-
).collect()[0]
1741+
result = df.select(f.union_extract(column("u"), "int").alias("val")).collect()[0]
17321742
assert result.column(0).to_pylist() == [1, None, 2]

0 commit comments

Comments
 (0)