@@ -1894,3 +1894,80 @@ def test_measurement_hypothesis_nd_grid_with_inferred_param(
18941894 assert set (inf_indexes .keys ()) == set (inf_sp_names )
18951895 for dim in inf_sp_names :
18961896 assert inf_indexes [dim ].equals (xr_ds .indexes [dim ])
1897+
1898+
1899+ def test_measurement_2d_with_inferred_setpoint (
1900+ experiment : Experiment , caplog : LogCaptureFixture
1901+ ) -> None :
1902+ """
1903+ Sweep two parameters (x, y) where y is inferred from one or more basis parameters.
1904+ Verify that xarray export uses direct method, signal dims match, and basis
1905+ parameters appear as inferred coordinates with indexes corresponding to y.
1906+ """
1907+ # Grid sizes
1908+ nx , ny = 3 , 4
1909+ x_vals = np .linspace (0.0 , 2.0 , nx )
1910+ # Define basis parameters for y and compute y from these
1911+ y_b0_vals = np .linspace (10.0 , 13.0 , ny )
1912+ y_b1_vals = np .linspace (- 1.0 , 2.0 , ny )
1913+ # y is inferred from (y_b0, y_b1)
1914+ y_vals = y_b0_vals + 2.0 * y_b1_vals
1915+
1916+ meas = Measurement (exp = experiment , name = "2d_with_inferred_setpoint" )
1917+ # Register setpoint x
1918+ meas .register_custom_parameter ("x" , paramtype = "numeric" )
1919+ # Register basis params for y
1920+ meas .register_custom_parameter ("y_b0" , paramtype = "numeric" )
1921+ meas .register_custom_parameter ("y_b1" , paramtype = "numeric" )
1922+ # Register y as setpoint inferred from basis
1923+ meas .register_custom_parameter ("y" , basis = ("y_b0" , "y_b1" ), paramtype = "numeric" )
1924+ # Register measured parameter depending on (x, y)
1925+ meas .register_custom_parameter ("signal" , setpoints = ("x" , "y" ), paramtype = "numeric" )
1926+ meas .set_shapes ({"signal" : (nx , ny )})
1927+
1928+ with meas .run () as datasaver :
1929+ for ix in range (nx ):
1930+ for iy in range (ny ):
1931+ x = float (x_vals [ix ])
1932+ y_b0 = float (y_b0_vals [iy ])
1933+ y_b1 = float (y_b1_vals [iy ])
1934+ y = float (y_vals [iy ])
1935+ signal = x + 3.0 * y # deterministic function
1936+ datasaver .add_result (
1937+ ("x" , x ),
1938+ ("y_b0" , y_b0 ),
1939+ ("y_b1" , y_b1 ),
1940+ ("y" , y ),
1941+ ("signal" , signal ),
1942+ )
1943+
1944+ ds = datasaver .dataset
1945+
1946+ caplog .clear ()
1947+ with caplog .at_level (logging .INFO ):
1948+ xr_ds = ds .to_xarray_dataset ()
1949+
1950+ assert any (
1951+ "Exporting signal to xarray using direct method" in record .message
1952+ for record in caplog .records
1953+ )
1954+
1955+ # Sizes and coords
1956+ assert xr_ds .sizes == {"x" : nx , "y" : ny }
1957+ np .testing .assert_allclose (xr_ds .coords ["x" ].values , x_vals )
1958+ np .testing .assert_allclose (xr_ds .coords ["y" ].values , y_vals )
1959+
1960+ # Signal dims and values
1961+ assert xr_ds ["signal" ].dims == ("x" , "y" )
1962+ expected_signal = x_vals [:, None ] + 3.0 * y_vals [None , :]
1963+ np .testing .assert_allclose (xr_ds ["signal" ].values , expected_signal )
1964+
1965+ # Inferred coords for y_b0 and y_b1 exist with dims only along y
1966+ for name , vals in ("y_b0" , y_b0_vals ), ("y_b1" , y_b1_vals ):
1967+ assert name in xr_ds .coords
1968+ assert xr_ds .coords [name ].dims == ("y" ,)
1969+ np .testing .assert_allclose (xr_ds .coords [name ].values , vals )
1970+ # Indexes of inferred coords should correspond to the y axis index
1971+ inf_idx = xr_ds .coords [name ].indexes
1972+ assert set (inf_idx .keys ()) == {"y" }
1973+ assert inf_idx ["y" ].equals (xr_ds .indexes ["y" ])
0 commit comments