44import warnings
55from importlib .metadata import version
66from math import prod
7- from typing import TYPE_CHECKING , Literal , cast
7+ from typing import TYPE_CHECKING , Literal
88
99from packaging import version as pversion
1010
@@ -61,20 +61,18 @@ def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]
6161 return expanded_shape
6262
6363
64- def _load_to_xarray_dataarray_dict_no_metadata (
64+ def _load_to_xarray_dataset_dict_no_metadata (
6565 dataset : DataSetProtocol ,
6666 datadict : Mapping [str , Mapping [str , npt .NDArray ]],
6767 * ,
6868 use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
69- ) -> dict [str , xr .DataArray ]:
70- import xarray as xr
71-
69+ ) -> dict [str , xr .Dataset ]:
7270 if use_multi_index not in ("auto" , "always" , "never" ):
7371 raise ValueError (
7472 f"Invalid value for use_multi_index. Expected one of 'auto', 'always', 'never' but got { use_multi_index } "
7573 )
7674
77- data_xrdarray_dict : dict [str , xr .DataArray ] = {}
75+ data_xrdarray_dict : dict [str , xr .Dataset ] = {}
7876
7977 for name , subdict in datadict .items ():
8078 shape_is_consistent = (
@@ -96,11 +94,9 @@ def _load_to_xarray_dataarray_dict_no_metadata(
9694 )
9795
9896 if index is None :
99- xrdarray : xr .DataArray = (
100- _data_to_dataframe (subdict , index = index )
101- .to_xarray ()
102- .get (name , xr .DataArray ())
103- )
97+ xrdarray : xr .Dataset = _data_to_dataframe (
98+ subdict , index = index
99+ ).to_xarray ()
104100 data_xrdarray_dict [name ] = xrdarray
105101 elif index_is_unique :
106102 df = _data_to_dataframe (subdict , index )
@@ -109,9 +105,7 @@ def _load_to_xarray_dataarray_dict_no_metadata(
109105 )
110106 else :
111107 df = _data_to_dataframe (subdict , index )
112- xrdata_temp = df .reset_index ().to_xarray ()
113- for _name in subdict :
114- data_xrdarray_dict [_name ] = xrdata_temp [_name ]
108+ data_xrdarray_dict [name ] = df .reset_index ().to_xarray ()
115109
116110 return data_xrdarray_dict
117111
@@ -122,7 +116,7 @@ def _xarray_data_array_from_pandas_multi_index(
122116 name : str ,
123117 df : pd .DataFrame ,
124118 index : pd .Index | pd .MultiIndex ,
125- ) -> xr .DataArray :
119+ ) -> xr .Dataset :
126120 import pandas as pd
127121 import xarray as xr
128122
@@ -148,16 +142,16 @@ def _xarray_data_array_from_pandas_multi_index(
148142 )
149143
150144 coords = xr .Coordinates .from_pandas_multiindex (df .index , "multi_index" )
151- xrdarray = xr .DataArray (df [name ], coords = coords )
145+ xrdarray = xr .DataArray (df [name ], coords = coords ). to_dataset ( name = name )
152146 else :
153- xrdarray = df .to_xarray (). get ( name , xr . DataArray ())
147+ xrdarray = df .to_xarray ()
154148
155149 return xrdarray
156150
157151
158152def _xarray_data_array_direct (
159153 dataset : DataSetProtocol , name : str , subdict : Mapping [str , npt .NDArray ]
160- ) -> xr .DataArray :
154+ ) -> xr .Dataset :
161155 import xarray as xr
162156
163157 meas_paramspec = dataset .description .interdeps .graph .nodes [name ]["value" ]
@@ -213,12 +207,7 @@ def _xarray_data_array_direct(
213207 data_vars .update (extra_data_vars )
214208
215209 ds = xr .Dataset (data_vars , coords = coords )
216- da = ds [name ]
217- if len (extra_data_vars ) > 0 :
218- # stash extra data vars to be added at dataset assembly time
219- # mapping: var_name -> (dims_tuple, numpy array)
220- da .attrs ["_qcodes_extra_data_vars" ] = extra_data_vars
221- return da
210+ return ds
222211
223212
224213def load_to_xarray_dataarray_dict (
@@ -227,17 +216,20 @@ def load_to_xarray_dataarray_dict(
227216 * ,
228217 use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
229218) -> dict [str , xr .DataArray ]:
230- dataarrays = _load_to_xarray_dataarray_dict_no_metadata (
219+ xr_datasets = _load_to_xarray_dataset_dict_no_metadata (
231220 dataset , datadict , use_multi_index = use_multi_index
232221 )
222+ data_arrays : dict [str , xr .DataArray ] = {}
233223
234- for dataname , dataarray in dataarrays .items ():
235- _add_param_spec_to_xarray_coords (dataset , dataarray )
224+ for dataname , xr_dataset in xr_datasets .items ():
225+ data_array = xr_dataset [dataname ]
226+ _add_param_spec_to_xarray_coords (dataset , data_array )
236227 paramspec_dict = _paramspec_dict_with_extras (dataset , str (dataname ))
237- dataarray .attrs .update (paramspec_dict .items ())
238- _add_metadata_to_xarray (dataset , dataarray )
228+ data_array .attrs .update (paramspec_dict .items ())
229+ _add_metadata_to_xarray (dataset , data_array )
230+ data_arrays [dataname ] = data_array
239231
240- return dataarrays
232+ return data_arrays
241233
242234
243235def _add_metadata_to_xarray (
@@ -276,26 +268,17 @@ def load_to_xarray_dataset(
276268) -> xr .Dataset :
277269 import xarray as xr
278270
279- data_xrdarray_dict = _load_to_xarray_dataarray_dict_no_metadata (
271+ xr_dataset_dict = _load_to_xarray_dataset_dict_no_metadata (
280272 dataset , data , use_multi_index = use_multi_index
281273 )
282274
283- # Casting Hashable for the key type until python/mypy#1114
284- # and python/typing#445 are resolved.
285- xrdataset = xr .Dataset (cast ("dict[Hashable, xr.DataArray]" , data_xrdarray_dict ))
286-
287- # add any stashed extra data variables created during direct export
288- for _ , dataarray in data_xrdarray_dict .items ():
289- extras = dataarray .attrs .pop ("_qcodes_extra_data_vars" , None )
290- if isinstance (extras , dict ):
291- for var_name , (dims , values ) in extras .items ():
292- xrdataset [var_name ] = (dims , values )
275+ xr_dataset = xr .merge (xr_dataset_dict .values (), compat = "equals" , join = "outer" )
293276
294- _add_param_spec_to_xarray_coords (dataset , xrdataset )
295- _add_param_spec_to_xarray_data_vars (dataset , xrdataset )
296- _add_metadata_to_xarray (dataset , xrdataset )
277+ _add_param_spec_to_xarray_coords (dataset , xr_dataset )
278+ _add_param_spec_to_xarray_data_vars (dataset , xr_dataset )
279+ _add_metadata_to_xarray (dataset , xr_dataset )
297280
298- return xrdataset
281+ return xr_dataset
299282
300283
301284def _add_param_spec_to_xarray_coords (
0 commit comments