Skip to content

Commit 1688aef

Browse files
timsaucerclaude
andcommitted
Make map the primary function with make_map as alias
map() now supports three calling conventions matching upstream: - map({"a": 1, "b": 2}) — from a Python dictionary - map([keys], [values]) — two lists that get zipped - map(k1, v1, k2, v2, ...) — variadic key-value pairs Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5467eda commit 1688aef

File tree

2 files changed

+85
-41
lines changed

2 files changed

+85
-41
lines changed

python/datafusion/functions.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,49 +3383,57 @@ def empty(array: Expr) -> Expr:
33833383
# map functions
33843384

33853385

3386-
def make_map(
3387-
data: dict[Any, Any] | None = None,
3388-
keys: list[Any] | None = None,
3389-
values: list[Any] | None = None,
3390-
) -> Expr:
3386+
def map(*args: Any) -> Expr:
33913387
"""Returns a map expression.
33923388
3393-
Can be called with either a Python dictionary or separate ``keys``
3394-
and ``values`` lists. Keys and values that are not already
3395-
:py:class:`~datafusion.expr.Expr` are automatically converted to
3396-
literal expressions.
3389+
Supports three calling conventions:
33973390
3398-
Args:
3399-
data: A Python dictionary of key-value pairs.
3400-
keys: A list of keys (use with ``values`` for column expressions).
3401-
values: A list of values (use with ``keys``).
3391+
- ``map({"a": 1, "b": 2})`` — from a Python dictionary.
3392+
- ``map([keys], [values])`` — two lists that get zipped.
3393+
- ``map(k1, v1, k2, v2, ...)`` — variadic key-value pairs.
3394+
3395+
Keys and values that are not already :py:class:`~datafusion.expr.Expr`
3396+
are automatically converted to literal expressions.
34023397
34033398
Examples:
34043399
>>> ctx = dfn.SessionContext()
34053400
>>> df = ctx.from_pydict({"a": [1]})
34063401
>>> result = df.select(
3407-
... dfn.functions.make_map({"a": 1, "b": 2}).alias("map"))
3408-
>>> result.collect_column("map")[0].as_py()
3402+
... dfn.functions.map({"a": 1, "b": 2}).alias("m"))
3403+
>>> result.collect_column("m")[0].as_py()
34093404
[('a', 1), ('b', 2)]
34103405
"""
3411-
if data is not None:
3412-
if keys is not None or values is not None:
3413-
msg = "Cannot specify both data and keys/values"
3414-
raise ValueError(msg)
3415-
key_list = list(data.keys())
3416-
value_list = list(data.values())
3417-
elif keys is not None and values is not None:
3418-
key_list = keys
3419-
value_list = values
3406+
if len(args) == 1 and isinstance(args[0], dict):
3407+
key_list = list(args[0].keys())
3408+
value_list = list(args[0].values())
3409+
elif (
3410+
len(args) == 2 # noqa: PLR2004
3411+
and isinstance(args[0], list)
3412+
and isinstance(args[1], list)
3413+
):
3414+
key_list = args[0]
3415+
value_list = args[1]
3416+
elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004
3417+
key_list = list(args[0::2])
3418+
value_list = list(args[1::2])
34203419
else:
3421-
msg = "Must specify either data or both keys and values"
3420+
msg = "map expects a dict, two lists, or an even number of key-value arguments"
34223421
raise ValueError(msg)
34233422

34243423
key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
34253424
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
34263425
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))
34273426

34283427

3428+
def make_map(*args: Any) -> Expr:
3429+
"""Returns a map expression.
3430+
3431+
See Also:
3432+
This is an alias for :py:func:`map`.
3433+
"""
3434+
return map(*args)
3435+
3436+
34293437
def map_keys(map: Expr) -> Expr:
34303438
"""Returns a list of all keys in the map.
34313439
@@ -3434,7 +3442,7 @@ def map_keys(map: Expr) -> Expr:
34343442
>>> df = ctx.from_pydict({"a": [1]})
34353443
>>> result = df.select(
34363444
... dfn.functions.map_keys(
3437-
... dfn.functions.make_map({"x": 1, "y": 2})
3445+
... dfn.functions.map({"x": 1, "y": 2})
34383446
... ).alias("keys"))
34393447
>>> result.collect_column("keys")[0].as_py()
34403448
['x', 'y']
@@ -3450,7 +3458,7 @@ def map_values(map: Expr) -> Expr:
34503458
>>> df = ctx.from_pydict({"a": [1]})
34513459
>>> result = df.select(
34523460
... dfn.functions.map_values(
3453-
... dfn.functions.make_map({"x": 1, "y": 2})
3461+
... dfn.functions.map({"x": 1, "y": 2})
34543462
... ).alias("vals"))
34553463
>>> result.collect_column("vals")[0].as_py()
34563464
[1, 2]
@@ -3466,7 +3474,7 @@ def map_extract(map: Expr, key: Expr) -> Expr:
34663474
>>> df = ctx.from_pydict({"a": [1]})
34673475
>>> result = df.select(
34683476
... dfn.functions.map_extract(
3469-
... dfn.functions.make_map({"x": 1, "y": 2}),
3477+
... dfn.functions.map({"x": 1, "y": 2}),
34703478
... dfn.lit("x"),
34713479
... ).alias("val"))
34723480
>>> result.collect_column("val")[0].as_py()
@@ -3483,7 +3491,7 @@ def map_entries(map: Expr) -> Expr:
34833491
>>> df = ctx.from_pydict({"a": [1]})
34843492
>>> result = df.select(
34853493
... dfn.functions.map_entries(
3486-
... dfn.functions.make_map({"x": 1, "y": 2})
3494+
... dfn.functions.map({"x": 1, "y": 2})
34873495
... ).alias("entries"))
34883496
>>> result.collect_column("entries")[0].as_py()
34893497
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]

python/tests/test_functions.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -668,29 +668,29 @@ def test_array_function_obj_tests(stmt, py_expr):
668668
assert a == b
669669

670670

671-
def test_make_map():
671+
def test_map_from_dict():
672672
ctx = SessionContext()
673673
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674674
df = ctx.create_dataframe([[batch]])
675675

676-
result = df.select(f.make_map({"x": 1, "y": 2}).alias("map")).collect()[0].column(0)
676+
result = df.select(f.map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0)
677677
assert result[0].as_py() == [("x", 1), ("y", 2)]
678678

679679

680-
def test_make_map_with_expr_values():
680+
def test_map_from_dict_with_expr_values():
681681
ctx = SessionContext()
682682
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
683683
df = ctx.create_dataframe([[batch]])
684684

685685
result = (
686-
df.select(f.make_map({"x": literal(1), "y": literal(2)}).alias("map"))
686+
df.select(f.map({"x": literal(1), "y": literal(2)}).alias("m"))
687687
.collect()[0]
688688
.column(0)
689689
)
690690
assert result[0].as_py() == [("x", 1), ("y", 2)]
691691

692692

693-
def test_make_map_with_column_data():
693+
def test_map_from_two_lists():
694694
ctx = SessionContext()
695695
batch = pa.RecordBatch.from_arrays(
696696
[
@@ -701,7 +701,7 @@ def test_make_map_with_column_data():
701701
)
702702
df = ctx.create_dataframe([[batch]])
703703

704-
m = f.make_map(keys=[column("keys")], values=[column("vals")])
704+
m = f.map([column("keys")], [column("vals")])
705705
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
706706
for i, expected in enumerate(["k1", "k2", "k3"]):
707707
assert result[i].as_py() == [expected]
@@ -711,12 +711,48 @@ def test_make_map_with_column_data():
711711
assert result[i].as_py() == [expected]
712712

713713

714+
def test_map_from_variadic_pairs():
715+
ctx = SessionContext()
716+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
717+
df = ctx.create_dataframe([[batch]])
718+
719+
result = df.select(f.map("x", 1, "y", 2).alias("m")).collect()[0].column(0)
720+
assert result[0].as_py() == [("x", 1), ("y", 2)]
721+
722+
723+
def test_map_variadic_with_exprs():
724+
ctx = SessionContext()
725+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
726+
df = ctx.create_dataframe([[batch]])
727+
728+
result = (
729+
df.select(f.map(literal("x"), literal(1), literal("y"), literal(2)).alias("m"))
730+
.collect()[0]
731+
.column(0)
732+
)
733+
assert result[0].as_py() == [("x", 1), ("y", 2)]
734+
735+
736+
def test_map_odd_args_raises():
737+
with pytest.raises(ValueError, match="map expects"):
738+
f.map("x", 1, "y")
739+
740+
741+
def test_make_map_is_alias():
742+
ctx = SessionContext()
743+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
744+
df = ctx.create_dataframe([[batch]])
745+
746+
result = df.select(f.make_map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0)
747+
assert result[0].as_py() == [("x", 1), ("y", 2)]
748+
749+
714750
def test_map_keys():
715751
ctx = SessionContext()
716752
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
717753
df = ctx.create_dataframe([[batch]])
718754

719-
m = f.make_map({"x": 1, "y": 2})
755+
m = f.map({"x": 1, "y": 2})
720756
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
721757
assert result[0].as_py() == ["x", "y"]
722758

@@ -726,7 +762,7 @@ def test_map_values():
726762
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
727763
df = ctx.create_dataframe([[batch]])
728764

729-
m = f.make_map({"x": 1, "y": 2})
765+
m = f.map({"x": 1, "y": 2})
730766
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
731767
assert result[0].as_py() == [1, 2]
732768

@@ -736,7 +772,7 @@ def test_map_extract():
736772
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
737773
df = ctx.create_dataframe([[batch]])
738774

739-
m = f.make_map({"x": 1, "y": 2})
775+
m = f.map({"x": 1, "y": 2})
740776
result = (
741777
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
742778
)
@@ -748,7 +784,7 @@ def test_map_extract_missing_key():
748784
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
749785
df = ctx.create_dataframe([[batch]])
750786

751-
m = f.make_map({"x": 1})
787+
m = f.map({"x": 1})
752788
result = (
753789
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
754790
)
@@ -760,7 +796,7 @@ def test_map_entries():
760796
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
761797
df = ctx.create_dataframe([[batch]])
762798

763-
m = f.make_map({"x": 1, "y": 2})
799+
m = f.map({"x": 1, "y": 2})
764800
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
765801
assert result[0].as_py() == [
766802
{"key": "x", "value": 1},
@@ -773,7 +809,7 @@ def test_element_at():
773809
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
774810
df = ctx.create_dataframe([[batch]])
775811

776-
m = f.make_map({"a": 10, "b": 20})
812+
m = f.map({"a": 10, "b": 20})
777813
result = (
778814
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
779815
)

0 commit comments

Comments
 (0)