@@ -173,6 +173,7 @@ def _xarray_data_array_direct(
173173 ]
174174
175175 extra_coords : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {}
176+ extra_data_vars : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {}
176177 for inf in inferred :
177178 # skip parameters already used as primary coordinate axes
178179 if inf .name in dep_axis :
@@ -187,26 +188,37 @@ def _xarray_data_array_direct(
187188 related_top_level = inf_related .intersection ({meas_paramspec })
188189
189190 if len (related_top_level ) > 0 :
190- raise NotImplementedError (
191- "Adding inferred coords related to top level param is not yet supported"
192- )
193-
194- inf_data = subdict [inf .name ][
195- tuple (slice (None ) if dep in related_deps else 0 for dep in deps )
196- ]
197- inf_coords = [dep .name for dep in deps if dep in related_deps ]
191+ # If inferred param is related to the top-level measurement parameter,
192+ # add it as a data variable with the full dependency dimensions
193+ inf_data_full = subdict [inf .name ]
194+ inf_dims_full = tuple (dep_axis .keys ())
195+ extra_data_vars [inf .name ] = (inf_dims_full , inf_data_full )
196+ else :
197+ # Otherwise, add as a coordinate along the related dependency axes only
198+ inf_data = subdict [inf .name ][
199+ tuple (slice (None ) if dep in related_deps else 0 for dep in deps )
200+ ]
201+ inf_coords = [dep .name for dep in deps if dep in related_deps ]
198202
199- extra_coords [inf .name ] = (tuple (inf_coords ), inf_data )
203+ extra_coords [inf .name ] = (tuple (inf_coords ), inf_data )
200204
201205 # Compose coordinates dict including dependency axes and extra inferred coords
202206 coords : dict [str , tuple [tuple [str , ...], npt .NDArray ] | npt .NDArray ]
203207 coords = {** dep_axis , ** extra_coords }
204208
205- ds = xr .Dataset (
206- {name : (tuple (dep_axis .keys ()), subdict [name ])},
207- coords = coords ,
208- )
209- return ds [name ]
209+ # Compose data variables dict including measured var and any inferred data vars
210+ data_vars : dict [str , tuple [tuple [str , ...], npt .NDArray ]] = {
211+ name : (tuple (dep_axis .keys ()), subdict [name ])
212+ }
213+ data_vars .update (extra_data_vars )
214+
215+ ds = xr .Dataset (data_vars , coords = coords )
216+ da = ds [name ]
217+ if len (extra_data_vars ) > 0 :
218+ # stash extra data vars to be added at dataset assembly time
219+ # mapping: var_name -> (dims_tuple, numpy array)
220+ da .attrs ["_qcodes_extra_data_vars" ] = extra_data_vars
221+ return da
210222
211223
212224def load_to_xarray_dataarray_dict (
@@ -272,6 +284,13 @@ def load_to_xarray_dataset(
272284 # and python/typing#445 are resolved.
273285 xrdataset = xr .Dataset (cast ("dict[Hashable, xr.DataArray]" , data_xrdarray_dict ))
274286
287+ # add any stashed extra data variables created during direct export
288+ for _ , dataarray in data_xrdarray_dict .items ():
289+ extras = dataarray .attrs .pop ("_qcodes_extra_data_vars" , None )
290+ if isinstance (extras , dict ):
291+ for var_name , (dims , values ) in extras .items ():
292+ xrdataset [var_name ] = (dims , values )
293+
275294 _add_param_spec_to_xarray_coords (dataset , xrdataset )
276295 _add_param_spec_to_xarray_data_vars (dataset , xrdataset )
277296 _add_metadata_to_xarray (dataset , xrdataset )
0 commit comments