Skip to content

Commit 60786b0

Browse files
committed
Merge datasets pr top level param in xarray export
1 parent 4629dd3 commit 60786b0

2 files changed

Lines changed: 34 additions & 50 deletions

File tree

src/qcodes/dataset/data_set_cache.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def data(self) -> ParameterData:
7979
Loads data from the database on disk if needed and returns
8080
the cached data. The cached data is in almost the same format as
8181
:py:class:`.DataSet.get_parameter_data`. However if a shape is provided
82-
as part of the dataset metadata and fewer datapoints than expected are
82+
as part of the dataset metadata and fewer data points than expected are
8383
returned the missing values will be replaced by `NaN` or zeroes
8484
depending on the datatype.
8585
@@ -118,7 +118,7 @@ def _empty_data_dict(
118118

119119
def prepare(self) -> None:
120120
"""
121-
Set up the internal datastructure of the cache.
121+
Set up the internal data structure of the cache.
122122
Must be called after the dataset has been setup with
123123
interdependencies but before data is added to the dataset.
124124
"""
@@ -200,9 +200,10 @@ def to_xarray_dataarray_dict(
200200
201201
"""
202202
data = self.data()
203-
return load_to_xarray_dataarray_dict(
203+
data_dict = load_to_xarray_dataarray_dict(
204204
self._dataset, data, use_multi_index=use_multi_index
205205
)
206+
return data_dict
206207

207208
def to_xarray_dataset(
208209
self, *, use_multi_index: Literal["auto", "always", "never"] = "auto"
@@ -489,11 +490,11 @@ def load_data_from_db(self) -> None:
489490
)
490491

491492
def _load_xr_dataset(self) -> xr.Dataset:
492-
import cf_xarray as cfxr
493+
import cf_xarray as cf_xr
493494
import xarray as xr
494495

495496
loaded_data = xr.load_dataset(self._xr_dataset_path, engine="h5netcdf")
496-
loaded_data = cfxr.coding.decode_compress_to_multi_index(loaded_data)
497+
loaded_data = cf_xr.coding.decode_compress_to_multi_index(loaded_data)
497498
export_info = ExportInfo.from_str(loaded_data.attrs.get("export_info", ""))
498499
export_info.export_paths["nc"] = str(self._xr_dataset_path)
499500
loaded_data.attrs["export_info"] = export_info.to_str()

src/qcodes/dataset/exporters/export_to_xarray.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from importlib.metadata import version
66
from math import prod
7-
from typing import TYPE_CHECKING, Literal, cast
7+
from typing import TYPE_CHECKING, Literal
88

99
from packaging import version as pversion
1010

@@ -61,20 +61,18 @@ def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]
6161
return expanded_shape
6262

6363

64-
def _load_to_xarray_dataarray_dict_no_metadata(
64+
def _load_to_xarray_dataset_dict_no_metadata(
6565
dataset: DataSetProtocol,
6666
datadict: Mapping[str, Mapping[str, npt.NDArray]],
6767
*,
6868
use_multi_index: Literal["auto", "always", "never"] = "auto",
69-
) -> dict[str, xr.DataArray]:
70-
import xarray as xr
71-
69+
) -> dict[str, xr.Dataset]:
7270
if use_multi_index not in ("auto", "always", "never"):
7371
raise ValueError(
7472
f"Invalid value for use_multi_index. Expected one of 'auto', 'always', 'never' but got {use_multi_index}"
7573
)
7674

77-
data_xrdarray_dict: dict[str, xr.DataArray] = {}
75+
data_xrdarray_dict: dict[str, xr.Dataset] = {}
7876

7977
for name, subdict in datadict.items():
8078
shape_is_consistent = (
@@ -96,11 +94,9 @@ def _load_to_xarray_dataarray_dict_no_metadata(
9694
)
9795

9896
if index is None:
99-
xrdarray: xr.DataArray = (
100-
_data_to_dataframe(subdict, index=index)
101-
.to_xarray()
102-
.get(name, xr.DataArray())
103-
)
97+
xrdarray: xr.Dataset = _data_to_dataframe(
98+
subdict, index=index
99+
).to_xarray()
104100
data_xrdarray_dict[name] = xrdarray
105101
elif index_is_unique:
106102
df = _data_to_dataframe(subdict, index)
@@ -109,9 +105,7 @@ def _load_to_xarray_dataarray_dict_no_metadata(
109105
)
110106
else:
111107
df = _data_to_dataframe(subdict, index)
112-
xrdata_temp = df.reset_index().to_xarray()
113-
for _name in subdict:
114-
data_xrdarray_dict[_name] = xrdata_temp[_name]
108+
data_xrdarray_dict[name] = df.reset_index().to_xarray()
115109

116110
return data_xrdarray_dict
117111

@@ -122,7 +116,7 @@ def _xarray_data_array_from_pandas_multi_index(
122116
name: str,
123117
df: pd.DataFrame,
124118
index: pd.Index | pd.MultiIndex,
125-
) -> xr.DataArray:
119+
) -> xr.Dataset:
126120
import pandas as pd
127121
import xarray as xr
128122

@@ -148,16 +142,16 @@ def _xarray_data_array_from_pandas_multi_index(
148142
)
149143

150144
coords = xr.Coordinates.from_pandas_multiindex(df.index, "multi_index")
151-
xrdarray = xr.DataArray(df[name], coords=coords)
145+
xrdarray = xr.DataArray(df[name], coords=coords).to_dataset(name=name)
152146
else:
153-
xrdarray = df.to_xarray().get(name, xr.DataArray())
147+
xrdarray = df.to_xarray()
154148

155149
return xrdarray
156150

157151

158152
def _xarray_data_array_direct(
159153
dataset: DataSetProtocol, name: str, subdict: Mapping[str, npt.NDArray]
160-
) -> xr.DataArray:
154+
) -> xr.Dataset:
161155
import xarray as xr
162156

163157
meas_paramspec = dataset.description.interdeps.graph.nodes[name]["value"]
@@ -213,12 +207,7 @@ def _xarray_data_array_direct(
213207
data_vars.update(extra_data_vars)
214208

215209
ds = xr.Dataset(data_vars, coords=coords)
216-
da = ds[name]
217-
if len(extra_data_vars) > 0:
218-
# stash extra data vars to be added at dataset assembly time
219-
# mapping: var_name -> (dims_tuple, numpy array)
220-
da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
221-
return da
210+
return ds
222211

223212

224213
def load_to_xarray_dataarray_dict(
@@ -227,17 +216,20 @@ def load_to_xarray_dataarray_dict(
227216
*,
228217
use_multi_index: Literal["auto", "always", "never"] = "auto",
229218
) -> dict[str, xr.DataArray]:
230-
dataarrays = _load_to_xarray_dataarray_dict_no_metadata(
219+
xr_datasets = _load_to_xarray_dataset_dict_no_metadata(
231220
dataset, datadict, use_multi_index=use_multi_index
232221
)
222+
data_arrays: dict[str, xr.DataArray] = {}
233223

234-
for dataname, dataarray in dataarrays.items():
235-
_add_param_spec_to_xarray_coords(dataset, dataarray)
224+
for dataname, xr_dataset in xr_datasets.items():
225+
data_array = xr_dataset[dataname]
226+
_add_param_spec_to_xarray_coords(dataset, data_array)
236227
paramspec_dict = _paramspec_dict_with_extras(dataset, str(dataname))
237-
dataarray.attrs.update(paramspec_dict.items())
238-
_add_metadata_to_xarray(dataset, dataarray)
228+
data_array.attrs.update(paramspec_dict.items())
229+
_add_metadata_to_xarray(dataset, data_array)
230+
data_arrays[dataname] = data_array
239231

240-
return dataarrays
232+
return data_arrays
241233

242234

243235
def _add_metadata_to_xarray(
@@ -276,26 +268,17 @@ def load_to_xarray_dataset(
276268
) -> xr.Dataset:
277269
import xarray as xr
278270

279-
data_xrdarray_dict = _load_to_xarray_dataarray_dict_no_metadata(
271+
xr_dataset_dict = _load_to_xarray_dataset_dict_no_metadata(
280272
dataset, data, use_multi_index=use_multi_index
281273
)
282274

283-
# Casting Hashable for the key type until python/mypy#1114
284-
# and python/typing#445 are resolved.
285-
xrdataset = xr.Dataset(cast("dict[Hashable, xr.DataArray]", data_xrdarray_dict))
286-
287-
# add any stashed extra data variables created during direct export
288-
for _, dataarray in data_xrdarray_dict.items():
289-
extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
290-
if isinstance(extras, dict):
291-
for var_name, (dims, values) in extras.items():
292-
xrdataset[var_name] = (dims, values)
275+
xr_dataset = xr.merge(xr_dataset_dict.values(), compat="equals", join="outer")
293276

294-
_add_param_spec_to_xarray_coords(dataset, xrdataset)
295-
_add_param_spec_to_xarray_data_vars(dataset, xrdataset)
296-
_add_metadata_to_xarray(dataset, xrdataset)
277+
_add_param_spec_to_xarray_coords(dataset, xr_dataset)
278+
_add_param_spec_to_xarray_data_vars(dataset, xr_dataset)
279+
_add_metadata_to_xarray(dataset, xr_dataset)
297280

298-
return xrdataset
281+
return xr_dataset
299282

300283

301284
def _add_param_spec_to_xarray_coords(

0 commit comments

Comments
 (0)