@@ -279,6 +279,69 @@ def f(x):
279279 simple_run (learner , 10 )
280280
281281
282+ def test_learner2d_vector_valued_function ():
283+ """Test that Learner2D handles vector-valued functions correctly.
284+
285+ This test verifies that the deviations function works properly when
286+ the function returns a vector (array/list) of values instead of a scalar.
287+ """
288+
289+ from adaptive import Learner2D
290+
291+ def vector_function (xy ):
292+ """A 2D function that returns a 3-element vector."""
293+ x , y = xy
294+ return [x + y , x * y , x - y ] # Returns 3-element vector
295+
296+ # Create learner with vector-valued function
297+ learner = Learner2D (vector_function , bounds = [(- 1 , 1 ), (- 1 , 1 )])
298+
299+ # Add some initial points
300+ points = [
301+ (0.0 , 0.0 ),
302+ (1.0 , 0.0 ),
303+ (0.0 , 1.0 ),
304+ (1.0 , 1.0 ),
305+ (0.5 , 0.5 ),
306+ (- 0.5 , 0.5 ),
307+ (0.5 , - 0.5 ),
308+ (- 1.0 , - 1.0 ),
309+ ]
310+
311+ for point in points :
312+ value = vector_function (point )
313+ learner .tell (point , value )
314+
315+ # Run the learner to trigger deviations calculation
316+ # This should not raise any errors
317+ learner .ask (10 )
318+
319+ # Verify that the interpolator is created (ip is a property that may return a function)
320+ assert hasattr (learner , "ip" )
321+
322+ # Check the internal interpolator if it exists
323+ if hasattr (learner , "_ip" ) and learner ._ip is not None :
324+ # Check that values have the correct shape
325+ assert learner ._ip .values .shape [1 ] == 3 # 3 output dimensions
326+
327+ # Test that we can evaluate the interpolated function
328+ test_point = (0.25 , 0.25 )
329+ ip_func = learner .interpolator (scaled = True ) # Get the interpolator function
330+ if ip_func is not None :
331+ interpolated_value = ip_func (test_point )
332+ assert len (interpolated_value ) == 3
333+
334+ # Run more iterations to ensure deviations are computed correctly
335+ simple_run (learner , 20 )
336+
337+ # Final verification
338+ assert len (learner .data ) > len (points ) # Learner added more points
339+
340+ # Check that all values in data are vectors
341+ for _point , value in learner .data .items ():
342+ assert len (value ) == 3 , f"Expected 3-element vector, got { value } "
343+
344+
282345@run_with (Learner1D , Learner2D , LearnerND , SequenceLearner , AverageLearner1D )
283346def test_adding_existing_data_is_idempotent (learner_type , f , learner_kwargs ):
284347 """Adding already existing data is an idempotent operation.
0 commit comments