Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .ai/skills/check-upstream/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
- Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions`
- Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()`

**Evaluated and not requiring separate Python exposure:**
- `get_field_path` — already covered by `get_field(expr, *names)`, which takes a
variadic field path and dispatches to the same underlying
`functions::core::get_field` UDF as the upstream `get_field_path` helper.

**How to check:**
1. Fetch the upstream scalar function documentation page
2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions)
3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding.
4. Only report functions that are missing from the Python `__all__` list / function definitions
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
5. Only report functions that are missing from the Python `__all__` list / function definitions

### 2. Aggregate Functions

Expand Down
8 changes: 4 additions & 4 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
fn get_field(expr: PyExpr, names: Vec<PyExpr>) -> PyExpr {
let mut args = vec![expr.into()];
args.extend(names.into_iter().map(Into::into));
functions::core::get_field().call(args).into()
}

#[pyfunction]
Expand Down
5 changes: 2 additions & 3 deletions examples/datafusion-ffi-example/src/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

use std::sync::Arc;

use datafusion_catalog::{TableFunctionImpl, TableProvider};
use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
use datafusion_common::error::Result as DataFusionResult;
use datafusion_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use datafusion_python_util::ffi_logical_codec_from_pycapsule;
use pyo3::types::PyCapsule;
Expand Down Expand Up @@ -59,7 +58,7 @@ impl MyTableFunction {
}

impl TableFunctionImpl for MyTableFunction {
fn call(&self, _args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
fn call_with_args(&self, _args: TableFunctionArgs) -> DataFusionResult<Arc<dyn TableProvider>> {
let provider = MyTableProvider::new(4, 3, 2).create_table()?;
Ok(Arc::new(provider))
}
Expand Down
19 changes: 15 additions & 4 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@

import pandas as pd
import polars as pl # type: ignore[import]
from _typeshed import CapsuleType as _PyCapsule

from datafusion.catalog import CatalogProvider, Table
from datafusion.common import DFSchema
from datafusion.expr import Expr, SortKey
from datafusion.plan import ExecutionPlan, LogicalPlan
from datafusion.user_defined import (
AggregateUDF,
LogicalExtensionCodecExportable,
PhysicalExtensionCodecExportable,
ScalarUDF,
TableFunction,
WindowUDF,
Expand Down Expand Up @@ -1744,11 +1747,15 @@ def __datafusion_logical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_LogicalExtensionCodec."""
return self.ctx.__datafusion_logical_extension_codec__()

def with_logical_extension_codec(self, codec: Any) -> SessionContext:
def with_logical_extension_codec(
self, codec: LogicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with specified codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw ``FFI_LogicalExtensionCodec``
``PyCapsule`` or an object exposing
``__datafusion_logical_extension_codec__``.
Comment thread
timsaucer marked this conversation as resolved.
Outdated
"""
new_internal = self.ctx.with_logical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand All @@ -1759,11 +1766,15 @@ def __datafusion_physical_extension_codec__(self) -> Any:
"""Access the PyCapsule FFI_PhysicalExtensionCodec."""
return self.ctx.__datafusion_physical_extension_codec__()

def with_physical_extension_codec(self, codec: Any) -> SessionContext:
def with_physical_extension_codec(
self, codec: PhysicalExtensionCodecExportable | _PyCapsule
) -> SessionContext:
"""Create a new session context with the specified physical codec.

This only supports codecs that have been implemented using the
FFI interface.
FFI interface. ``codec`` must either be a raw
Comment thread
timsaucer marked this conversation as resolved.
Outdated
``FFI_PhysicalExtensionCodec`` ``PyCapsule`` or an object exposing
``__datafusion_physical_extension_codec__``.
"""
new_internal = self.ctx.with_physical_extension_codec(codec)
new = SessionContext.__new__(SessionContext)
Expand Down
42 changes: 34 additions & 8 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,14 +2727,24 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
return Expr(f.arrow_metadata(expr.expr, key.expr))


def get_field(expr: Expr, name: Expr | str) -> Expr:
"""Extracts a field from a struct or map by name.
def get_field(expr: Expr, *names: Expr | str) -> Expr:
Comment thread
timsaucer marked this conversation as resolved.
"""Extracts a (possibly nested) field from a struct or map by name.

When the field name is a static string, the bracket operator
``expr["field"]`` is a convenient shorthand. Use ``get_field``
when the field name is a dynamic expression.
Pass one name for a single-level lookup, or several names to walk a path
of nested struct/map fields in a single ``get_field`` call. For a single
static-string name, ``expr["field"]`` is a convenient shorthand; use
``get_field`` when the field name is a dynamic
:py:class:`~datafusion.expr.Expr` or when traversing multiple levels at
once.

Args:
expr: The struct or map expression to read from.
*names: One or more field names (``str``) or expressions
(:py:class:`~datafusion.expr.Expr`).

Examples:
Single-level lookup:

>>> ctx = dfn.SessionContext()
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
>>> df = df.with_column(
Expand All @@ -2756,10 +2766,26 @@ def get_field(expr: Expr, name: Expr | str) -> Expr:
... )
>>> result.collect_column("x_val")[0].as_py()
1

Multi-level lookup:

>>> df = df.with_column(
... "outer",
... dfn.functions.named_struct([("inner", dfn.col("s"))]),
Comment thread
timsaucer marked this conversation as resolved.
Outdated
... )
>>> result = df.select(
... dfn.functions.get_field(
... dfn.col("outer"), "inner", "x"
... ).alias("x_val")
... )
>>> result.collect_column("x_val")[0].as_py()
1
"""
if isinstance(name, str):
name = Expr.string_literal(name)
return Expr(f.get_field(expr.expr, name.expr))
if not names:
msg = "get_field requires at least one field name"
raise ValueError(msg)
resolved = [Expr.string_literal(n) if isinstance(n, str) else n for n in names]
return Expr(f.get_field(expr.expr, [n.expr for n in resolved]))


def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
Expand Down
12 changes: 12 additions & 0 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
return value.__class__.__name__ == "PyCapsule"


class LogicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_logical_extension_codec__``."""

def __datafusion_logical_extension_codec__(self) -> object: ... # noqa: D105


class PhysicalExtensionCodecExportable(Protocol):
"""Type hint for objects exposing ``__datafusion_physical_extension_codec__``."""

def __datafusion_physical_extension_codec__(self) -> object: ... # noqa: D105


class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand Down
31 changes: 31 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,37 @@ def test_get_field(df):
assert result.column(1) == pa.array([4, 5, 6])


def test_get_field_path(df):
df = df.with_column(
"outer",
f.named_struct(
[
(
"inner",
f.named_struct(
[
("x", column("a")),
("y", column("b")),
]
),
),
]
),
)
result = df.select(
f.get_field(column("outer"), "inner", "x").alias("x_val"),
f.get_field(column("outer"), "inner", "y").alias("y_val"),
).collect()[0]

assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
assert result.column(1) == pa.array([4, 5, 6])


def test_get_field_requires_a_name():
with pytest.raises(ValueError, match="at least one field name"):
f.get_field(column("s"))


def test_arrow_metadata():
ctx = SessionContext()
field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})
Expand Down