Skip to content

Commit 5467eda

Browse files
timsaucerclaude
andcommitted
Change make_map to accept a Python dictionary
make_map now takes a dict for the common case and also supports separate keys/values lists for column expressions. Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2cd3d94 commit 5467eda

File tree

2 files changed

+70
-68
lines changed

2 files changed

+70
-68
lines changed

python/datafusion/functions.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from __future__ import annotations
2020

21-
import builtins
2221
from typing import TYPE_CHECKING, Any
2322

2423
import pyarrow as pa
@@ -3384,29 +3383,47 @@ def empty(array: Expr) -> Expr:
33843383
# map functions
33853384

33863385

3387-
def make_map(*args: Expr) -> Expr:
3388-
"""Returns a map created from key and value expressions.
3386+
def make_map(
3387+
data: dict[Any, Any] | None = None,
3388+
keys: list[Any] | None = None,
3389+
values: list[Any] | None = None,
3390+
) -> Expr:
3391+
"""Returns a map expression.
3392+
3393+
Can be called with either a Python dictionary or separate ``keys``
3394+
and ``values`` lists. Keys and values that are not already
3395+
:py:class:`~datafusion.expr.Expr` are automatically converted to
3396+
literal expressions.
33893397
3390-
Accepts an even number of arguments, alternating between keys and values.
3391-
For example, ``make_map(k1, v1, k2, v2)`` creates a map ``{k1: v1, k2: v2}``.
3398+
Args:
3399+
data: A Python dictionary of key-value pairs.
3400+
keys: A list of keys (use with ``values`` for column expressions).
3401+
values: A list of values (use with ``keys``).
33923402
33933403
Examples:
33943404
>>> ctx = dfn.SessionContext()
33953405
>>> df = ctx.from_pydict({"a": [1]})
33963406
>>> result = df.select(
3397-
... dfn.functions.make_map(
3398-
... dfn.lit("a"), dfn.lit(1),
3399-
... dfn.lit("b"), dfn.lit(2),
3400-
... ).alias("map"))
3407+
... dfn.functions.make_map({"a": 1, "b": 2}).alias("map"))
34013408
>>> result.collect_column("map")[0].as_py()
34023409
[('a', 1), ('b', 2)]
34033410
"""
3404-
if len(args) % 2 != 0:
3405-
msg = "make_map requires an even number of arguments"
3411+
if data is not None:
3412+
if keys is not None or values is not None:
3413+
msg = "Cannot specify both data and keys/values"
3414+
raise ValueError(msg)
3415+
key_list = list(data.keys())
3416+
value_list = list(data.values())
3417+
elif keys is not None and values is not None:
3418+
key_list = keys
3419+
value_list = values
3420+
else:
3421+
msg = "Must specify either data or both keys and values"
34063422
raise ValueError(msg)
3407-
keys = [args[i].expr for i in builtins.range(0, len(args), 2)]
3408-
values = [args[i].expr for i in builtins.range(1, len(args), 2)]
3409-
return Expr(f.make_map(keys, values))
3423+
3424+
key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
3425+
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
3426+
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))
34103427

34113428

34123429
def map_keys(map: Expr) -> Expr:
@@ -3417,10 +3434,7 @@ def map_keys(map: Expr) -> Expr:
34173434
>>> df = ctx.from_pydict({"a": [1]})
34183435
>>> result = df.select(
34193436
... dfn.functions.map_keys(
3420-
... dfn.functions.make_map(
3421-
... dfn.lit("x"), dfn.lit(1),
3422-
... dfn.lit("y"), dfn.lit(2),
3423-
... )
3437+
... dfn.functions.make_map({"x": 1, "y": 2})
34243438
... ).alias("keys"))
34253439
>>> result.collect_column("keys")[0].as_py()
34263440
['x', 'y']
@@ -3436,10 +3450,7 @@ def map_values(map: Expr) -> Expr:
34363450
>>> df = ctx.from_pydict({"a": [1]})
34373451
>>> result = df.select(
34383452
... dfn.functions.map_values(
3439-
... dfn.functions.make_map(
3440-
... dfn.lit("x"), dfn.lit(1),
3441-
... dfn.lit("y"), dfn.lit(2),
3442-
... )
3453+
... dfn.functions.make_map({"x": 1, "y": 2})
34433454
... ).alias("vals"))
34443455
>>> result.collect_column("vals")[0].as_py()
34453456
[1, 2]
@@ -3455,10 +3466,7 @@ def map_extract(map: Expr, key: Expr) -> Expr:
34553466
>>> df = ctx.from_pydict({"a": [1]})
34563467
>>> result = df.select(
34573468
... dfn.functions.map_extract(
3458-
... dfn.functions.make_map(
3459-
... dfn.lit("x"), dfn.lit(1),
3460-
... dfn.lit("y"), dfn.lit(2),
3461-
... ),
3469+
... dfn.functions.make_map({"x": 1, "y": 2}),
34623470
... dfn.lit("x"),
34633471
... ).alias("val"))
34643472
>>> result.collect_column("val")[0].as_py()
@@ -3475,10 +3483,7 @@ def map_entries(map: Expr) -> Expr:
34753483
>>> df = ctx.from_pydict({"a": [1]})
34763484
>>> result = df.select(
34773485
... dfn.functions.map_entries(
3478-
... dfn.functions.make_map(
3479-
... dfn.lit("x"), dfn.lit(1),
3480-
... dfn.lit("y"), dfn.lit(2),
3481-
... )
3486+
... dfn.functions.make_map({"x": 1, "y": 2})
34823487
... ).alias("entries"))
34833488
>>> result.collect_column("entries")[0].as_py()
34843489
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]

python/tests/test_functions.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -673,32 +673,50 @@ def test_make_map():
673673
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674674
df = ctx.create_dataframe([[batch]])
675675

676+
result = df.select(f.make_map({"x": 1, "y": 2}).alias("map")).collect()[0].column(0)
677+
assert result[0].as_py() == [("x", 1), ("y", 2)]
678+
679+
680+
def test_make_map_with_expr_values():
681+
ctx = SessionContext()
682+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
683+
df = ctx.create_dataframe([[batch]])
684+
676685
result = (
677-
df.select(
678-
f.make_map(
679-
literal("x"),
680-
literal(1),
681-
literal("y"),
682-
literal(2),
683-
).alias("map")
684-
)
686+
df.select(f.make_map({"x": literal(1), "y": literal(2)}).alias("map"))
685687
.collect()[0]
686688
.column(0)
687689
)
688690
assert result[0].as_py() == [("x", 1), ("y", 2)]
689691

690692

691-
def test_make_map_odd_args():
692-
with pytest.raises(ValueError, match="even number of arguments"):
693-
f.make_map(literal("x"), literal(1), literal("y"))
693+
def test_make_map_with_column_data():
694+
ctx = SessionContext()
695+
batch = pa.RecordBatch.from_arrays(
696+
[
697+
pa.array(["k1", "k2", "k3"]),
698+
pa.array([10, 20, 30]),
699+
],
700+
names=["keys", "vals"],
701+
)
702+
df = ctx.create_dataframe([[batch]])
703+
704+
m = f.make_map(keys=[column("keys")], values=[column("vals")])
705+
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
706+
for i, expected in enumerate(["k1", "k2", "k3"]):
707+
assert result[i].as_py() == [expected]
708+
709+
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
710+
for i, expected in enumerate([10, 20, 30]):
711+
assert result[i].as_py() == [expected]
694712

695713

696714
def test_map_keys():
697715
ctx = SessionContext()
698716
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
699717
df = ctx.create_dataframe([[batch]])
700718

701-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
719+
m = f.make_map({"x": 1, "y": 2})
702720
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
703721
assert result[0].as_py() == ["x", "y"]
704722

@@ -708,7 +726,7 @@ def test_map_values():
708726
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
709727
df = ctx.create_dataframe([[batch]])
710728

711-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
729+
m = f.make_map({"x": 1, "y": 2})
712730
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
713731
assert result[0].as_py() == [1, 2]
714732

@@ -718,7 +736,7 @@ def test_map_extract():
718736
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
719737
df = ctx.create_dataframe([[batch]])
720738

721-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
739+
m = f.make_map({"x": 1, "y": 2})
722740
result = (
723741
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
724742
)
@@ -730,7 +748,7 @@ def test_map_extract_missing_key():
730748
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
731749
df = ctx.create_dataframe([[batch]])
732750

733-
m = f.make_map(literal("x"), literal(1))
751+
m = f.make_map({"x": 1})
734752
result = (
735753
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
736754
)
@@ -742,7 +760,7 @@ def test_map_entries():
742760
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
743761
df = ctx.create_dataframe([[batch]])
744762

745-
m = f.make_map(literal("x"), literal(1), literal("y"), literal(2))
763+
m = f.make_map({"x": 1, "y": 2})
746764
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
747765
assert result[0].as_py() == [
748766
{"key": "x", "value": 1},
@@ -755,34 +773,13 @@ def test_element_at():
755773
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
756774
df = ctx.create_dataframe([[batch]])
757775

758-
m = f.make_map(literal("a"), literal(10), literal("b"), literal(20))
776+
m = f.make_map({"a": 10, "b": 20})
759777
result = (
760778
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
761779
)
762780
assert result[0].as_py() == [20]
763781

764782

765-
def test_map_functions_with_column_data():
766-
ctx = SessionContext()
767-
batch = pa.RecordBatch.from_arrays(
768-
[
769-
pa.array(["k1", "k2", "k3"]),
770-
pa.array([10, 20, 30]),
771-
],
772-
names=["keys", "vals"],
773-
)
774-
df = ctx.create_dataframe([[batch]])
775-
776-
m = f.make_map(column("keys"), column("vals"))
777-
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
778-
for i, expected in enumerate(["k1", "k2", "k3"]):
779-
assert result[i].as_py() == [expected]
780-
781-
result = df.select(f.map_values(m).alias("v")).collect()[0].column(0)
782-
for i, expected in enumerate([10, 20, 30]):
783-
assert result[i].as_py() == [expected]
784-
785-
786783
@pytest.mark.parametrize(
787784
("function", "expected_result"),
788785
[

0 commit comments

Comments
 (0)