Skip to content

Commit df1ead1

Browse files
timsaucerclaude
andcommitted
Accept str for key parameter in arrow_metadata for consistency
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 02eb255 commit df1ead1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/datafusion/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,19 +2619,24 @@ def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
26192619
return Expr(f.arrow_cast(expr.expr, data_type.expr))
26202620

26212621

2622-
def arrow_metadata(*args: Expr) -> Expr:
2622+
def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
26232623
"""Returns the metadata of the input expression.
26242624
26252625
If called with one argument, returns a Map of all metadata key-value pairs.
26262626
If called with two arguments, returns the value for the specified metadata key.
26272627
26282628
Args:
2629-
args: An expression, optionally followed by a metadata key string.
2629+
expr: An expression whose metadata to retrieve.
2630+
key: Optional metadata key to look up. Can be a string or an Expr.
26302631
26312632
Returns:
26322633
A Map of metadata or a specific metadata value.
26332634
"""
2634-
return Expr(f.arrow_metadata(*[arg.expr for arg in args]))
2635+
if key is None:
2636+
return Expr(f.arrow_metadata(expr.expr))
2637+
if isinstance(key, str):
2638+
key = Expr.string_literal(key)
2639+
return Expr(f.arrow_metadata(expr.expr, key.expr))
26352640

26362641

26372642
def get_field(expr: Expr, name: Expr | str) -> Expr:

python/tests/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
import pyarrow as pa
2222
import pytest
23-
from datafusion import SessionContext, column, literal, string_literal
23+
from datafusion import SessionContext, column, literal
2424
from datafusion import functions as f
2525

2626
np.seterr(invalid="ignore")
@@ -1696,7 +1696,7 @@ def test_arrow_metadata():
16961696

16971697
# Two-argument form: returns the value for a specific metadata key
16981698
result = df.select(
1699-
f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"),
1699+
f.arrow_metadata(column("val"), "key1").alias("meta_val"),
17001700
).collect()[0]
17011701
assert result.column(0)[0].as_py() == "value1"
17021702

0 commit comments

Comments
 (0)