Skip to content

Commit 8e51ce7

Browse files
max-sixtyclaude
andcommitted
Fix Dataset.map to handle non-DataArray outputs
After PR pydata#10602, Dataset.map started failing when functions returned non-DataArray values (e.g., scalars), raising AttributeError when trying to access .coords on the returned values. This restores backward compatibility by converting non-DataArray outputs to DataArrays, which was the behavior before the regression. Fixes pydata#10835 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 43a7f1e commit 8e51ce7

2 files changed

Lines changed: 32 additions & 0 deletions

File tree

xarray/core/dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6923,12 +6923,19 @@ def map(
69236923
foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773
69246924
bar (x) float64 16B 1.0 2.0
69256925
"""
6926+
from xarray.core.dataarray import DataArray
6927+
69266928
if keep_attrs is None:
69276929
keep_attrs = _get_keep_attrs(default=False)
69286930
variables = {
69296931
k: maybe_wrap_array(v, func(v, *args, **kwargs))
69306932
for k, v in self.data_vars.items()
69316933
}
6934+
# Convert non-DataArray values to DataArrays
6935+
variables = {
6936+
k: v if isinstance(v, DataArray) else DataArray(v)
6937+
for k, v in variables.items()
6938+
}
69326939
coord_vars, indexes = merge_coordinates_without_align(
69336940
[v.coords for v in variables.values()]
69346941
)

xarray/tests/test_dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6268,6 +6268,31 @@ def func(arr):
62686268
ds["x"].attrs["y"] = "x"
62696269
assert ds["x"].attrs != actual["x"].attrs
62706270

6271+
def test_map_non_dataarray_outputs(self) -> None:
6272+
# Test that map handles non-DataArray outputs by converting them
6273+
# Regression test for GH10835
6274+
ds = xr.Dataset({"foo": ("x", [1, 2, 3]), "bar": ("y", [4, 5])})
6275+
6276+
# Scalar output
6277+
result = ds.map(lambda x: 1)
6278+
expected = xr.Dataset({"foo": 1, "bar": 1})
6279+
assert_identical(result, expected)
6280+
6281+
# Numpy array output with same shape
6282+
result = ds.map(lambda x: x.values)
6283+
expected = ds.copy()
6284+
assert_identical(result, expected)
6285+
6286+
# Mixed: some return scalars, some return arrays
6287+
def mixed_func(x):
6288+
if "x" in x.dims:
6289+
return 42
6290+
return x
6291+
6292+
result = ds.map(mixed_func)
6293+
expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])})
6294+
assert_identical(result, expected)
6295+
62716296
def test_apply_pending_deprecated_map(self) -> None:
62726297
data = create_test_data()
62736298
data.attrs["foo"] = "bar"

0 commit comments

Comments
 (0)