Skip to content

Commit 4384c1f

Browse files
timsaucerclaude
andcommitted
Accept str for key parameter in arrow_metadata for consistency
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2771621 commit 4384c1f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/datafusion/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2550,19 +2550,24 @@ def arrow_cast(expr: Expr, data_type: Expr | str) -> Expr:
25502550
return Expr(f.arrow_cast(expr.expr, data_type.expr))
25512551

25522552

2553-
def arrow_metadata(*args: Expr) -> Expr:
2553+
def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
25542554
"""Returns the metadata of the input expression.
25552555
25562556
If called with one argument, returns a Map of all metadata key-value pairs.
25572557
If called with two arguments, returns the value for the specified metadata key.
25582558
25592559
Args:
2560-
args: An expression, optionally followed by a metadata key string.
2560+
expr: An expression whose metadata to retrieve.
2561+
key: Optional metadata key to look up. Can be a string or an Expr.
25612562
25622563
Returns:
25632564
A Map of metadata or a specific metadata value.
25642565
"""
2565-
return Expr(f.arrow_metadata(*[arg.expr for arg in args]))
2566+
if key is None:
2567+
return Expr(f.arrow_metadata(expr.expr))
2568+
if isinstance(key, str):
2569+
key = Expr.string_literal(key)
2570+
return Expr(f.arrow_metadata(expr.expr, key.expr))
25662571

25672572

25682573
def get_field(expr: Expr, name: Expr | str) -> Expr:

python/tests/test_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
import pyarrow as pa
2222
import pytest
23-
from datafusion import SessionContext, column, literal, string_literal
23+
from datafusion import SessionContext, column, literal
2424
from datafusion import functions as f
2525

2626
np.seterr(invalid="ignore")
@@ -1505,7 +1505,7 @@ def test_arrow_metadata():
15051505

15061506
# Two-argument form: returns the value for a specific metadata key
15071507
result = df.select(
1508-
f.arrow_metadata(column("val"), string_literal("key1")).alias("meta_val"),
1508+
f.arrow_metadata(column("val"), "key1").alias("meta_val"),
15091509
).collect()[0]
15101510
assert result.column(0)[0].as_py() == "value1"
15111511

0 commit comments

Comments
 (0)