Skip to content

Commit 148f62e

Browse files
timsaucerclaude
andcommitted
Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row
Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes apache#1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 16feeb1 commit 148f62e

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,29 @@ expr_fn_vec!(named_struct);
644644
expr_fn!(from_unixtime, unixtime);
645645
expr_fn!(arrow_typeof, arg_1);
646646
expr_fn!(arrow_cast, arg_1 datatype);
647+
expr_fn_vec!(arrow_metadata);
648+
expr_fn!(union_tag, arg1);
647649
expr_fn!(random);
648650

651+
#[pyfunction]
652+
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
653+
functions::core::get_field()
654+
.call(vec![expr.into(), name.into()])
655+
.into()
656+
}
657+
658+
#[pyfunction]
659+
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
660+
functions::core::union_extract()
661+
.call(vec![union_expr.into(), field_name.into()])
662+
.into()
663+
}
664+
665+
#[pyfunction]
666+
fn version() -> PyExpr {
667+
functions::core::version().call(vec![]).into()
668+
}
669+
649670
// Array Functions
650671
array_fn!(array_append, array element);
651672
array_fn!(array_to_string, array delimiter);
@@ -953,6 +974,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
953974
m.add_wrapped(wrap_pyfunction!(array_agg))?;
954975
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
955976
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
977+
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
956978
m.add_wrapped(wrap_pyfunction!(ascii))?;
957979
m.add_wrapped(wrap_pyfunction!(asin))?;
958980
m.add_wrapped(wrap_pyfunction!(asinh))?;
@@ -1081,6 +1103,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10811103
m.add_wrapped(wrap_pyfunction!(trim))?;
10821104
m.add_wrapped(wrap_pyfunction!(trunc))?;
10831105
m.add_wrapped(wrap_pyfunction!(upper))?;
1106+
m.add_wrapped(wrap_pyfunction!(get_field))?;
1107+
m.add_wrapped(wrap_pyfunction!(union_extract))?;
1108+
m.add_wrapped(wrap_pyfunction!(union_tag))?;
1109+
m.add_wrapped(wrap_pyfunction!(version))?;
10841110
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
10851111
m.add_wrapped(wrap_pyfunction!(var_pop))?;
10861112
m.add_wrapped(wrap_pyfunction!(var_sample))?;

python/datafusion/functions.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"array_to_string",
9191
"array_union",
9292
"arrow_cast",
93+
"arrow_metadata",
9394
"arrow_typeof",
9495
"ascii",
9596
"asin",
@@ -152,6 +153,7 @@
152153
"floor",
153154
"from_unixtime",
154155
"gcd",
156+
"get_field",
155157
"greatest",
156158
"ifnull",
157159
"in_list",
@@ -250,6 +252,7 @@
250252
"reverse",
251253
"right",
252254
"round",
255+
"row",
253256
"row_number",
254257
"rpad",
255258
"rtrim",
@@ -290,12 +293,15 @@
290293
"translate",
291294
"trim",
292295
"trunc",
296+
"union_extract",
297+
"union_tag",
293298
"upper",
294299
"uuid",
295300
"var",
296301
"var_pop",
297302
"var_samp",
298303
"var_sample",
304+
"version",
299305
"when",
300306
# Window Functions
301307
"window",
@@ -2612,6 +2618,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
26122618
return Expr(f.arrow_cast(expr.expr, data_type.expr))
26132619

26142620

2621+
def arrow_metadata(*args: Expr) -> Expr:
2622+
"""Returns the metadata of the input expression.
2623+
2624+
If called with one argument, returns a Map of all metadata key-value pairs.
2625+
If called with two arguments, returns the value for the specified metadata key.
2626+
2627+
Args:
2628+
args: An expression, optionally followed by a metadata key string.
2629+
2630+
Returns:
2631+
A Map of metadata or a specific metadata value.
2632+
"""
2633+
args = [arg.expr for arg in args]
2634+
return Expr(f.arrow_metadata(*args))
2635+
2636+
2637+
def get_field(expr: Expr, name: Expr) -> Expr:
2638+
"""Extracts a field from a struct or map by name.
2639+
2640+
Args:
2641+
expr: A struct or map expression.
2642+
name: The field name to extract.
2643+
2644+
Returns:
2645+
The value of the named field.
2646+
"""
2647+
return Expr(f.get_field(expr.expr, name.expr))
2648+
2649+
2650+
def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
2651+
"""Extracts a value from a union type by field name.
2652+
2653+
Returns the value of the named field if it is the currently selected
2654+
variant, otherwise returns NULL.
2655+
2656+
Args:
2657+
union_expr: A union-typed expression.
2658+
field_name: The name of the field to extract.
2659+
2660+
Returns:
2661+
The extracted value or NULL.
2662+
"""
2663+
return Expr(f.union_extract(union_expr.expr, field_name.expr))
2664+
2665+
2666+
def union_tag(union_expr: Expr) -> Expr:
2667+
"""Returns the tag (active field name) of a union type.
2668+
2669+
Args:
2670+
union_expr: A union-typed expression.
2671+
2672+
Returns:
2673+
The name of the currently selected field in the union.
2674+
"""
2675+
return Expr(f.union_tag(union_expr.expr))
2676+
2677+
2678+
def version() -> Expr:
2679+
"""Returns the DataFusion version string.
2680+
2681+
Returns:
2682+
A string describing the DataFusion version.
2683+
"""
2684+
return Expr(f.version())
2685+
2686+
2687+
def row(*args: Expr) -> Expr:
2688+
"""Returns a struct with the given arguments.
2689+
2690+
This is an alias for :py:func:`struct`.
2691+
2692+
Args:
2693+
args: The expressions to include in the struct.
2694+
2695+
Returns:
2696+
A struct expression.
2697+
"""
2698+
return struct(*args)
2699+
2700+
26152701
def random() -> Expr:
26162702
"""Returns a random value in the range ``0.0 <= x < 1.0``.
26172703

0 commit comments

Comments
 (0)