diff --git a/docs/changes/newsfragments/7346.breaking b/docs/changes/newsfragments/7346.breaking new file mode 100644 index 000000000000..fb466a4cc802 --- /dev/null +++ b/docs/changes/newsfragments/7346.breaking @@ -0,0 +1,11 @@ +Registration and Unpacking interfaces created in ParameterBse + +``ParameterBase`` now implements new `depends_on``, ``is_controlled_by``, and ``has_control_of``properties that allow subclasses to define ``InterDependencies_`` relationships directly +``ParameterBase.unpack_self`` allows subclasses to unpack themselves during ``DataSaver.add_result``, which removes the requirement for users to add pre-defined ``InterDependencies_`` results explicitly +``Measurement.register_parameter`` has been refactored to follow the relationship links defined in parameter subclasses and automatically register related parameters with the appropriate relationships +``DataSaver.add_result`` has been refactored to take advantage of the new ``unpack_self`` method + +Breaking Changes +- A dependent parameter registered with an independent parameter as its ``setpoints`` no longer requires that the independent parameter be registered first, if the independent parameter is ParameterBase subclass and not a str +- Previously, a ParameterWithSetpoints whose setpoints values were explicitly added in add_result would use the explicit version. Now, an error is raised if the explicit values are not within some tolerance of the internal values (as with other duplication). +- ``DataSaver.add_result`` signature has changed from ``*res_tuple`` to ``*result_tuples`` diff --git a/docs/examples/Parameters/Parameter_defined_InterDependencies.ipynb b/docs/examples/Parameters/Parameter_defined_InterDependencies.ipynb new file mode 100644 index 000000000000..fe79c2e8414e --- /dev/null +++ b/docs/examples/Parameters/Parameter_defined_InterDependencies.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a338885a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import TYPE_CHECKING\n", + "\n", + "import numpy as np\n", + "\n", + "from qcodes.dataset import (\n", + " Measurement,\n", + " initialise_or_create_database_at,\n", + " load_or_create_experiment,\n", + ")\n", + "from qcodes.parameters import (\n", + " ManualParameter,\n", + " Parameter,\n", + " ParameterBase,\n", + ")\n", + "\n", + "if TYPE_CHECKING:\n", + " from qcodes.dataset.data_set_protocol import ValuesType\n", + " from qcodes.parameters import ParameterBase, ParamRawDataType" + ] + }, + { + "cell_type": "markdown", + "id": "fd4cb8f8", + "metadata": {}, + "source": [ + "# Parameter-defined InterDependencies\n", + "\n", + "This example demonstrates how to use the `depends_on`, `has_control_of`, and `is_controlled_by` properties to define granular implicit interdependencies between Parameters. These are described in greater detail in the [Interdependent Parameters](../../dataset/interdependentparams.rst)." + ] + }, + { + "cell_type": "markdown", + "id": "a7967c55", + "metadata": {}, + "source": [ + "## Interdependency Definitions:\n", + "- `depends_on`: (also `setpoints`) An experimental relationship, usually the focus of the measurement. A dependent parameter will generally `depend_on` one or more independent parameters\n", + "- `is_controlled_by`: (also `basis` and `inferred_from`) A well-known or defined relationship, with an explicit mathematical function to describe it. The directionality is important: We say a parameter A is inferred from B if there exists a function f such that f(B) = A.\n", + "- `has_control_of`: The opposite direction of the `is_controlled_by` relationship\n", + "\n", + "In this example, we will first create a `ControllingParameter` class that operates two component parameters in tandem according to simple linear equations. We will look at how it uses the `has_control_of` and `is_controlled_by` properties to ensure that these components are properly registered in a `Measurement`. Finally, we will examine its custom `unpack_self` method which allows `datasaver.add_result` to add component results even if they are not explicitly added.\n", + "\n", + "Then we will show how to bind a `depends_on` relationship to a parameter, and demonstrate how this simplifies handling of fixed and constant dependencies." + ] + }, + { + "cell_type": "markdown", + "id": "7a188625", + "metadata": {}, + "source": [ + "# ControllingParameter Example" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "32651dfa", + "metadata": {}, + "outputs": [], + "source": [ + "class ControllingParameter(Parameter):\n", + " def __init__(\n", + " self, name: str, components: dict[Parameter, tuple[float, float]]\n", + " ) -> None:\n", + " super().__init__(name=name, get_cmd=False)\n", + " # dict of Parameter to (slope, offset) of components\n", + " self._components_dict: dict[Parameter, tuple[float, float]] = components\n", + " for param in self._components_dict.keys():\n", + " self._has_control_of.add(param)\n", + " param.is_controlled_by.add(self)\n", + "\n", + " def set_raw(self, value: \"ParamRawDataType\") -> None:\n", + " # Set all dependent parameters based on their slope and offsets\n", + " for param, slope_offset in self._components_dict.items():\n", + " param(value * slope_offset[0] + slope_offset[1])\n", + "\n", + " def get_raw(self) -> \"ParamRawDataType\":\n", + " return self.cache.get()\n", + "\n", + " def unpack_self(\n", + " self, value: \"ValuesType\"\n", + " ) -> list[tuple[\"ParameterBase\", \"ValuesType\"]]:\n", + " assert isinstance(value, float)\n", + " unpacked_results = super().unpack_self(value)\n", + " for param, slope_offset in self._components_dict.items():\n", + " unpacked_results.append((param, value * slope_offset[0] + slope_offset[1]))\n", + " return unpacked_results" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9af7d477", + "metadata": {}, + "outputs": [], + "source": [ + "param1 = ManualParameter(\"param1\", initial_value=0)\n", + "param2 = ManualParameter(\"param2\", initial_value=0)\n", + "control = ControllingParameter(\"control\", components={param1: (1, 0), param2: (-1, 10)})\n", + "\n", + "meas_param = Parameter(\"meas\", get_cmd=lambda: param1() + param2() - 5.0)" + ] + }, + { + "cell_type": "markdown", + "id": "9241ee47", + "metadata": {}, + "source": [ + "## ControllingParameter self-registration of components\n", + "\n", + "In the ``__init__`` method of the `ControllingParameter`, we use two new attributes to define its built-in InterDependencies. The `has_control_of` property is an ordered set of its internal components. We also add the `ControllingParameter` instance to the `is_controlled_by` sets of the components. This lets us register just _one_ of the set `param1, param2, control` and get the other two for free." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "bb26c0f0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'control': ParamSpecBase('control', 'numeric', 'control', ''),\n", + " 'param1': ParamSpecBase('param1', 'numeric', 'param1', ''),\n", + " 'param2': ParamSpecBase('param2', 'numeric', 'param2', '')}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "initialise_or_create_database_at(\"experiments.db\")\n", + "exp = load_or_create_experiment(\"InterDependencies_ examples\")\n", + "meas = Measurement(exp=exp, name=\"self registration example\")\n", + "meas.register_parameter(control)\n", + "\n", + "meas.parameters" + ] + }, + { + "cell_type": "markdown", + "id": "dce92a39", + "metadata": {}, + "source": [ + "In addition to the `has_control_of` and `is_controlled_by` properties, there is also a similar `depends_on` property that can be used to flexibly create something like the `ParameterWithSetpoints`. The `setpoints` of a `ParameterWithSetpoints` are now added to its internal `depends_on` set, where they are automatically self-registered with the same machinery as we demonstrated above." + ] + }, + { + "cell_type": "markdown", + "id": "870e9165", + "metadata": {}, + "source": [ + "## ControllingParameter self-unpacking\n", + "\n", + "For qcodes measurements, parameter registration is only the first part of the story. Inside the measurement loop itself, we use `datasaver.add_result` to save new data to the resulting database. The `unpack_self` method defined in the `ControllingParameter` class handles unpacking a `ControllingParameter` result tuple, so that the data for its components is also saved." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1eb9f5e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting experimental run with id: 6. \n" + ] + } + ], + "source": [ + "with meas.run() as datasaver:\n", + " for i in np.linspace(0, 1, 11):\n", + " control(i)\n", + " datasaver.add_result((control, control()))\n", + " ds = datasaver.dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c48afcbe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'param1': {'param1': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]),\n", + " 'control': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])},\n", + " 'param2': {'param2': array([10. , 9.9, 9.8, 9.7, 9.6, 9.5, 9.4, 9.3, 9.2, 9.1, 9. ]),\n", + " 'control': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.get_parameter_data()" + ] + }, + { + "cell_type": "markdown", + "id": "95f34917", + "metadata": {}, + "source": [ + "### But does it work with dond?\n", + "\n", + "Yes." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "73823b84", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting experimental run with id: 8. Using 'qcodes.dataset.dond'\n" + ] + }, + { + "data": { + "text/plain": [ + "{'meas': {'meas': array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]),\n", + " 'control': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]),\n", + " 'param1': array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]),\n", + " 'param2': array([10. , 9.9, 9.8, 9.7, 9.6, 9.5, 9.4, 9.3, 9.2, 9.1, 9. ])}}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from qcodes.dataset import LinSweep, dond\n", + "\n", + "ds, _, _ = dond(LinSweep(control, 0, 1, 11), meas_param)\n", + "ds.get_parameter_data()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py311", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/plotting/auto_color_scale.ipynb b/docs/examples/plotting/auto_color_scale.ipynb index 2a4662b23584..ddbaaf0fc4b0 100644 --- a/docs/examples/plotting/auto_color_scale.ipynb +++ b/docs/examples/plotting/auto_color_scale.ipynb @@ -105,7 +105,7 @@ "import numpy as np\n", "\n", "from qcodes.dataset.descriptions.dependencies import InterDependencies_\n", - "from qcodes.dataset.descriptions.param_spec import ParamSpecBase\n", + "from qcodes.parameters import ParamSpecBase\n", "\n", "\n", "def dataset_with_outliers_generator(\n", diff --git a/src/qcodes/dataset/data_export.py b/src/qcodes/dataset/data_export.py index e206c6960674..e7259873fc6d 100644 --- a/src/qcodes/dataset/data_export.py +++ b/src/qcodes/dataset/data_export.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from qcodes.dataset.data_set_protocol import DataSetProtocol - from qcodes.dataset.descriptions.param_spec import ParamSpecBase + from qcodes.parameters import ParamSpecBase log = logging.getLogger(__name__) diff --git a/src/qcodes/dataset/data_set.py b/src/qcodes/dataset/data_set.py index 9ca0fdc59e6a..55e8d221b555 100644 --- a/src/qcodes/dataset/data_set.py +++ b/src/qcodes/dataset/data_set.py @@ -110,9 +110,9 @@ import pandas as pd import xarray as xr - from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase + from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes - from qcodes.parameters import ParameterBase + from qcodes.parameters import ParameterBase, ParamSpecBase log = logging.getLogger(__name__) diff --git a/src/qcodes/dataset/data_set_in_memory.py b/src/qcodes/dataset/data_set_in_memory.py index 89b465e27f08..bbd8d0bd2ae4 100644 --- a/src/qcodes/dataset/data_set_in_memory.py +++ b/src/qcodes/dataset/data_set_in_memory.py @@ -56,8 +56,9 @@ import pandas as pd import xarray as xr - from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase + from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes + from qcodes.parameters import ParamSpecBase from ..parameters import ParameterBase diff --git a/src/qcodes/dataset/data_set_protocol.py b/src/qcodes/dataset/data_set_protocol.py index 8c8d94277af7..bc06044bb2a8 100644 --- a/src/qcodes/dataset/data_set_protocol.py +++ b/src/qcodes/dataset/data_set_protocol.py @@ -19,7 +19,7 @@ import numpy.typing as npt from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.export_config import ( DataExportType, get_data_export_name_elements, @@ -42,7 +42,7 @@ from qcodes.dataset.descriptions.rundescriber import RunDescriber from qcodes.dataset.descriptions.versioning.rundescribertypes import Shapes from qcodes.dataset.linked_datasets.links import Link - from qcodes.parameters import ParameterBase + from qcodes.parameters import ParameterBase, ParamSpecBase from .data_set_cache import DataSetCache from .exporters.export_info import ExportInfo diff --git a/src/qcodes/dataset/descriptions/dependencies.py b/src/qcodes/dataset/descriptions/dependencies.py index f8bafb3b673d..49a98b26b7e1 100644 --- a/src/qcodes/dataset/descriptions/dependencies.py +++ b/src/qcodes/dataset/descriptions/dependencies.py @@ -15,10 +15,9 @@ import networkx as nx from typing_extensions import deprecated +from qcodes.parameters import ParamSpecBase from qcodes.utils import QCoDeSDeprecationWarning -from .param_spec import ParamSpecBase - if TYPE_CHECKING: from collections.abc import Sequence diff --git a/src/qcodes/dataset/descriptions/param_spec.py b/src/qcodes/dataset/descriptions/param_spec.py index 02d299a179ec..8a965080cca6 100644 --- a/src/qcodes/dataset/descriptions/param_spec.py +++ b/src/qcodes/dataset/descriptions/param_spec.py @@ -1,131 +1,36 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any -from typing_extensions import TypedDict +from typing_extensions import deprecated -if TYPE_CHECKING: - from collections.abc import Sequence - - -class ParamSpecBaseDict(TypedDict): - name: str - paramtype: str - label: str | None - unit: str | None - - -class ParamSpecDict(ParamSpecBaseDict): - inferred_from: list[str] - depends_on: list[str] - - -class ParamSpecBase: - allowed_types: ClassVar[list[str]] = ["array", "numeric", "text", "complex"] - - def __init__( - self, - name: str, - paramtype: str, - label: str | None = None, - unit: str | None = None, - ): - """ - Args: - name: name of the parameter - paramtype: type of the parameter, i.e. the SQL storage class - label: label of the parameter - unit: The unit of the parameter - - """ - - if not isinstance(paramtype, str): - raise ValueError("Paramtype must be a string.") - if paramtype.lower() not in self.allowed_types: - raise ValueError(f"Illegal paramtype. Must be on of {self.allowed_types}") - if not name.isidentifier(): - raise ValueError( - f"Invalid name: {name}. Only valid python " - "identifier names are allowed (no spaces or " - "punctuation marks, no prepended " - "numbers, etc.)" - ) +from qcodes.parameters import ParamSpecBase as _ParamSpecBase +from qcodes.parameters import ParamSpecBaseDict as _ParamSpecBaseDict - self.name = name - self.type = paramtype.lower() - self.label = label or "" - self.unit = unit or "" - self._hash: int = self._compute_hash() +@deprecated( + "ParamSpecBase is deprecated, use qcodes.parameters.ParamSpecBase instead", +) +class ParamSpecBase(_ParamSpecBase): ... - def _compute_hash(self) -> int: - """ - This method should only be called by __init__ - """ - attrs = ["name", "type", "label", "unit"] - # First, get the hash of the tuple with all the relevant attributes - all_attr_tuple_hash = hash(tuple(getattr(self, attr) for attr in attrs)) - hash_value = all_attr_tuple_hash - # Then, XOR it with the individual hashes of all relevant attributes - for attr in attrs: - hash_value = hash_value ^ hash(getattr(self, attr)) - - return hash_value +@deprecated( + "ParamSpecBaseDict is deprecated, use qcodes.parameters.ParamSpecBaseDict instead", +) +class ParamSpecBaseDict(_ParamSpecBaseDict): ... - def sql_repr(self) -> str: - return f"{self.name} {self.type}" - def __repr__(self) -> str: - return ( - f"ParamSpecBase('{self.name}', '{self.type}', '{self.label}', " - f"'{self.unit}')" - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ParamSpecBase): - return False - attrs = ["name", "type", "label", "unit"] - for attr in attrs: - if getattr(self, attr) != getattr(other, attr): - return False - return True - - def __hash__(self) -> int: - """ - Allow ParamSpecBases in data structures that use hashing (e.g. sets) - """ - return self._hash - - def _to_dict(self) -> ParamSpecBaseDict: - """ - Write the ParamSpec as a dictionary - """ - output = ParamSpecBaseDict( - name=self.name, paramtype=self.type, label=self.label, unit=self.unit - ) - return output +if TYPE_CHECKING: + from collections.abc import Sequence - @classmethod - def _from_dict(cls, ser: ParamSpecBaseDict) -> ParamSpecBase: - """ - Create a ParamSpec instance of the current version - from a dictionary representation of ParamSpec of some version - The version changes must be implemented as a series of transformations - of the representation dict. - """ - - return ParamSpecBase( - name=ser["name"], - paramtype=ser["paramtype"], - label=ser["label"], - unit=ser["unit"], - ) +class ParamSpecDict(_ParamSpecBaseDict): + inferred_from: list[str] + depends_on: list[str] -class ParamSpec(ParamSpecBase): +class ParamSpec(_ParamSpecBase): def __init__( self, name: str, @@ -266,12 +171,12 @@ def _to_dict(self) -> ParamSpecDict: ) return output - def base_version(self) -> ParamSpecBase: + def base_version(self) -> _ParamSpecBase: """ Return a ParamSpecBase object with the same name, paramtype, label and unit as this ParamSpec """ - return ParamSpecBase( + return _ParamSpecBase( name=self.name, paramtype=self.type, label=self.label, unit=self.unit ) diff --git a/src/qcodes/dataset/descriptions/versioning/converters.py b/src/qcodes/dataset/descriptions/versioning/converters.py index 19b69c17d926..897b8a8b6fb1 100644 --- a/src/qcodes/dataset/descriptions/versioning/converters.py +++ b/src/qcodes/dataset/descriptions/versioning/converters.py @@ -8,8 +8,10 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ..dependencies import InterDependencies_ -from ..param_spec import ParamSpec, ParamSpecBase +from ..param_spec import ParamSpec from .rundescribertypes import ( RunDescriberV0Dict, RunDescriberV1Dict, @@ -18,6 +20,9 @@ ) from .v0 import InterDependencies +if TYPE_CHECKING: + from qcodes.parameters import ParamSpecBase + def old_to_new(idps: InterDependencies) -> InterDependencies_: """ diff --git a/src/qcodes/dataset/descriptions/versioning/rundescribertypes.py b/src/qcodes/dataset/descriptions/versioning/rundescribertypes.py index 3f45b7e03797..ef82832e8622 100644 --- a/src/qcodes/dataset/descriptions/versioning/rundescribertypes.py +++ b/src/qcodes/dataset/descriptions/versioning/rundescribertypes.py @@ -21,7 +21,9 @@ from typing_extensions import TypedDict if TYPE_CHECKING: - from ..param_spec import ParamSpecBaseDict, ParamSpecDict + from qcodes.parameters import ParamSpecBaseDict + + from ..param_spec import ParamSpecDict class InterDependenciesDict(TypedDict): diff --git a/src/qcodes/dataset/measurements.py b/src/qcodes/dataset/measurements.py index 9b79660a091c..6eb4c054df3d 100644 --- a/src/qcodes/dataset/measurements.py +++ b/src/qcodes/dataset/measurements.py @@ -15,9 +15,10 @@ from contextlib import ExitStack from copy import deepcopy from inspect import signature +from itertools import chain from numbers import Number from time import perf_counter -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, cast import numpy as np import numpy.typing as npt @@ -37,17 +38,19 @@ from qcodes.dataset.descriptions.dependencies import ( IncompleteSubsetError, InterDependencies_, + ParamSpecTree, ) -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.export_config import get_data_export_automatic from qcodes.parameters import ( ArrayParameter, GroupedParameter, + ManualParameter, MultiParameter, Parameter, ParameterBase, ParameterWithSetpoints, - expand_setpoints_helper, + ParamSpecBase, ) from qcodes.station import Station from qcodes.utils import DelayedKeyboardInterrupt @@ -69,6 +72,9 @@ Callable[..., Any], MutableSequence[Any] | MutableMapping[Any, Any] ] +ParameterResultType: TypeAlias = tuple[ParameterBase, ValuesType] +DatasetResultDict: TypeAlias = dict[ParamSpecBase, npt.NDArray] + class ParameterTypeError(Exception): pass @@ -87,6 +93,7 @@ def __init__( dataset: DataSetProtocol, write_period: float, interdeps: InterDependencies_, + registered_parameters: Sequence[ParameterBase], span: trace.Span | None = None, ) -> None: self._span = span @@ -122,11 +129,53 @@ def __init__( self._last_save_time = perf_counter() self._known_dependencies: dict[str, list[str]] = {} self.parent_datasets: list[DataSetProtocol] = [] + self._registered_parameters = registered_parameters for link in self._dataset.parent_dataset_links: self.parent_datasets.append(load_by_guid(link.tail)) - def add_result(self, *res_tuple: ResType) -> None: + def _validate_result_tuples_no_duplicates(self, *result_tuples: ResType) -> None: + """Validate that the result tuples do not contain duplicates""" + + parameter_names = tuple( + result_tuple[0].register_name + if isinstance(result_tuple[0], ParameterBase) + else result_tuple[0] + for result_tuple in result_tuples + ) + if len(set(parameter_names)) != len(parameter_names): + non_unique = [ + item + for item, count in collections.Counter(parameter_names).items() + if count > 1 + ] + raise ValueError( + f"Not all parameter names are unique. " + f"Got multiple values for {non_unique}" + ) + + def _coerce_result_tuple_to_parameter_result_type( + self, result_tuple: ResType + ) -> ParameterResultType: + param_or_str = result_tuple[0] + if isinstance(param_or_str, ParameterBase): + return (param_or_str, result_tuple[1]) + else: # param_or_str is a str + candidate_params = [ + param + for param in self._registered_parameters + if param.register_name == result_tuple[0] + ] + if len(candidate_params) > 1: + raise ValueError( + f"More than one parameter matched the name {param_or_str}" + f"{candidate_params}" + ) + elif len(candidate_params) < 1: + raise ValueError("No matching parameters") + return (candidate_params[0], result_tuple[1]) + + def add_result(self, *result_tuples: ResType) -> None: """ Add a result to the measurement results. Represents a measurement point in the space of measurement parameters, e.g. in an experiment @@ -142,10 +191,10 @@ def add_result(self, *res_tuple: ResType) -> None: of this class. Args: - res_tuple: A tuple with the first element being the parameter name - and the second element is the corresponding value(s) at this - measurement point. The function takes as many tuples as there - are results. + result_tuples: One or more result tuples with the first element + being the parameter name and the second element is the + corresponding value(s) at this measurement point. The function + takes as many tuples as there are results. Raises: ValueError: If a parameter name is not registered in the parent @@ -159,130 +208,75 @@ def add_result(self, *res_tuple: ResType) -> None: """ - # we iterate through the input twice. First we find any array and - # multiparameters that need to be unbundled and collect the names - # of all parameters. This also allows users to call - # add_result with the arguments in any particular order, i.e. NOT - # enforcing that setpoints come before dependent variables. - results_dict: dict[ParamSpecBase, npt.NDArray] = {} - - parameter_names = tuple( - partial_result[0].register_name - if isinstance(partial_result[0], ParameterBase) - else partial_result[0] - for partial_result in res_tuple - ) - if len(set(parameter_names)) != len(parameter_names): - non_unique = [ - item - for item, count in collections.Counter(parameter_names).items() - if count > 1 - ] + parameter_results: list[ParameterResultType] = [ + self._coerce_result_tuple_to_parameter_result_type(result_tuple) + for result_tuple in result_tuples + ] + + non_unique = [ + item.register_name + for item, count in collections.Counter( + [parameter_result[0] for parameter_result in parameter_results] + ).items() + if count > 1 + ] + if len(non_unique) > 0: raise ValueError( f"Not all parameter names are unique. " f"Got multiple values for {non_unique}" ) - for partial_result in res_tuple: - parameter = partial_result[0] - data = partial_result[1] + legacy_results_dict: DatasetResultDict = {} + self_unpacked_parameter_results: list[ParameterResultType] = [] - if isinstance(parameter, ParameterBase) and isinstance( - parameter.vals, vals.Arrays - ): - if not isinstance(data, np.ndarray): - raise TypeError( - f"Expected data for Parameter with Array validator " - f"to be a numpy array but got: {type(data)}" - ) - - if ( - parameter.vals.shape is not None - and data.shape != parameter.vals.shape - ): - raise TypeError( - f"Expected data with shape {parameter.vals.shape}, " - f"but got {data.shape} for parameter: {parameter.full_name}" - ) - - if isinstance(parameter, ArrayParameter): - results_dict.update(self._unpack_arrayparameter(partial_result)) - elif isinstance(parameter, MultiParameter): - results_dict.update(self._unpack_multiparameter(partial_result)) - elif isinstance(parameter, ParameterWithSetpoints): - results_dict.update( - self._conditionally_expand_parameter_with_setpoints( - data, parameter, parameter_names, partial_result - ) + for parameter_result in parameter_results: + if isinstance(parameter_result[0], ArrayParameter): + legacy_results_dict.update( + self._unpack_arrayparameter(parameter_result) + ) + elif isinstance(parameter_result[0], MultiParameter): + legacy_results_dict.update( + self._unpack_multiparameter(parameter_result) ) else: - results_dict.update(self._unpack_partial_result(partial_result)) - - self._validate_result_deps(results_dict) - self._validate_result_shapes(results_dict) - self._validate_result_types(results_dict) - - self.dataset._enqueue_results(results_dict) - - if perf_counter() - self._last_save_time > self.write_period: - self.flush_data_to_database() - self._last_save_time = perf_counter() + self_unpacked_parameter_results.extend( + parameter_result[0].unpack_self(parameter_result[1]) + ) - def _conditionally_expand_parameter_with_setpoints( - self, - data: ValuesType, - parameter: ParameterWithSetpoints, - parameter_names: Sequence[str], - partial_result: ResType, - ) -> dict[ParamSpecBase, npt.NDArray]: - local_results = {} - setpoint_names = tuple( - setpoint.register_name for setpoint in parameter.setpoints - ) - expanded = tuple( - setpoint_name in parameter_names for setpoint_name in setpoint_names + all_results_dict: dict[ParamSpecBase, list[npt.NDArray]] = ( + collections.defaultdict(list) ) - if all(expanded): - local_results.update(self._unpack_partial_result(partial_result)) - elif any(expanded): - raise ValueError( - f"Some of the setpoints of {parameter.full_name} " - "were explicitly given but others were not. " - "Either supply all of them or none of them." - ) - else: - expanded_partial_result = expand_setpoints_helper(parameter, data) - for res in expanded_partial_result: - local_results.update(self._unpack_partial_result(res)) - return local_results - - def _unpack_partial_result( - self, partial_result: ResType - ) -> dict[ParamSpecBase, npt.NDArray]: - """ - Unpack a partial result (not containing :class:`ArrayParameters` or - class:`MultiParameters`) into a standard results dict form and return - that dict - """ - param, values = partial_result - try: - parameter = self._interdeps._id_to_paramspec[str_or_register_name(param)] - except KeyError: - if str_or_register_name(param) == str(param): - err_msg = ( - "Can not add result for parameter " - f"{param}, no such parameter registered " - "with this measurement." - ) - else: - err_msg = ( + for parameter_result in self_unpacked_parameter_results: + try: + result_paramspec = self._interdeps._id_to_paramspec[ + parameter_result[0].register_name + ] + except KeyError: + raise ValueError( "Can not add result for parameter " - f"{param!s} or {str_or_register_name(param)}," + f"{parameter_result[0].register_name}, " "no such parameter registered " "with this measurement." ) - raise ValueError(err_msg) - return {parameter: np.array(values)} + all_results_dict[result_paramspec].append(np.array(parameter_result[1])) + + # Add any unpacked results from legacy Parameter types + for key, value in legacy_results_dict.items(): + all_results_dict[key].append(value) + + datasaver_results_dict: DatasetResultDict = _deduplicate_results( + all_results_dict + ) + + self._validate_result_deps(datasaver_results_dict) + self._validate_result_shapes(datasaver_results_dict) + self._validate_result_types(datasaver_results_dict) + + self.dataset._enqueue_results(datasaver_results_dict) + + if perf_counter() - self._last_save_time > self.write_period: + self.flush_data_to_database() + self._last_save_time = perf_counter() def _unpack_arrayparameter( self, partial_result: ResType @@ -550,7 +544,7 @@ def __init__( in_memory_cache: bool | None = None, dataset_class: DataSetType = DataSetType.DataSet, parent_span: trace.Span | None = None, - registered_parameters: Sequence[ParameterBase] | None = None, + registered_parameters: Sequence[ParameterBase] = (), ) -> None: if in_memory_cache is None: in_memory_cache = qc.config.dataset.in_memory_cache @@ -725,6 +719,7 @@ def __enter__(self) -> DataSaver: dataset=self.ds, write_period=self.write_period, interdeps=self._interdependencies, + registered_parameters=self._registered_parameters, span=self._span, ) @@ -818,7 +813,7 @@ def __init__( self._shapes: Shapes | None = None self._parent_datasets: list[dict[str, str]] = [] self._extra_log_info: str = "" - self._registered_parameters: list[ParameterBase] = [] + self._registered_parameters: set[ParameterBase] = set() @property def parameters(self) -> dict[str, ParamSpecBase]: @@ -839,7 +834,6 @@ def write_period(self, wp: float) -> None: def _paramspecbase_from_strings( self, - name: str, setpoints: Sequence[str] | None = None, basis: Sequence[str] | None = None, ) -> tuple[tuple[ParamSpecBase, ...], tuple[ParamSpecBase, ...]]: @@ -848,10 +842,9 @@ def _paramspecbase_from_strings( error message if the user tries to register a parameter with reference (setpoints, basis) to a parameter not registered with this measurement - Called by _register_parameter only. + Called by _register_parameter and _self_register_parameter only. Args: - name: Name of the parameter to register setpoints: name(s) of the setpoint parameter(s) basis: name(s) of the parameter(s) that this parameter is inferred from @@ -912,6 +905,103 @@ def register_parent( return self + def _paramspecs_and_parameters_from_setpoints( + self, setpoints: SetpointsType | None + ) -> tuple[list[ParamSpecBase], list[ParameterBase]]: + paramspecs = [] + parameters = [] + if setpoints is not None: + for setpoint in setpoints: + if isinstance(setpoint, ParameterBase): + paramspecs.append(setpoint.param_spec) + parameters.append(setpoint) + elif ( + isinstance(setpoint, str) + and ( + setpoint_paramspec := self._interdeps._id_to_paramspec.get( + setpoint, None + ) + ) + is not None + ): + paramspecs.append(setpoint_paramspec) + else: + raise ValueError( + f"Unknown interdependency: {setpoint}. Please register that parameter first." + ) + return paramspecs, parameters + + def _self_register_parameter( + self: Self, + parameter: ParameterBase, + setpoints: SetpointsType | None = None, + basis: SetpointsType | None = None, + ) -> Self: + # It is important to preserve the order of the setpoints (and basis) arguments + # when building the dependency trees, as this order is implicitly used to assign + # the axis-order for multidimensional data variables where shape alone is + # insufficient (eg, if the shape is square) + + # Convert setpoints and basis arguments to ParamSpecBases + dependency_paramspecs, dependency_parameters = ( + self._paramspecs_and_parameters_from_setpoints(setpoints) + ) + inference_paramspecs, inference_parameters = ( + self._paramspecs_and_parameters_from_setpoints(basis) + ) + + # Append internal dependencies/inferences + dependency_paramspecs.extend( + [param.param_spec for param in parameter.depends_on] + ) + inference_paramspecs.extend( + [param.param_spec for param in parameter.is_controlled_by] + ) + + # Make ParamSpecTrees and extend interdeps + dependencies_tree: ParamSpecTree | None = None + if len(dependency_paramspecs) > 0: + dependencies_tree = {parameter.param_spec: tuple(dependency_paramspecs)} + + inferences_tree: ParamSpecTree | None = None + if len(inference_paramspecs) > 0: + inferences_tree = {parameter.param_spec: tuple(inference_paramspecs)} + + standalones: tuple[ParamSpecBase, ...] = () + if dependencies_tree is None and inferences_tree is None: + standalones = (parameter.param_spec,) + + self._interdeps = self._interdeps.extend( + dependencies=dependencies_tree, + inferences=inferences_tree, + standalones=standalones, + ) + self._registered_parameters.add(parameter) + log.info(f"Registered {parameter.register_name} in the Measurement.") + + # Recursively register all other interdependent parameters related to this parameter + interdependent_parameters = list( + chain.from_iterable( + [ + dependency_parameters, + inference_parameters, + parameter.depends_on, + parameter.is_controlled_by, + ] + ) + ) + for interdependent_parameter in interdependent_parameters: + if interdependent_parameter not in self._registered_parameters: + self._self_register_parameter(interdependent_parameter) + + # We handle the `has_control_of` relationship differently so that the controlled parameter + # does not need to implement the reverse-direction `is_controlled_by` to get the + # inference relationship + for controlled_parameter in parameter.has_control_of: + self._self_register_parameter(controlled_parameter, basis=(parameter,)) + + return self + def register_parameter( self: Self, parameter: ParameterBase, @@ -936,66 +1026,44 @@ def register_parameter( and the validator of the supplied parameter. """ - if not isinstance(parameter, ParameterBase): - raise ValueError( - f"Can not register object of type {type(parameter)}. Can only " - "register a QCoDeS Parameter." - ) - paramtype = self._infer_paramtype(parameter, paramtype) - # default to numeric - if paramtype is None: - paramtype = "numeric" - - # now the parameter type must be valid - if paramtype not in ParamSpec.allowed_types: - raise RuntimeError( - "Trying to register a parameter with type " - f"{paramtype}. However, only " - f"{ParamSpec.allowed_types} are supported." - ) - if setpoints is not None: self._check_setpoints_type(setpoints, "setpoints") if basis is not None: self._check_setpoints_type(basis, "basis") - if isinstance(parameter, ArrayParameter): - self._register_arrayparameter(parameter, setpoints, basis, paramtype) - elif isinstance(parameter, ParameterWithSetpoints): - self._register_parameter_with_setpoints( - parameter, setpoints, basis, paramtype - ) - elif isinstance(parameter, MultiParameter): - self._register_multiparameter( - parameter, - setpoints, - basis, - paramtype, - ) - elif isinstance(parameter, Parameter): - self._register_parameter( - parameter.register_name, - parameter.label, - parameter.unit, - setpoints, - basis, - paramtype, - ) - elif isinstance(parameter, GroupedParameter): - self._register_parameter( - parameter.register_name, - parameter.label, - parameter.unit, - setpoints, - basis, - paramtype, - ) - else: - raise RuntimeError( - f"Does not know how to register a parameter of type {type(parameter)}" - ) - self._registered_parameters.append(parameter) + match parameter: + case ArrayParameter(): + paramtype = self._infer_paramtype(parameter, paramtype) + self._register_arrayparameter(parameter, setpoints, basis, paramtype) + case MultiParameter(): + paramtype = self._infer_paramtype(parameter, paramtype) + self._register_multiparameter( + parameter, + setpoints, + basis, + paramtype, + ) + case GroupedParameter(): + paramtype = self._infer_paramtype(parameter, paramtype) + self._register_parameter( + parameter.register_name, + parameter.label, + parameter.unit, + setpoints, + basis, + paramtype, + ) + case ParameterBase() | ParameterWithSetpoints(): + if paramtype is not None: + parameter.paramtype = paramtype + self._self_register_parameter(parameter, setpoints, basis) + case _: + raise ValueError( + f"Can not register object of type {type(parameter)}. Can only " + "register a QCoDeS Parameter." + ) + self._registered_parameters.add(parameter) return self @@ -1011,7 +1079,7 @@ def _check_setpoints_type(arg: SetpointsType, name: str) -> None: ) @staticmethod - def _infer_paramtype(parameter: ParameterBase, paramtype: str | None) -> str | None: + def _infer_paramtype(parameter: ParameterBase, paramtype: str | None) -> str: """ Infer the best parameter type to store the parameter supplied. @@ -1025,20 +1093,29 @@ def _infer_paramtype(parameter: ParameterBase, paramtype: str | None) -> str | N Returns None if a parameter type could not be inferred """ - if paramtype is not None: - return paramtype - - if isinstance(parameter.vals, vals.Arrays): - paramtype = "array" + return_paramtype: str + if paramtype is not None: # override with argument + return_paramtype = paramtype + elif isinstance(parameter.vals, vals.Arrays): + return_paramtype = "array" elif isinstance(parameter, ArrayParameter): - paramtype = "array" + return_paramtype = "array" elif isinstance(parameter.vals, vals.Strings): - paramtype = "text" + return_paramtype = "text" elif isinstance(parameter.vals, vals.ComplexNumbers): - paramtype = "complex" + return_paramtype = "complex" + else: # Default to this if nothing else matches + return_paramtype = "numeric" + + if return_paramtype not in ParamSpec.allowed_types: + raise RuntimeError( + "Trying to register a parameter with type " + f"{return_paramtype}. However, only " + f"{ParamSpec.allowed_types} are supported." + ) # TODO should we try to figure out if parts of a multiparameter are # arrays or something else? - return paramtype + return return_paramtype def _register_parameter( self: Self, @@ -1084,9 +1161,7 @@ def _register_parameter( bs_strings = [] # get the ParamSpecBases - depends_on, inf_from = self._paramspecbase_from_strings( - name, sp_strings, bs_strings - ) + depends_on, inf_from = self._paramspecbase_from_strings(sp_strings, bs_strings) if depends_on: self._interdeps = self._interdeps.extend( @@ -1283,6 +1358,8 @@ def register_custom_parameter( paramtype: Type of the parameter, i.e. the SQL storage class """ + custom_parameter = ManualParameter(name=name, label=label, unit=unit) + self._registered_parameters.add(custom_parameter) return self._register_parameter(name, label, unit, setpoints, basis, paramtype) def unregister_parameter(self, parameter: SetpointsType) -> None: @@ -1319,7 +1396,7 @@ def unregister_parameter(self, parameter: SetpointsType) -> None: for param in self._registered_parameters if parameter not in (param.name, param.register_name) ] - self._registered_parameters = with_parameters_removed + self._registered_parameters = set(with_parameters_removed) log.info(f"Removed {param_name} from Measurement.") @@ -1441,7 +1518,7 @@ def run( in_memory_cache=in_memory_cache, dataset_class=dataset_class, parent_span=parent_span, - registered_parameters=self._registered_parameters, + registered_parameters=tuple(self._registered_parameters), ) @@ -1451,3 +1528,62 @@ def str_or_register_name(sp: str | ParameterBase) -> str: return sp else: return sp.register_name + + +# TODO: These deduplication methods need testing against arrays with all ValuesType types +def _deduplicate_results( + results_dict: dict[ParamSpecBase, list[npt.NDArray]], +) -> DatasetResultDict: + deduplicated_results: dict[ParamSpecBase, npt.NDArray] = {} + for param_spec, list_of_ndarrays_of_values in results_dict.items(): + if len(list_of_ndarrays_of_values) == 1 or _values_are_equal( + list_of_ndarrays_of_values[0], *list_of_ndarrays_of_values[1:] + ): + deduplicated_results[param_spec] = list_of_ndarrays_of_values[0] + else: + raise ValueError(f"Multiple distinct values found for {param_spec.name}") + return deduplicated_results + + +def _values_are_equal(ref_array: npt.NDArray, *values_arrays: npt.NDArray) -> bool: + if np.issubdtype(ref_array.dtype, np.number): + return _numeric_values_are_equal(ref_array, *values_arrays) + return _non_numeric_values_are_equal(ref_array, *values_arrays) + + +def _non_numeric_values_are_equal( + ref_array: npt.NDArray, *values_arrays: npt.NDArray +) -> bool: + # For non-numeric values, we can use direct equality + for value_array in values_arrays: + if (ref_array.shape != value_array.shape) or not np.array_equal( + value_array, ref_array + ): + return False + return True + + +def _numeric_values_are_equal( + ref_array: npt.NDArray, *values_arrays: npt.NDArray +) -> bool: + # The equal_nan arg in np.allclose considers complex values with np.nan in + # either real or imaginary part to be equal. That is, np.nan + 1.0j is equal to 1.0 + np.nan*1.0j. + # Since we want a more granular equality, we split arrays with complex values + # into real and imaginary parts to evaluate equality + if np.issubdtype(ref_array.dtype, np.complexfloating): + return _numeric_values_are_equal( + np.real(ref_array), *[np.real(value_array) for value_array in values_arrays] + ) and _numeric_values_are_equal( + np.imag(ref_array), *[np.imag(value_array) for value_array in values_arrays] + ) + + for value_array in values_arrays: + if (ref_array.shape != value_array.shape) or not np.allclose( + value_array, + ref_array, + atol=0, + rtol=1e-8, # TODO: allow flexible rtol + equal_nan=True, + ): + return False + return True diff --git a/src/qcodes/dataset/sqlite/queries.py b/src/qcodes/dataset/sqlite/queries.py index 512fb4a3670d..4825f99f8bcd 100644 --- a/src/qcodes/dataset/sqlite/queries.py +++ b/src/qcodes/dataset/sqlite/queries.py @@ -20,7 +20,7 @@ import qcodes as qc from qcodes import config from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.descriptions.rundescriber import RunDescriber from qcodes.dataset.descriptions.versioning import serialization as serial from qcodes.dataset.descriptions.versioning import v0 @@ -47,6 +47,7 @@ sql_placeholder_string, update_where, ) +from qcodes.parameters import ParamSpecBase from qcodes.utils import list_of_data_to_maybe_ragged_nd_array if TYPE_CHECKING: diff --git a/src/qcodes/interactive_widget.py b/src/qcodes/interactive_widget.py index b1dfa316572e..503dd5e152e2 100644 --- a/src/qcodes/interactive_widget.py +++ b/src/qcodes/interactive_widget.py @@ -30,7 +30,7 @@ from collections.abc import Callable, Iterable, Sequence from qcodes.dataset.data_set_protocol import DataSetProtocol - from qcodes.dataset.descriptions.param_spec import ParamSpecBase + from qcodes.parameters import ParamSpecBase _META_DATA_KEY = "widget_notes" diff --git a/src/qcodes/parameters/__init__.py b/src/qcodes/parameters/__init__.py index 3d850d8435f7..eaaa1654e607 100644 --- a/src/qcodes/parameters/__init__.py +++ b/src/qcodes/parameters/__init__.py @@ -68,6 +68,7 @@ """ +from ._paramspec import ParamSpecBase, ParamSpecBaseDict from .array_parameter import ArrayParameter from .combined_parameter import CombinedParameter, combine from .delegate_parameter import DelegateParameter @@ -80,6 +81,7 @@ from .parameter_base import ( ParamDataType, ParameterBase, + ParameterSet, ParamRawDataType, invert_val_mapping, ) @@ -106,8 +108,11 @@ "MultiParameter", "ParamDataType", "ParamRawDataType", + "ParamSpecBase", + "ParamSpecBaseDict", "Parameter", "ParameterBase", + "ParameterSet", "ParameterWithSetpoints", "ScaledParameter", "SweepFixedValues", diff --git a/src/qcodes/parameters/_paramspec.py b/src/qcodes/parameters/_paramspec.py new file mode 100644 index 000000000000..8f50956c9a1a --- /dev/null +++ b/src/qcodes/parameters/_paramspec.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import ClassVar + +from typing_extensions import TypedDict + + +class ParamSpecBaseDict(TypedDict): + name: str + paramtype: str + label: str | None + unit: str | None + + +class ParamSpecBase: + allowed_types: ClassVar[list[str]] = ["array", "numeric", "text", "complex"] + + def __init__( + self, + name: str, + paramtype: str, + label: str | None = None, + unit: str | None = None, + ): + """ + Args: + name: name of the parameter + paramtype: type of the parameter, i.e. the SQL storage class + label: label of the parameter + unit: The unit of the parameter + + """ + + if not isinstance(paramtype, str): + raise ValueError("Paramtype must be a string.") + if paramtype.lower() not in self.allowed_types: + raise ValueError(f"Illegal paramtype. Must be on of {self.allowed_types}") + if not name.isidentifier(): + raise ValueError( + f"Invalid name: {name}. Only valid python " + "identifier names are allowed (no spaces or " + "punctuation marks, no prepended " + "numbers, etc.)" + ) + + self.name = name + self.type = paramtype.lower() + self.label = label or "" + self.unit = unit or "" + + self._hash: int = self._compute_hash() + + def _compute_hash(self) -> int: + """ + This method should only be called by __init__ + """ + attrs = ["name", "type", "label", "unit"] + # First, get the hash of the tuple with all the relevant attributes + all_attr_tuple_hash = hash(tuple(getattr(self, attr) for attr in attrs)) + hash_value = all_attr_tuple_hash + + # Then, XOR it with the individual hashes of all relevant attributes + for attr in attrs: + hash_value = hash_value ^ hash(getattr(self, attr)) + + return hash_value + + def sql_repr(self) -> str: + return f"{self.name} {self.type}" + + def __repr__(self) -> str: + return ( + f"ParamSpecBase('{self.name}', '{self.type}', '{self.label}', " + f"'{self.unit}')" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ParamSpecBase): + return False + attrs = ["name", "type", "label", "unit"] + for attr in attrs: + if getattr(self, attr) != getattr(other, attr): + return False + return True + + def __hash__(self) -> int: + """ + Allow ParamSpecBases in data structures that use hashing (e.g. sets) + """ + return self._hash + + def _to_dict(self) -> ParamSpecBaseDict: + """ + Write the ParamSpec as a dictionary + """ + output = ParamSpecBaseDict( + name=self.name, paramtype=self.type, label=self.label, unit=self.unit + ) + return output + + @classmethod + def _from_dict(cls, ser: ParamSpecBaseDict) -> ParamSpecBase: + """ + Create a ParamSpec instance of the current version + from a dictionary representation of ParamSpec of some version + + The version changes must be implemented as a series of transformations + of the representation dict. + """ + + return ParamSpecBase( + name=ser["name"], + paramtype=ser["paramtype"], + label=ser["label"], + unit=ser["unit"], + ) diff --git a/src/qcodes/parameters/parameter.py b/src/qcodes/parameters/parameter.py index 4fc93f2215f0..1cfe55248471 100644 --- a/src/qcodes/parameters/parameter.py +++ b/src/qcodes/parameters/parameter.py @@ -17,6 +17,7 @@ from qcodes.instrument import InstrumentBase from qcodes.logger.instrument_logger import InstrumentLoggerAdapter + from qcodes.parameters import ParamSpecBase from qcodes.validators import Validator @@ -437,6 +438,13 @@ def sweep( """ return SweepFixedValues(self, start=start, stop=stop, step=step, num=num) + @property + def param_spec(self) -> ParamSpecBase: + paramspecbase = super().param_spec # Sets the name and paramtype + paramspecbase.label = self.label + paramspecbase.unit = self.unit + return paramspecbase + class ManualParameter(Parameter): def __init__( diff --git a/src/qcodes/parameters/parameter_base.py b/src/qcodes/parameters/parameter_base.py index 85a19f5b843d..baba76c77fd2 100644 --- a/src/qcodes/parameters/parameter_base.py +++ b/src/qcodes/parameters/parameter_base.py @@ -4,14 +4,26 @@ import logging import time import warnings +from collections.abc import Iterator, MutableSet from contextlib import contextmanager from datetime import datetime from functools import cached_property, wraps -from typing import TYPE_CHECKING, Any, ClassVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload + +import numpy as np from qcodes.metadatable import Metadatable, MetadatableWithName +from qcodes.parameters import ParamSpecBase from qcodes.utils import DelegateAttributes, full_class, qcodes_abstractmethod -from qcodes.validators import Enum, Ints, Validator +from qcodes.validators import ( + Arrays, + ComplexNumbers, + Enum, + Ints, + Numbers, + Strings, + Validator, +) from ..utils.types import NumberType from .cache import _Cache, _CacheProtocol @@ -26,6 +38,7 @@ from collections.abc import Callable, Generator, Iterable, Mapping, Sequence, Sized from types import TracebackType + from qcodes.dataset.data_set_protocol import ValuesType from qcodes.instrument import InstrumentBase from qcodes.logger.instrument_logger import InstrumentLoggerAdapter @@ -235,6 +248,10 @@ def __init__( self.snapshot_exclude = snapshot_exclude self.on_set_callback = on_set_callback + self._depends_on: ParameterSet = ParameterSet() + self._has_control_of: ParameterSet = ParameterSet() + self._is_controlled_by: ParameterSet = ParameterSet() + self._param_spec: ParamSpecBase | None = None if not isinstance(vals, (Validator, type(None))): raise TypeError("vals must be None or a Validator") elif val_mapping is not None: @@ -1144,6 +1161,89 @@ def underlying_instrument(self) -> InstrumentBase | None: def abstract(self) -> bool | None: return self._abstract + @property + def param_spec(self) -> ParamSpecBase: + if self._param_spec is None: + match self.vals: + case Arrays(): + paramtype = "array" + case Strings(): + paramtype = "text" + case ComplexNumbers(): + paramtype = "complex" + case _: + paramtype = "numeric" + + self._param_spec = ParamSpecBase( + name=self.register_name, + paramtype=paramtype, + label=None, + unit=None, + ) + return self._param_spec + + @property + def paramtype(self) -> str: + return self.param_spec.type + + @paramtype.setter + def paramtype(self, paramtype: str) -> None: + self._set_paramtype(paramtype) # Indirected here, so subclasses can override + + def _set_paramtype(self, paramtype: str) -> None: + paramtype = paramtype.lower() + if paramtype not in ["array", "text", "complex", "numeric"]: + raise ValueError(f"{paramtype} is not a valid paramtype") + if self.paramtype == paramtype: + return + new_vals: Validator + match paramtype: + case "array": + new_vals = Arrays() + case "text": + new_vals = Strings() + case "complex": + new_vals = ComplexNumbers() + case "numeric": + new_vals = Numbers() + case _: + raise NotImplementedError("This should not be possible") + if self.vals is None: + self.vals = new_vals + elif type(self.vals) is not type(new_vals): + logging.warning( + f"Tried to set a new paramtype {paramtype}, but this parameter already has paramtype {self.paramtype} which does not match" + ) + self.param_spec.type = paramtype + + @property + def depends_on(self) -> ParameterSet: + return self._depends_on + + @property + def has_control_of(self) -> ParameterSet: + return self._has_control_of + + @property + def is_controlled_by(self) -> ParameterSet: + # This is equivalent to the "inferred_from" relationship + return self._is_controlled_by + + def unpack_self(self, value: ValuesType) -> list[tuple[ParameterBase, ValuesType]]: + if isinstance(self.vals, Arrays): + if not isinstance(value, np.ndarray): + raise TypeError( + f"Expected data for Parameter with Array validator " + f"to be a numpy array but got: {type(value)}" + ) + + if self.vals.shape is not None and value.shape != self.vals.shape: + raise TypeError( + f"Expected data with shape {self.vals.shape}, " + f"but got {value.shape} for parameter: {self.full_name}" + ) + return [(self, value)] + class GetLatest(DelegateAttributes): """ @@ -1211,3 +1311,125 @@ def __call__(self) -> ParamDataType: It is recommended to use ``parameter.cache()`` instead. """ return self.cache() + + +P = TypeVar("P", bound=ParameterBase) + + +# Does not implement __hash__, not clear it needs to +class ParameterSet(MutableSet, Generic[P]): # noqa: PLW1641 + """A set-like container that preserves the insertion order of its parameters. + + This class implements the common set interface methods while maintaining + the order in which parameters were first added. + """ + + def __init__(self, parameters: Sequence[P] | None = None) -> None: + self._dict: dict[P, None] = {} + if parameters is not None: + for item in parameters: + self.add(item) + + def add(self, value: P) -> None: + self._dict[value] = None + + def remove(self, value: P) -> None: + self._dict.pop(value) + + def discard(self, value: P) -> None: + if value in self._dict: + self._dict.pop(value) + + def clear(self) -> None: + self._dict.clear() + + def pop(self) -> ParameterBase: + if not self._dict: + raise KeyError("pop from an empty ParameterSet") + item = next(iter(self._dict)) + self._dict.pop(item) + return item + + def union(self, other: ParameterSet[P]) -> ParameterSet[P]: + result = ParameterSet(list(self._dict.keys())) + for item in other: + result.add(item) + return result + + def intersection(self, other: ParameterSet[P]) -> ParameterSet[P]: + result: ParameterSet[P] = ParameterSet() + for item in self: + if item in other: + result.add(item) + return result + + def difference(self, other: ParameterSet[P]) -> ParameterSet[P]: + result: ParameterSet[P] = ParameterSet() + for item in self: + if item not in other: + result.add(item) + return result + + def issubset(self, other: ParameterSet[P] | set) -> bool: + return all(item in other for item in self) + + def issuperset(self, other: ParameterSet[P] | set) -> bool: + return all(item in self for item in other) + + def update(self, other: Iterable[P]) -> None: + for item in other: + self.add(item) + + def __iter__(self) -> Iterator[P]: + return iter(self._dict) + + def __contains__(self, item: object) -> bool: + return item in self._dict + + def __len__(self) -> int: + return len(self._dict) + + def __eq__(self, other: object) -> bool: + if isinstance(other, ParameterSet): + return set(self._dict) == set(other._dict) + return False + + def __repr__(self) -> str: + if not self: + return f"{self.__class__.__name__}()" + return f"{self.__class__.__name__}({list(self._dict.keys())})" + + def __or__(self, other: object) -> ParameterSet[P]: + if isinstance(other, ParameterSet): + return self.union(other) + raise NotImplementedError( + f"OR operation is not defined between ParameterSet and {type(other)}" + ) + + def __and__(self, other: object) -> ParameterSet[P]: + if isinstance(other, ParameterSet): + return self.intersection(other) + raise NotImplementedError( + f"AND operation is not defined between ParameterSet and {type(other)}" + ) + + def __sub__(self, other: object) -> ParameterSet[P]: + if isinstance(other, ParameterSet): + return self.difference(other) + raise NotImplementedError( + f"Difference operation is not defined between ParameterSet and {type(other)}" + ) + + def __le__(self, other: object) -> bool: + if isinstance(other, ParameterSet): + return self.issubset(other) + raise NotImplementedError( + f"<= operation is not defined between ParameterSet and {type(other)}" + ) + + def __ge__(self, other: object) -> bool: + if isinstance(other, ParameterSet): + return self.issuperset(other) + raise NotImplementedError( + f">+ operation is not defined between ParameterSet and {type(other)}" + ) diff --git a/src/qcodes/parameters/parameter_with_setpoints.py b/src/qcodes/parameters/parameter_with_setpoints.py index c2cbf117d919..d9ef9badf4dd 100644 --- a/src/qcodes/parameters/parameter_with_setpoints.py +++ b/src/qcodes/parameters/parameter_with_setpoints.py @@ -5,14 +5,15 @@ import numpy as np +from qcodes.parameters.parameter import Parameter +from qcodes.parameters.parameter_base import ParameterBase, ParameterSet from qcodes.validators import Arrays, Validator -from .parameter import Parameter - if TYPE_CHECKING: from collections.abc import Callable, Sequence - from .parameter_base import ParamDataType, ParameterBase + from qcodes.dataset.data_set_protocol import ValuesType + from qcodes.parameters.parameter_base import ParamDataType, ParameterBase LOG = logging.getLogger(__name__) @@ -154,6 +155,31 @@ def validate(self, value: ParamDataType) -> None: self.validate_consistent_shape() super().validate(value) + @property + def depends_on(self) -> ParameterSet: + return ParameterSet(self.setpoints) + + def unpack_self(self, value: ValuesType) -> list[tuple[ParameterBase, ValuesType]]: + unpacked_results: list[tuple[ParameterBase, ValuesType]] = [] + setpoint_params = [] + setpoint_data = [] + for setpointparam in self.setpoints: + these_setpoints = setpointparam.get() + setpoint_params.append(setpointparam) + setpoint_data.append(these_setpoints) + output_grids = np.meshgrid(*setpoint_data, indexing="ij") + for param, grid in zip(setpoint_params, output_grids): + unpacked_results.append((param, grid)) + unpacked_results.extend( + super().unpack_self(value) + ) # Must come last to preserve original ordering + return unpacked_results + + def _set_paramtype(self, paramtype: str) -> None: + super()._set_paramtype(paramtype) + for setpoint in self.setpoints: + setpoint.paramtype = paramtype + def expand_setpoints_helper( parameter: ParameterWithSetpoints, results: ParamDataType | None = None @@ -174,24 +200,7 @@ def expand_setpoints_helper( and its setpoints. """ - if not isinstance(parameter, ParameterWithSetpoints): - raise TypeError( - f"Expanding setpoints only works for ParameterWithSetpoints. " - f"Supplied a {type(parameter)}" - ) - res = [] - setpoint_params = [] - setpoint_data = [] - for setpointparam in parameter.setpoints: - these_setpoints = setpointparam.get() - setpoint_params.append(setpointparam) - setpoint_data.append(these_setpoints) - output_grids = np.meshgrid(*setpoint_data, indexing="ij") - for param, grid in zip(setpoint_params, output_grids): - res.append((param, grid)) - if results is None: - data = parameter.get() + if results is not None: + return parameter.unpack_self(results) else: - data = results - res.append((parameter, data)) - return res + return parameter.unpack_self(parameter.get()) diff --git a/tests/conftest.py b/tests/conftest.py index 032cf96275e0..eae647bbd852 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,10 +15,10 @@ from qcodes.configuration import Config from qcodes.dataset import initialise_database, new_data_set from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.experiment_container import Experiment, new_experiment from qcodes.instrument import Instrument from qcodes.monitor.monitor import Monitor +from qcodes.parameters import ParamSpecBase from qcodes.station import Station settings.register_profile("ci", deadline=1000) diff --git a/tests/dataset/conftest.py b/tests/dataset/conftest.py index b24ce1c65b35..000f7974d5ab 100644 --- a/tests/dataset/conftest.py +++ b/tests/dataset/conftest.py @@ -16,7 +16,7 @@ import qcodes as qc from qcodes.dataset.data_set import DataSet from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.measurements import Measurement from qcodes.dataset.sqlite.database import connect from qcodes.instrument_drivers.mock_instruments import ( @@ -27,7 +27,12 @@ Multi2DSetPointParam2Sizes, setpoint_generator, ) -from qcodes.parameters import ArrayParameter, Parameter, ParameterWithSetpoints +from qcodes.parameters import ( + ArrayParameter, + Parameter, + ParameterWithSetpoints, + ParamSpecBase, +) from qcodes.validators import Arrays, ComplexNumbers, Numbers if TYPE_CHECKING: diff --git a/tests/dataset/measurement/test_measurement_context_manager.py b/tests/dataset/measurement/test_measurement_context_manager.py index 6426698e7d4e..550a59e91c38 100644 --- a/tests/dataset/measurement/test_measurement_context_manager.py +++ b/tests/dataset/measurement/test_measurement_context_manager.py @@ -21,7 +21,6 @@ import qcodes as qc import qcodes.validators as vals from qcodes.dataset.data_set import DataSet, load_by_id -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.experiment_container import new_experiment from qcodes.dataset.export_config import DataExportType from qcodes.dataset.measurements import Measurement @@ -30,10 +29,11 @@ DelegateParameter, ManualParameter, Parameter, + ParamSpecBase, expand_setpoints_helper, ) from qcodes.station import Station -from qcodes.validators import ComplexNumbers +from qcodes.validators import Arrays, ComplexNumbers from tests.common import retry_until_does_not_throw @@ -89,7 +89,6 @@ def test_register_parameter_numbers(DAC, DMM) -> None: Test the registration of scalar QCoDeS parameters """ - parameters = [DAC.ch1, DAC.ch2, DMM.v1, DMM.v2] not_parameters = ("", "Parameter", 0, 1.1, Measurement) meas = Measurement() @@ -110,16 +109,16 @@ def test_register_parameter_numbers(DAC, DMM) -> None: # we allow the registration of the EXACT same parameter twice... meas.register_parameter(my_param) - # ... but not a different parameter with a new name + # ... but not a different parameter with the same name attrs = ["label", "unit"] vals = ["new label", "new unit"] for attr, val in zip(attrs, vals): - old_val = getattr(my_param, attr) - setattr(my_param, attr, val) - match = re.escape("Parameter already registered in this Measurement.") + different_param = ManualParameter(name=my_param.full_name) + assert different_param.full_name == my_param.full_name + setattr(different_param, attr, val) + match = re.escape("already exists in the graph ") with pytest.raises(ValueError, match=match): - meas.register_parameter(my_param) - setattr(my_param, attr, old_val) + meas.register_parameter(different_param) assert len(meas.parameters) == 1 paramspec = meas.parameters[str(my_param)] @@ -128,12 +127,6 @@ def test_register_parameter_numbers(DAC, DMM) -> None: assert paramspec.unit == my_param.unit assert paramspec.type == "numeric" - for parameter in parameters: - with pytest.raises(ValueError): - meas.register_parameter(my_param, setpoints=(parameter,)) - with pytest.raises(ValueError): - meas.register_parameter(my_param, basis=(parameter,)) - meas.register_parameter(DAC.ch2) meas.register_parameter(DMM.v1) meas.register_parameter(DMM.v2) @@ -315,6 +308,7 @@ def test_mixing_array_and_numeric(DAC, bg_writing) -> None: Test that mixing array and numeric types is okay """ meas = Measurement() + DAC.ch2.vals = Arrays() meas.register_parameter(DAC.ch1, paramtype="numeric") meas.register_parameter(DAC.ch2, paramtype="array") @@ -1330,7 +1324,7 @@ def test_datasaver_parameter_with_setpoints_explicitly_expanded( @pytest.mark.usefixtures("experiment") -def test_datasaver_parameter_with_setpoints_partially_expanded_raises( +def test_datasaver_parameter_with_setpoints_that_are_different_raises( channel_array_instrument, DAC ) -> None: random_seed = 1 @@ -1358,8 +1352,9 @@ def test_datasaver_parameter_with_setpoints_partially_expanded_raises( with meas.run() as datasaver: # we seed the random number generator # so we can test that we get the expected numbers + # This fails because a 2D PWS expects 2D setpoints parameter values (ie a grid) np.random.seed(random_seed) - with pytest.raises(ValueError, match="Some of the setpoints of"): + with pytest.raises(ValueError, match="Multiple distinct values found for"): datasaver.add_result((param, param.get()), (sp_param_1, sp_param_1.get())) @@ -2416,15 +2411,20 @@ def test_save_numeric_as_complex_raises(complex_num_instrument, bg_writing) -> N def test_parameter_inference(channel_array_instrument) -> None: chan = channel_array_instrument.channels[0] # default values - assert Measurement._infer_paramtype(chan.temperature, None) is None + assert Measurement._infer_paramtype(chan.temperature, None) == "numeric" assert Measurement._infer_paramtype(chan.dummy_array_parameter, None) == "array" assert ( Measurement._infer_paramtype(chan.dummy_parameter_with_setpoints, None) == "array" ) - assert Measurement._infer_paramtype(chan.dummy_multi_parameter, None) is None - assert Measurement._infer_paramtype(chan.dummy_scalar_multi_parameter, None) is None - assert Measurement._infer_paramtype(chan.dummy_2d_multi_parameter, None) is None + assert Measurement._infer_paramtype(chan.dummy_multi_parameter, None) == "numeric" + assert ( + Measurement._infer_paramtype(chan.dummy_scalar_multi_parameter, None) + == "numeric" + ) + assert ( + Measurement._infer_paramtype(chan.dummy_2d_multi_parameter, None) == "numeric" + ) assert Measurement._infer_paramtype(chan.dummy_text, None) == "text" assert Measurement._infer_paramtype(chan.dummy_complex, None) == "complex" diff --git a/tests/dataset/measurement/test_self_registering_parameters.py b/tests/dataset/measurement/test_self_registering_parameters.py new file mode 100644 index 000000000000..38e01738cf93 --- /dev/null +++ b/tests/dataset/measurement/test_self_registering_parameters.py @@ -0,0 +1,115 @@ +from typing import TYPE_CHECKING + +import pytest + +from qcodes.dataset import Measurement +from qcodes.parameters import ManualParameter + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def control_parameters() -> ( + "Generator[tuple[ManualParameter, ManualParameter, ManualParameter], None, None]" +): + comp1 = ManualParameter("comp1") + comp2 = ManualParameter("comp2") + control1 = ManualParameter("control1") + + comp1.is_controlled_by.add(control1) + comp2.is_controlled_by.add(control1) + control1.has_control_of.add(comp1) + control1.has_control_of.add(comp2) + yield control1, comp1, comp2 + + +@pytest.fixture +def dependent_parameters() -> ( + "Generator[tuple[ManualParameter, ManualParameter, ManualParameter], None, None]" +): + indep1 = ManualParameter("indep1") + indep2 = ManualParameter("indep2") + dep1 = ManualParameter("dep1") + + dep1.depends_on.add(indep1) + dep1.depends_on.add(indep2) + yield dep1, indep1, indep2 + + +def test_registering_control_param_registers_components(control_parameters) -> None: + control1, comp1, comp2 = control_parameters + meas = Measurement() + meas.register_parameter(control1) + + assert comp1 in meas._registered_parameters + assert comp2 in meas._registered_parameters + + +def test_registering_component_param_registers_control(control_parameters) -> None: + control1, comp1, comp2 = control_parameters + meas = Measurement() + meas.register_parameter(comp1) + + assert control1 in meas._registered_parameters + assert comp2 in meas._registered_parameters + + +def test_registering_dependent_param_registers_indeps(dependent_parameters) -> None: + dep1, indep1, indep2 = dependent_parameters + meas = Measurement() + meas.register_parameter(dep1) + + assert indep1 in meas._registered_parameters + assert indep2 in meas._registered_parameters + # Note, registering indep1 is not expected to also register dep1 automatically + + +def test_registering_chain_of_control_parameters(control_parameters) -> None: + control1, comp1, comp2 = control_parameters + comp11 = ManualParameter("comp11") + comp12 = ManualParameter("comp12") + comp21 = ManualParameter("comp21") + comp22 = ManualParameter("comp22") + + new_controls = {comp1: [comp11, comp12], comp2: [comp21, comp22]} + for comp_n, comp_mms in new_controls.items(): + for comp_mm in comp_mms: + comp_n.has_control_of.add(comp_mm) + comp_mm.is_controlled_by.add(comp_n) + + meas = Measurement() + meas.register_parameter(control1) + + assert comp1 in meas._registered_parameters + assert comp2 in meas._registered_parameters + assert comp11 in meas._registered_parameters + assert comp12 in meas._registered_parameters + assert comp21 in meas._registered_parameters + assert comp22 in meas._registered_parameters + + +def test_registering_dependent_param_with_setpoints(dependent_parameters) -> None: + dep1, indep1, indep2 = dependent_parameters + setpoints1 = ManualParameter("setpoints1") + setpoints2 = ManualParameter("setpoints2") + meas = Measurement() + meas.register_parameter(dep1, setpoints=[setpoints1, setpoints2]) + + assert indep1 in meas._registered_parameters + assert indep2 in meas._registered_parameters + assert setpoints1 in meas._registered_parameters + assert setpoints2 in meas._registered_parameters + + dependency_tree = meas._interdeps.dependencies + assert len(dependency_tree) == 1 + assert dep1.param_spec in dependency_tree.keys() + + # Ensure that order in the dependency spec tree is preserved + # Explicit Setpoints first, then internal depends_on parameters + # In the case where setpoints have equal dimension, this is the only + # way to preserve the correct relationship of the multidimensional data + assert dependency_tree[dep1.param_spec][0] == setpoints1.param_spec + assert dependency_tree[dep1.param_spec][1] == setpoints2.param_spec + assert dependency_tree[dep1.param_spec][2] == indep1.param_spec + assert dependency_tree[dep1.param_spec][3] == indep2.param_spec diff --git a/tests/dataset/measurement/test_self_unpacking.py b/tests/dataset/measurement/test_self_unpacking.py new file mode 100644 index 000000000000..7e0245ab17b1 --- /dev/null +++ b/tests/dataset/measurement/test_self_unpacking.py @@ -0,0 +1,225 @@ +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from qcodes.dataset import Measurement +from qcodes.dataset.measurements import ( + _non_numeric_values_are_equal, + _numeric_values_are_equal, + _values_are_equal, +) +from qcodes.parameters import ( + ManualParameter, + Parameter, + ParameterWithSetpoints, + ParamRawDataType, +) +from qcodes.validators import Arrays + +if TYPE_CHECKING: + from collections.abc import Generator + + from qcodes.dataset.data_set_protocol import ValuesType + from qcodes.parameters.parameter_base import ParameterBase + + +class ControllingParameter(Parameter): + def __init__( + self, name: str, components: dict[Parameter, tuple[float, float]] + ) -> None: + super().__init__(name=name, get_cmd=False) + # dict of Parameter to (slope, offset) of components + self._components_dict: dict[Parameter, tuple[float, float]] = components + for param in self._components_dict.keys(): + self._has_control_of.add(param) + param.is_controlled_by.add(self) + + def set_raw(self, value: ParamRawDataType) -> None: + # Set all dependent parameters based on their mapping functions + for param, slope_offset in self._components_dict.items(): + param(value * slope_offset[0] + slope_offset[1]) + + def unpack_self( + self, value: "ValuesType" + ) -> list[tuple["ParameterBase", "ValuesType"]]: + assert isinstance(value, float) + unpacked_results = super().unpack_self(value) + for param, slope_offset in self._components_dict.items(): + unpacked_results.append((param, value * slope_offset[0] + slope_offset[1])) + return unpacked_results + + +@pytest.fixture +def controlling_parameters() -> ( + "Generator[tuple[ControllingParameter, ManualParameter, ManualParameter], None, None]" +): + comp1 = ManualParameter("comp1") + comp2 = ManualParameter("comp2") + control1 = ControllingParameter( + "control1", components={comp1: (1, 0), comp2: (-1, 10)} + ) + yield control1, comp1, comp2 + + +def test_add_result_self_unpack(controlling_parameters, experiment): + control1, comp1, comp2 = controlling_parameters + meas1 = ManualParameter("meas1") + + meas = Measurement(experiment) + meas.register_parameter(meas1, setpoints=[control1]) + + assert all( + param in meas._registered_parameters + for param in (comp1, comp2, control1, meas1) + ) + + with meas.run() as datasaver: + for val in np.linspace(0, 1, 11): + control1(val) + datasaver.add_result((meas1, val + 1), (control1, val)) + ds = datasaver.dataset + + dataset_data = ds.get_parameter_data() + meas1_data = dataset_data.get("meas1", None) + assert meas1_data is not None + assert all( + param_name in meas1_data.keys() + for param_name in ("meas1", "comp1", "comp2", "control1") + ) + assert meas1_data["meas1"] == pytest.approx(np.linspace(1, 2, 11)) + assert meas1_data["control1"] == pytest.approx(np.linspace(0, 1, 11)) + assert meas1_data["comp1"] == pytest.approx(np.linspace(0, 1, 11)) + assert meas1_data["comp2"] == pytest.approx(np.linspace(10, 9, 11)) + + +def test_add_result_self_unpack_with_PWS(controlling_parameters, experiment): + control1, comp1, comp2 = controlling_parameters + pws_setpoints = Parameter( + "pws_setpoints", + get_cmd=lambda: np.linspace(-1, 1, 11), + vals=Arrays(shape=(11,)), + ) + pws = ParameterWithSetpoints( + "pws", + setpoints=(pws_setpoints,), + vals=Arrays(shape=(11,)), + get_cmd=lambda: np.linspace(-2, 2, 11) + comp1(), + ) + + meas = Measurement(experiment) + meas.register_parameter(pws, setpoints=[control1]) + + assert all( + param in meas._registered_parameters + for param in (comp1, comp2, control1, pws, pws_setpoints) + ) + + with meas.run() as datasaver: + for val in np.linspace(0, 1, 11): + control1(val) + datasaver.add_result((pws, pws()), (control1, val)) + ds = datasaver.dataset + + dataset_data = ds.get_parameter_data() + pws_data = dataset_data.get("pws", None) + assert (pws_data) is not None + assert all( + param_name in pws_data.keys() + for param_name in ("pws", "comp1", "comp2", "control1", "pws_setpoints") + ) + expected_setpoints, expected_control = np.meshgrid( + np.linspace(-1, 1, 11), np.linspace(0, 1, 11) + ) + assert pws_data["control1"] == pytest.approx(expected_control) + assert pws_data["comp1"] == pytest.approx(expected_control) + assert pws_data["comp2"] == pytest.approx(10 - expected_control) + assert pws_data["pws_setpoints"] == pytest.approx(expected_setpoints) + + assert pws_data["control1"].shape == (11, 11) + assert pws_data["comp1"].shape == (11, 11) + assert pws_data["comp2"].shape == (11, 11) + assert pws_data["pws"].shape == (11, 11) + assert pws_data["pws_setpoints"].shape == (11, 11) + + +# Testing equality methods for deduplication +def test_non_numeric_values_are_equal() -> None: + # test str + assert _non_numeric_values_are_equal(np.array("string_val"), np.array("string_val")) + assert not _non_numeric_values_are_equal( + np.array("string_val"), np.array("different_string") + ) + + # test Sequence[str] + seq_value1 = ["a", "b", "c", "d"] + seq_value2 = ["a1", "b", "c", "d"] + assert _non_numeric_values_are_equal(np.array(seq_value1), np.array(seq_value1)) + assert not _non_numeric_values_are_equal(np.array(seq_value1), np.array(seq_value2)) + + # test NDArray[str] + arr_value1 = np.array(seq_value1) + arr_value2 = np.array(seq_value2) + assert _non_numeric_values_are_equal(np.array(arr_value1), np.array(arr_value1)) + assert not _non_numeric_values_are_equal(np.array(arr_value1), np.array(arr_value2)) + + +def test_numeric_values_are_equal() -> None: + # test complex + val1 = 1.0 + 3.0 * 1.0j + val2 = 2.0 + 3.0 * 1.0j + assert _numeric_values_are_equal(np.array(val1), np.array(val1)) + assert not _numeric_values_are_equal(np.array(val1), np.array(val2)) + + # test complex w/ nans + val1 = 1.0 + np.nan * 1.0j + val2 = np.nan + 3.0 * 1.0j + assert not _numeric_values_are_equal(np.array(val1), np.array(val2)) + + # test ndarray[complex] + real_1 = np.linspace(0, 1, 11) + imag_1 = np.linspace(0, -1, 11) + val1 = real_1 + 1.0j * imag_1 + + real_2 = np.linspace(0, -1, 11) + imag_2 = np.linspace(0, 1, 11) + val2 = real_2 + 1.0j * imag_2 + assert _numeric_values_are_equal(np.array(val1), np.array(val1)) + assert not _numeric_values_are_equal(np.array(val1), np.array(val2)) + + # test float + val1 = 1.0 + val2 = 2.0 + assert _numeric_values_are_equal(np.array(val1), np.array(val1)) + assert not _numeric_values_are_equal(np.array(val1), np.array(val2)) + + # test ndarray[float] + val1 = np.linspace(0, 1, 11) + val2 = np.linspace(0, -1, 11) + assert _numeric_values_are_equal(np.array(val1), np.array(val1)) + assert not _numeric_values_are_equal(np.array(val1), np.array(val2)) + + +def test_values_are_equal() -> None: + # test Sequence[str] + seq_value1 = ["a", "b", "c", "d"] + seq_value2 = ["a1", "b", "c", "d"] + assert _values_are_equal(np.array(seq_value1), np.array(seq_value1)) + assert not _values_are_equal(np.array(seq_value1), np.array(seq_value2)) + + # test ndarray[complex] + real_1 = np.linspace(0, 1, 11) + imag_1 = np.linspace(0, -1, 11) + val1 = real_1 + 1.0j * imag_1 + + real_2 = np.linspace(0, -1, 11) + imag_2 = np.linspace(0, 1, 11) + val2 = real_2 + 1.0j * imag_2 + assert _values_are_equal(np.array(val1), np.array(val1)) + assert not _values_are_equal(np.array(val1), np.array(val2)) + + # test float + val1 = 1.0 + val2 = 2.0 + assert _values_are_equal(np.array(val1), np.array(val1)) + assert not _values_are_equal(np.array(val1), np.array(val2)) diff --git a/tests/dataset/test__get_data_from_ds.py b/tests/dataset/test__get_data_from_ds.py index 04058916aba7..7223caee3112 100644 --- a/tests/dataset/test__get_data_from_ds.py +++ b/tests/dataset/test__get_data_from_ds.py @@ -6,9 +6,8 @@ from qcodes.dataset.data_export import _get_data_from_ds from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.measurements import Measurement -from qcodes.parameters import ManualParameter +from qcodes.parameters import ManualParameter, ParamSpecBase def test_get_data_by_id_order(dataset) -> None: diff --git a/tests/dataset/test_database_creation_and_upgrading.py b/tests/dataset/test_database_creation_and_upgrading.py index 4d50f9545309..ea2e198ef3b9 100644 --- a/tests/dataset/test_database_creation_and_upgrading.py +++ b/tests/dataset/test_database_creation_and_upgrading.py @@ -23,7 +23,6 @@ ) from qcodes.dataset.data_set import DataSet from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.versioning.v0 import InterDependencies from qcodes.dataset.guids import parse_guid from qcodes.dataset.measurements import Measurement @@ -49,7 +48,7 @@ is_column_in_table, one, ) -from qcodes.parameters import Parameter +from qcodes.parameters import Parameter, ParamSpecBase from tests.common import error_caused_by, skip_if_no_fixtures from tests.dataset.conftest import temporarily_copied_DB diff --git a/tests/dataset/test_datasaver.py b/tests/dataset/test_datasaver.py index b567b223617f..1a4f965adec7 100644 --- a/tests/dataset/test_datasaver.py +++ b/tests/dataset/test_datasaver.py @@ -6,8 +6,9 @@ from qcodes.dataset import new_data_set from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.measurements import DataSaver +from qcodes.parameters import ManualParameter, ParamSpecBase +from qcodes.validators import Strings if TYPE_CHECKING: from collections.abc import Callable @@ -54,7 +55,12 @@ def test_default_callback(bg_writing) -> None: test_set = new_data_set("test-dataset") test_set.add_metadata("snapshot", "reasonable_snapshot") - DataSaver(dataset=test_set, write_period=0, interdeps=InterDependencies_()) + DataSaver( + dataset=test_set, + write_period=0, + interdeps=InterDependencies_(), + registered_parameters=[], + ) test_set.mark_started(start_bg_writer=bg_writing) test_set.mark_completed() assert CALLBACK_SNAPSHOT == "reasonable_snapshot" @@ -75,6 +81,7 @@ def test_numpy_types(bg_writing) -> None: """ p = ParamSpecBase(name="p", paramtype="numeric") + p_param = ManualParameter("p") test_set = new_data_set("test-dataset") test_set.prepare( snapshot={}, @@ -84,7 +91,12 @@ def test_numpy_types(bg_writing) -> None: idps = InterDependencies_(standalones=(p,)) - data_saver = DataSaver(dataset=test_set, write_period=0, interdeps=idps) + data_saver = DataSaver( + dataset=test_set, + write_period=0, + interdeps=idps, + registered_parameters=[p_param], + ) dtypes: list[Callable] = [ np.int8, @@ -128,14 +140,19 @@ def test_saving_numeric_values_as_text(numeric_type, bg_writing) -> None: Test the saving numeric values into 'text' parameter raises an exception """ p = ParamSpecBase("p", "text") - + p_param = ManualParameter("p", vals=Strings()) test_set = new_data_set("test-dataset") test_set.set_interdependencies(InterDependencies_(standalones=(p,))) test_set.mark_started(start_bg_writer=bg_writing) idps = InterDependencies_(standalones=(p,)) - data_saver = DataSaver(dataset=test_set, write_period=0, interdeps=idps) + data_saver = DataSaver( + dataset=test_set, + write_period=0, + interdeps=idps, + registered_parameters=[p_param], + ) try: value = numeric_type(2) @@ -160,6 +177,7 @@ def test_duplicated_parameter_raises() -> None: Test that passing same parameter multiple times to ``add_result`` raises an exception """ p = ParamSpecBase("p", "text") + p_param = ManualParameter("p", vals=Strings()) test_set = new_data_set("test-dataset") test_set.set_interdependencies(InterDependencies_(standalones=(p,))) @@ -167,7 +185,12 @@ def test_duplicated_parameter_raises() -> None: idps = InterDependencies_(standalones=(p,)) - data_saver = DataSaver(dataset=test_set, write_period=0, interdeps=idps) + data_saver = DataSaver( + dataset=test_set, + write_period=0, + interdeps=idps, + registered_parameters=[p_param], + ) try: msg = re.escape( diff --git a/tests/dataset/test_dataset_basic.py b/tests/dataset/test_dataset_basic.py index 38514674c909..c6e6884d34ef 100644 --- a/tests/dataset/test_dataset_basic.py +++ b/tests/dataset/test_dataset_basic.py @@ -23,12 +23,12 @@ from qcodes.dataset.data_set import DataSet from qcodes.dataset.data_set_protocol import CompletedError from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.rundescriber import RunDescriber from qcodes.dataset.guids import parse_guid from qcodes.dataset.sqlite.connection import atomic, path_to_dbfile from qcodes.dataset.sqlite.database import _convert_array, get_DB_location from qcodes.dataset.sqlite.queries import _rewrite_timestamps, _unicode_categories +from qcodes.parameters import ParamSpecBase from qcodes.utils.types import complex_types, numpy_complex, numpy_floats, numpy_ints from tests.common import error_caused_by from tests.dataset.helper_functions import verify_data_dict diff --git a/tests/dataset/test_dataset_export.py b/tests/dataset/test_dataset_export.py index 59477a87a2ba..74e003f997ae 100644 --- a/tests/dataset/test_dataset_export.py +++ b/tests/dataset/test_dataset_export.py @@ -32,14 +32,13 @@ from qcodes.dataset.data_set import DataSet from qcodes.dataset.data_set_in_memory import DataSetInMem from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.versioning import serialization as serial from qcodes.dataset.experiment_container import Experiment from qcodes.dataset.export_config import DataExportType from qcodes.dataset.exporters.export_to_pandas import _generate_pandas_index from qcodes.dataset.exporters.export_to_xarray import _calculate_index_shape from qcodes.dataset.linked_datasets.links import links_to_str -from qcodes.parameters import ManualParameter, Parameter +from qcodes.parameters import ManualParameter, Parameter, ParamSpecBase from qcodes.utils.deprecate import QCoDeSDeprecationWarning if TYPE_CHECKING: diff --git a/tests/dataset/test_dataset_in_memory.py b/tests/dataset/test_dataset_in_memory.py index e6266f19af11..24aee0303bc6 100644 --- a/tests/dataset/test_dataset_in_memory.py +++ b/tests/dataset/test_dataset_in_memory.py @@ -18,9 +18,8 @@ from qcodes.dataset.data_set_in_memory import DataSetInMem, load_from_file from qcodes.dataset.data_set_protocol import DataSetType from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.sqlite.connection import AtomicConnection, atomic_transaction -from qcodes.parameters import ManualParameter, Parameter +from qcodes.parameters import ManualParameter, Parameter, ParamSpecBase from qcodes.station import Station if TYPE_CHECKING: diff --git a/tests/dataset/test_dataset_in_memory_bacis.py b/tests/dataset/test_dataset_in_memory_bacis.py index 9ea2320d9952..1531642d476a 100644 --- a/tests/dataset/test_dataset_in_memory_bacis.py +++ b/tests/dataset/test_dataset_in_memory_bacis.py @@ -7,7 +7,7 @@ from qcodes.dataset import connect, load_by_guid, load_or_create_experiment from qcodes.dataset.data_set_in_memory import DataSetInMem from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase +from qcodes.parameters import ParamSpecBase def test_create_dataset_in_memory_explicit_db(empty_temp_db) -> None: diff --git a/tests/dataset/test_dependencies.py b/tests/dataset/test_dependencies.py index efdfbb3cfbe9..e2607655ad5a 100644 --- a/tests/dataset/test_dependencies.py +++ b/tests/dataset/test_dependencies.py @@ -9,9 +9,10 @@ IncompleteSubsetError, InterDependencies_, ) -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec from qcodes.dataset.descriptions.versioning.converters import new_to_old, old_to_new from qcodes.dataset.descriptions.versioning.v0 import InterDependencies +from qcodes.parameters import ParamSpecBase from tests.common import error_caused_by diff --git a/tests/dataset/test_nested_measurements.py b/tests/dataset/test_nested_measurements.py index 68fcbc4eee8b..c4db84f97fd6 100644 --- a/tests/dataset/test_nested_measurements.py +++ b/tests/dataset/test_nested_measurements.py @@ -6,8 +6,8 @@ from qcodes.dataset import Measurement, new_data_set from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.sqlite.connection import atomic_transaction +from qcodes.parameters import ParamSpecBase from tests.common import retry_until_does_not_throw VALUE = str | float | list | np.ndarray | bool diff --git a/tests/dataset/test_paramspec.py b/tests/dataset/test_paramspec.py index d20939b93a31..ed344a012032 100644 --- a/tests/dataset/test_paramspec.py +++ b/tests/dataset/test_paramspec.py @@ -6,7 +6,8 @@ from hypothesis import assume, given from numpy import ndarray -from qcodes.dataset.descriptions.param_spec import ParamSpec, ParamSpecBase +from qcodes.dataset.descriptions.param_spec import ParamSpec +from qcodes.parameters import ParamSpecBase def valid_identifier(**kwargs): diff --git a/tests/dataset/test_sqlite_base.py b/tests/dataset/test_sqlite_base.py index 52db258b12c2..d55ec2f00f9c 100644 --- a/tests/dataset/test_sqlite_base.py +++ b/tests/dataset/test_sqlite_base.py @@ -18,7 +18,6 @@ import qcodes.dataset.descriptions.versioning.serialization as serial from qcodes.dataset.data_set import DataSet from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.descriptions.rundescriber import RunDescriber from qcodes.dataset.experiment_container import load_or_create_experiment from qcodes.dataset.guids import generate_guid @@ -30,6 +29,7 @@ from qcodes.dataset.sqlite import query_helpers as mut_help from qcodes.dataset.sqlite.connection import atomic_transaction, path_to_dbfile from qcodes.dataset.sqlite.database import get_DB_location +from qcodes.parameters import ParamSpecBase from tests.common import error_caused_by from .helper_functions import verify_data_dict diff --git a/tests/dataset/test_string_data.py b/tests/dataset/test_string_data.py index 6349d0d8438a..3193f62ca1db 100644 --- a/tests/dataset/test_string_data.py +++ b/tests/dataset/test_string_data.py @@ -8,9 +8,9 @@ from qcodes.dataset import load_by_id, new_data_set from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.measurements import DataSaver, Measurement -from qcodes.parameters import Parameter +from qcodes.parameters import ManualParameter, Parameter, ParamSpecBase +from qcodes.validators import Strings def test_string_via_dataset(experiment) -> None: @@ -36,14 +36,19 @@ def test_string_via_datasaver(experiment) -> None: Test that we can save text into database via DataSaver API """ p = ParamSpecBase(name="p", paramtype="text") - + p_param = ManualParameter("p", vals=Strings()) test_set = new_data_set("test-dataset") idps = InterDependencies_(standalones=(p,)) test_set.prepare(snapshot={}, interdeps=idps) idps = InterDependencies_(standalones=(p,)) - data_saver = DataSaver(dataset=test_set, write_period=0, interdeps=idps) + data_saver = DataSaver( + dataset=test_set, + write_period=0, + interdeps=idps, + registered_parameters=[p_param], + ) data_saver.add_result(("p", "some text")) data_saver.flush_data_to_database() @@ -110,7 +115,7 @@ def test_string_with_wrong_paramtype_via_datasaver() -> None: parameter via DataSaver object """ p = ParamSpecBase("p", "numeric") - + p_param = ManualParameter("p") test_set = new_data_set("test-dataset") idps = InterDependencies_(standalones=(p,)) test_set.set_interdependencies(idps) @@ -118,7 +123,12 @@ def test_string_with_wrong_paramtype_via_datasaver() -> None: idps = InterDependencies_(standalones=(p,)) - data_saver = DataSaver(dataset=test_set, write_period=0, interdeps=idps) + data_saver = DataSaver( + dataset=test_set, + write_period=0, + interdeps=idps, + registered_parameters=[p_param], + ) try: msg = re.escape( diff --git a/tests/dataset/test_subscribing.py b/tests/dataset/test_subscribing.py index 3f2b576bb93d..07750119070d 100644 --- a/tests/dataset/test_subscribing.py +++ b/tests/dataset/test_subscribing.py @@ -8,8 +8,8 @@ import qcodes from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase from qcodes.dataset.sqlite.connection import atomic_transaction +from qcodes.parameters import ParamSpecBase from tests.common import retry_until_does_not_throw log = logging.getLogger(__name__) diff --git a/tests/dataset_generators.py b/tests/dataset_generators.py index e7b76d5055f7..ff891f6cb69b 100644 --- a/tests/dataset_generators.py +++ b/tests/dataset_generators.py @@ -1,7 +1,7 @@ import numpy as np from qcodes.dataset.descriptions.dependencies import InterDependencies_ -from qcodes.dataset.descriptions.param_spec import ParamSpecBase +from qcodes.parameters import ParamSpecBase def dataset_with_outliers_generator( diff --git a/tests/parameter/test_parameter_set.py b/tests/parameter/test_parameter_set.py new file mode 100644 index 000000000000..9a69d49459a6 --- /dev/null +++ b/tests/parameter/test_parameter_set.py @@ -0,0 +1,88 @@ +from typing import TYPE_CHECKING + +import pytest + +from qcodes.parameters import ManualParameter, ParameterSet + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def manual_parameters() -> "Generator[tuple[ManualParameter, ...], None, None]": + param1 = ManualParameter("param1") + param2 = ManualParameter("param2") + param3 = ManualParameter("param3") + + yield param1, param2, param3 + + +def test_parameter_set_preserves_order( + manual_parameters: tuple[ManualParameter, ...], +) -> None: + param1, param2, param3 = manual_parameters + param_set = ParameterSet((param1, param2, param3)) + param_list = list(param_set) + assert param_list[0] is param1 + assert param_list[1] is param2 + assert param_list[2] is param3 + + param_set = ParameterSet((param3, param1, param2)) + param_list = list(param_set) + assert param_list[0] is param3 + assert param_list[1] is param1 + assert param_list[2] is param2 + + param_set.clear() + assert len(param_set) == 0 + param_set.update(ParameterSet((param1, param2, param3))) + param_list = list(param_set) + assert param_list[0] is param1 + assert param_list[1] is param2 + assert param_list[2] is param3 + + param_set.clear() + assert len(param_set) == 0 + param_set.add(param1) + param_set.add(param2) + param_set.add(param3) + param_list = list(param_set) + assert param_list[0] is param1 + assert param_list[1] is param2 + assert param_list[2] is param3 + + param_set.remove(param2) + param_list = list(param_set) + assert param_list[0] is param1 + assert param_list[1] is param3 + + param_set.add(param1) + assert len(param_set) == 2 + param_list = list(param_set) + assert param_list[0] is param1 + assert param_list[1] is param3 + + +def test_parameter_set_operations( + manual_parameters: tuple[ManualParameter, ...], +) -> None: + param1, param2, param3 = manual_parameters + param_set1 = ParameterSet((param1, param2, param3)) + param_set2 = ParameterSet((param3, param1, param2)) + param_subset1 = ParameterSet((param1, param2)) + + assert param_set1 == param_set2 # Should this fail, because the order is different? + assert param_subset1 < param_set1 + assert param_set1 > param_subset1 + assert param_subset1 <= param_set1 + assert param_set1 >= param_subset1 + + union_set = ParameterSet([param1]) | ParameterSet([param2]) + assert len(union_set) == 2 + + difference_set = param_set2 - ParameterSet([param1]) + assert len(difference_set) == 2 + + intersection_set = ParameterSet([param1, param2]) & ParameterSet([param2, param3]) + assert len(intersection_set) == 1 + assert param2 in intersection_set