Skip to content

Commit 36454e3

Browse files
committed
Add test where loop is over infeered parameter
1 parent c29c3ac commit 36454e3

1 file changed

Lines changed: 77 additions & 0 deletions

File tree

tests/dataset/test_dataset_export.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,3 +1894,80 @@ def test_measurement_hypothesis_nd_grid_with_inferred_param(
18941894
assert set(inf_indexes.keys()) == set(inf_sp_names)
18951895
for dim in inf_sp_names:
18961896
assert inf_indexes[dim].equals(xr_ds.indexes[dim])
1897+
1898+
1899+
def test_measurement_2d_with_inferred_setpoint(
1900+
experiment: Experiment, caplog: LogCaptureFixture
1901+
) -> None:
1902+
"""
1903+
Sweep two parameters (x, y) where y is inferred from one or more basis parameters.
1904+
Verify that xarray export uses direct method, signal dims match, and basis
1905+
parameters appear as inferred coordinates with indexes corresponding to y.
1906+
"""
1907+
# Grid sizes
1908+
nx, ny = 3, 4
1909+
x_vals = np.linspace(0.0, 2.0, nx)
1910+
# Define basis parameters for y and compute y from these
1911+
y_b0_vals = np.linspace(10.0, 13.0, ny)
1912+
y_b1_vals = np.linspace(-1.0, 2.0, ny)
1913+
# y is inferred from (y_b0, y_b1)
1914+
y_vals = y_b0_vals + 2.0 * y_b1_vals
1915+
1916+
meas = Measurement(exp=experiment, name="2d_with_inferred_setpoint")
1917+
# Register setpoint x
1918+
meas.register_custom_parameter("x", paramtype="numeric")
1919+
# Register basis params for y
1920+
meas.register_custom_parameter("y_b0", paramtype="numeric")
1921+
meas.register_custom_parameter("y_b1", paramtype="numeric")
1922+
# Register y as setpoint inferred from basis
1923+
meas.register_custom_parameter("y", basis=("y_b0", "y_b1"), paramtype="numeric")
1924+
# Register measured parameter depending on (x, y)
1925+
meas.register_custom_parameter("signal", setpoints=("x", "y"), paramtype="numeric")
1926+
meas.set_shapes({"signal": (nx, ny)})
1927+
1928+
with meas.run() as datasaver:
1929+
for ix in range(nx):
1930+
for iy in range(ny):
1931+
x = float(x_vals[ix])
1932+
y_b0 = float(y_b0_vals[iy])
1933+
y_b1 = float(y_b1_vals[iy])
1934+
y = float(y_vals[iy])
1935+
signal = x + 3.0 * y # deterministic function
1936+
datasaver.add_result(
1937+
("x", x),
1938+
("y_b0", y_b0),
1939+
("y_b1", y_b1),
1940+
("y", y),
1941+
("signal", signal),
1942+
)
1943+
1944+
ds = datasaver.dataset
1945+
1946+
caplog.clear()
1947+
with caplog.at_level(logging.INFO):
1948+
xr_ds = ds.to_xarray_dataset()
1949+
1950+
assert any(
1951+
"Exporting signal to xarray using direct method" in record.message
1952+
for record in caplog.records
1953+
)
1954+
1955+
# Sizes and coords
1956+
assert xr_ds.sizes == {"x": nx, "y": ny}
1957+
np.testing.assert_allclose(xr_ds.coords["x"].values, x_vals)
1958+
np.testing.assert_allclose(xr_ds.coords["y"].values, y_vals)
1959+
1960+
# Signal dims and values
1961+
assert xr_ds["signal"].dims == ("x", "y")
1962+
expected_signal = x_vals[:, None] + 3.0 * y_vals[None, :]
1963+
np.testing.assert_allclose(xr_ds["signal"].values, expected_signal)
1964+
1965+
# Inferred coords for y_b0 and y_b1 exist with dims only along y
1966+
for name, vals in ("y_b0", y_b0_vals), ("y_b1", y_b1_vals):
1967+
assert name in xr_ds.coords
1968+
assert xr_ds.coords[name].dims == ("y",)
1969+
np.testing.assert_allclose(xr_ds.coords[name].values, vals)
1970+
# Indexes of inferred coords should correspond to the y axis index
1971+
inf_idx = xr_ds.coords[name].indexes
1972+
assert set(inf_idx.keys()) == {"y"}
1973+
assert inf_idx["y"].equals(xr_ds.indexes["y"])

0 commit comments

Comments
 (0)