44import warnings
55from importlib .metadata import version
66from math import prod
7- from typing import TYPE_CHECKING , Literal , cast
7+ from typing import TYPE_CHECKING , Literal
88
99from packaging import version as pversion
1010
@@ -66,15 +66,13 @@ def _load_to_xarray_dataarray_dict_no_metadata(
6666 datadict : Mapping [str , Mapping [str , npt .NDArray ]],
6767 * ,
6868 use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
69- ) -> dict [str , xr .DataArray ]:
70- import xarray as xr
71-
69+ ) -> dict [str , xr .Dataset ]:
7270 if use_multi_index not in ("auto" , "always" , "never" ):
7371 raise ValueError (
7472 f"Invalid value for use_multi_index. Expected one of 'auto', 'always', 'never' but got { use_multi_index } "
7573 )
7674
77- data_xrdarray_dict : dict [str , xr .DataArray ] = {}
75+ data_xrdarray_dict : dict [str , xr .Dataset ] = {}
7876
7977 for name , subdict in datadict .items ():
8078 shape_is_consistent = (
@@ -96,11 +94,9 @@ def _load_to_xarray_dataarray_dict_no_metadata(
9694 )
9795
9896 if index is None :
99- xrdarray : xr .DataArray = (
100- _data_to_dataframe (subdict , index = index )
101- .to_xarray ()
102- .get (name , xr .DataArray ())
103- )
97+ xrdarray : xr .Dataset = _data_to_dataframe (
98+ subdict , index = index
99+ ).to_xarray ()
104100 data_xrdarray_dict [name ] = xrdarray
105101 elif index_is_unique :
106102 df = _data_to_dataframe (subdict , index )
@@ -109,9 +105,7 @@ def _load_to_xarray_dataarray_dict_no_metadata(
109105 )
110106 else :
111107 df = _data_to_dataframe (subdict , index )
112- xrdata_temp = df .reset_index ().to_xarray ()
113- for _name in subdict :
114- data_xrdarray_dict [_name ] = xrdata_temp [_name ]
108+ data_xrdarray_dict [name ] = df .reset_index ().to_xarray ()
115109
116110 return data_xrdarray_dict
117111
@@ -122,7 +116,7 @@ def _xarray_data_array_from_pandas_multi_index(
122116 name : str ,
123117 df : pd .DataFrame ,
124118 index : pd .Index | pd .MultiIndex ,
125- ) -> xr .DataArray :
119+ ) -> xr .Dataset :
126120 import pandas as pd
127121 import xarray as xr
128122
@@ -148,16 +142,16 @@ def _xarray_data_array_from_pandas_multi_index(
148142 )
149143
150144 coords = xr .Coordinates .from_pandas_multiindex (df .index , "multi_index" )
151- xrdarray = xr .DataArray (df [name ], coords = coords )
145+ xrdarray = xr .DataArray (df [name ], coords = coords ). to_dataset ( name = name )
152146 else :
153- xrdarray = df .to_xarray (). get ( name , xr . DataArray ())
147+ xrdarray = df .to_xarray ()
154148
155149 return xrdarray
156150
157151
158152def _xarray_data_array_direct (
159153 dataset : DataSetProtocol , name : str , subdict : Mapping [str , npt .NDArray ]
160- ) -> xr .DataArray :
154+ ) -> xr .Dataset :
161155 import xarray as xr
162156
163157 meas_paramspec = dataset .description .interdeps .graph .nodes [name ]["value" ]
@@ -213,20 +207,20 @@ def _xarray_data_array_direct(
213207 data_vars .update (extra_data_vars )
214208
215209 ds = xr .Dataset (data_vars , coords = coords )
216- da = ds [name ]
217- if len (extra_data_vars ) > 0 :
218- # stash extra data vars to be added at dataset assembly time
219- # mapping: var_name -> (dims_tuple, numpy array)
220- da .attrs ["_qcodes_extra_data_vars" ] = extra_data_vars
221- return da
210+ # da = ds[name]
211+ # if len(extra_data_vars) > 0:
212+ # # stash extra data vars to be added at dataset assembly time
213+ # # mapping: var_name -> (dims_tuple, numpy array)
214+ # da.attrs["_qcodes_extra_data_vars"] = extra_data_vars
215+ return ds
222216
223217
224218def load_to_xarray_dataarray_dict (
225219 dataset : DataSetProtocol ,
226220 datadict : Mapping [str , Mapping [str , npt .NDArray ]],
227221 * ,
228222 use_multi_index : Literal ["auto" , "always" , "never" ] = "auto" ,
229- ) -> dict [str , xr .DataArray ]:
223+ ) -> dict [str , xr .Dataset ]:
230224 dataarrays = _load_to_xarray_dataarray_dict_no_metadata (
231225 dataset , datadict , use_multi_index = use_multi_index
232226 )
@@ -282,14 +276,14 @@ def load_to_xarray_dataset(
282276
283277 # Casting Hashable for the key type until python/mypy#1114
284278 # and python/typing#445 are resolved.
285- xrdataset = xr .Dataset ( cast ( "dict[Hashable, xr.DataArray]" , data_xrdarray_dict ) )
279+ xrdataset = xr .merge ( data_xrdarray_dict . values (), compat = "equals" , join = "outer" )
286280
287281 # add any stashed extra data variables created during direct export
288- for _ , dataarray in data_xrdarray_dict .items ():
289- extras = dataarray .attrs .pop ("_qcodes_extra_data_vars" , None )
290- if isinstance (extras , dict ):
291- for var_name , (dims , values ) in extras .items ():
292- xrdataset [var_name ] = (dims , values )
282+ # for _, dataarray in data_xrdarray_dict.items():
283+ # extras = dataarray.attrs.pop("_qcodes_extra_data_vars", None)
284+ # if isinstance(extras, dict):
285+ # for var_name, (dims, values) in extras.items():
286+ # xrdataset[var_name] = (dims, values)
293287
294288 _add_param_spec_to_xarray_coords (dataset , xrdataset )
295289 _add_param_spec_to_xarray_data_vars (dataset , xrdataset )
0 commit comments