Skip to content

Commit ea2370a

Browse files
timsaucerclaude
andcommitted
Add tests for new scalar functions
Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 148f62e commit ea2370a

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

python/tests/test_functions.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,3 +1660,73 @@ def df_with_nulls():
16601660
def test_conditional_functions(df_with_nulls, expr, expected):
16611661
result = df_with_nulls.select(expr.alias("result")).collect()[0]
16621662
assert result.column(0) == expected
1663+
1664+
1665+
def test_get_field(df):
1666+
df = df.with_column(
1667+
"s",
1668+
f.named_struct(
1669+
[
1670+
("x", column("a")),
1671+
("y", column("b")),
1672+
]
1673+
),
1674+
)
1675+
result = df.select(
1676+
f.get_field(column("s"), string_literal("x")).alias("x_val"),
1677+
f.get_field(column("s"), string_literal("y")).alias("y_val"),
1678+
).collect()[0]
1679+
1680+
assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
1681+
assert result.column(1) == pa.array([4, 5, 6])
1682+
1683+
1684+
def test_arrow_metadata(df):
1685+
result = df.select(
1686+
f.arrow_metadata(column("a")).alias("meta"),
1687+
).collect()[0]
1688+
# The metadata column should be returned as a map type (possibly empty)
1689+
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())
1690+
1691+
1692+
def test_version():
1693+
ctx = SessionContext()
1694+
df = ctx.from_pydict({"a": [1]})
1695+
result = df.select(f.version().alias("v")).collect()[0]
1696+
version_str = result.column(0)[0].as_py()
1697+
assert "Apache DataFusion" in version_str
1698+
1699+
1700+
def test_row(df):
1701+
result = df.select(
1702+
f.row(column("a"), column("b")).alias("r"),
1703+
f.struct(column("a"), column("b")).alias("s"),
1704+
).collect()[0]
1705+
# row is an alias for struct, so they should produce the same output
1706+
assert result.column(0) == result.column(1)
1707+
1708+
1709+
def test_union_tag():
1710+
ctx = SessionContext()
1711+
types = pa.array([0, 1, 0], type=pa.int8())
1712+
offsets = pa.array([0, 0, 1], type=pa.int32())
1713+
children = [pa.array([1, 2]), pa.array(["hello"])]
1714+
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
1715+
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
1716+
1717+
result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0]
1718+
assert result.column(0).to_pylist() == ["int", "str", "int"]
1719+
1720+
1721+
def test_union_extract():
1722+
ctx = SessionContext()
1723+
types = pa.array([0, 1, 0], type=pa.int8())
1724+
offsets = pa.array([0, 0, 1], type=pa.int32())
1725+
children = [pa.array([1, 2]), pa.array(["hello"])]
1726+
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
1727+
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])
1728+
1729+
result = df.select(
1730+
f.union_extract(column("u"), string_literal("int")).alias("val")
1731+
).collect()[0]
1732+
assert result.column(0).to_pylist() == [1, None, 2]

0 commit comments

Comments
 (0)