Skip to content

Commit b35268d

Browse files
MarcusMNoackclaude
andcommitted
Add serialization test coverage for GPOptimizer/fvGPOptimizer
Extend test_pickle with value-level round-trip checks of the gpcam config attributes (cost_function, _gp2Scale, gp2Scale_batch_size, _linalg_mode, ram_economy, _args, logging, multi_task, x_out, gp) and an fvGPOptimizer round-trip exercising multi_task=True and x_out. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 8a13add commit b35268d

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

tests/test_gpCAM.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def meanf(x, hps):
4040
#This is a simple mean function but it can be arbitrarily complex using many hyperparameters.
4141
return np.sin(hps[2] * x[:,0])
4242

43+
def cost_f(origin, x):
44+
#module-level so it pickles by reference (used by the serialization test)
45+
return np.ones(len(x))
46+
4347
#class TestgpCAM(unittest.TestCase):
4448
# """Tests for `gpcam` package."""
4549

@@ -279,6 +283,45 @@ def is_pickle_equal(obj):
279283
assert is_pickle_equal(my_gpo.data)
280284
assert is_pickle_equal(my_gpo.marginal_likelihood.kv)
281285

286+
#TEST4
287+
#gpcam-level config attributes must round-trip by VALUE (not just key presence)
288+
def cfg_equal(a, b):
289+
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
290+
return np.array_equal(a, b)
291+
return a is b or a == b
292+
293+
my_gpo = GPOptimizer(x_data, y_data,
294+
init_hyperparameters=np.ones((4)) / 10.,
295+
compute_device="cpu",
296+
linalg_mode="Chol",
297+
ram_economy=True,
298+
gp2Scale_batch_size=5000,
299+
cost_function=cost_f,
300+
args={"k": 7.})
301+
my_gpo2 = pickle.loads(pickle.dumps(my_gpo))
302+
for attr in ["cost_function", "init_hyperparameters", "compute_device",
303+
"kernel_function", "kernel_function_grad",
304+
"noise_function", "noise_function_grad",
305+
"prior_mean_function", "prior_mean_function_grad",
306+
"_gp2Scale", "gp2Scale_batch_size", "_linalg_mode",
307+
"ram_economy", "_args", "logging", "multi_task", "x_out", "gp"]:
308+
assert cfg_equal(getattr(my_gpo, attr), getattr(my_gpo2, attr)), attr
309+
assert my_gpo2._dask_client is None
310+
311+
#TEST5
312+
#multi-task (fvGPOptimizer) pickling: exercises multi_task=True and x_out
313+
x_mt = np.random.uniform(size=(10, 2))
314+
y_mt = np.column_stack([np.sin(x_mt[:, 0]), np.cos(x_mt[:, 1])])
315+
fv = fvGPOptimizer(x_mt, y_mt, kernel_function=mt_kernel,
316+
init_hyperparameters=np.array([1., 1., 1.]))
317+
fv2 = pickle.loads(pickle.dumps(fv))
318+
assert fv2.multi_task is True
319+
assert np.array_equal(fv.x_out, fv2.x_out)
320+
assert np.all(fv.x_data == fv2.x_data)
321+
assert np.all(fv.y_data == fv2.y_data)
322+
assert np.all(fv.hyperparameters == fv2.hyperparameters)
323+
assert is_pickle_equal(fv)
324+
282325

283326

284327

0 commit comments

Comments
 (0)