Skip to content

Commit ad0ebe1

Browse files
committed
add tests for exp. data history
1 parent d6983b5 commit ad0ebe1

1 file changed

Lines changed: 348 additions & 0 deletions

File tree

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
"""Tests for SearchHistory and history tracking in experiments."""
2+
3+
# copyright: hyperactive developers, MIT License (see LICENSE file)
4+
5+
import pytest
6+
7+
from hyperactive.base import SearchHistory
8+
9+
10+
class TestSearchHistory:
11+
"""Tests for the SearchHistory class."""
12+
13+
def test_init_empty(self):
14+
"""Test that a new SearchHistory is empty."""
15+
history = SearchHistory()
16+
assert history.n_trials == 0
17+
assert history.n_runs == 1
18+
assert history.history == []
19+
assert history.best_trial is None
20+
assert history.best_score is None
21+
assert history.best_params is None
22+
23+
def test_record_single_trial(self):
24+
"""Test recording a single trial."""
25+
history = SearchHistory()
26+
history.record(
27+
params={"x": 1, "y": 2},
28+
score=0.5,
29+
metadata={"time": 1.0},
30+
eval_time=0.1,
31+
)
32+
33+
assert history.n_trials == 1
34+
assert history.n_runs == 1
35+
36+
trial = history.history[0]
37+
assert trial["iteration"] == 0
38+
assert trial["run_id"] == 0
39+
assert trial["params"] == {"x": 1, "y": 2}
40+
assert trial["score"] == 0.5
41+
assert trial["metadata"] == {"time": 1.0}
42+
assert trial["eval_time"] == 0.1
43+
44+
def test_record_multiple_trials(self):
45+
"""Test recording multiple trials in one run."""
46+
history = SearchHistory()
47+
48+
for i in range(5):
49+
history.record(
50+
params={"x": i},
51+
score=float(i),
52+
metadata={},
53+
eval_time=0.1,
54+
)
55+
56+
assert history.n_trials == 5
57+
assert history.n_runs == 1
58+
59+
# Check iteration is global
60+
for i, trial in enumerate(history.history):
61+
assert trial["iteration"] == i
62+
assert trial["run_id"] == 0
63+
64+
def test_multiple_runs(self):
65+
"""Test that run_id increments across multiple runs."""
66+
history = SearchHistory()
67+
68+
history.record(params={"x": 1}, score=0.1, metadata={}, eval_time=0.1)
69+
history.record(params={"x": 2}, score=0.2, metadata={}, eval_time=0.1)
70+
71+
history.new_run()
72+
history.record(params={"x": 3}, score=0.3, metadata={}, eval_time=0.1)
73+
74+
assert history.n_trials == 3
75+
assert history.n_runs == 2
76+
77+
# Check run_ids
78+
assert history.history[0]["run_id"] == 0
79+
assert history.history[1]["run_id"] == 0
80+
assert history.history[2]["run_id"] == 1
81+
82+
# Iteration is global
83+
assert history.history[0]["iteration"] == 0
84+
assert history.history[1]["iteration"] == 1
85+
assert history.history[2]["iteration"] == 2
86+
87+
def test_best_trial(self):
88+
"""Test that best_trial returns the trial with highest score."""
89+
history = SearchHistory()
90+
history.record(params={"x": 1}, score=0.5, metadata={}, eval_time=0.1)
91+
history.record(params={"x": 2}, score=0.9, metadata={}, eval_time=0.1)
92+
history.record(params={"x": 3}, score=0.3, metadata={}, eval_time=0.1)
93+
94+
best = history.best_trial
95+
assert best["score"] == 0.9
96+
assert best["params"] == {"x": 2}
97+
assert history.best_score == 0.9
98+
assert history.best_params == {"x": 2}
99+
100+
def test_get_run(self):
101+
"""Test filtering trials by run_id."""
102+
history = SearchHistory()
103+
104+
history.record(params={"x": 1}, score=0.1, metadata={}, eval_time=0.1)
105+
history.record(params={"x": 2}, score=0.2, metadata={}, eval_time=0.1)
106+
107+
history.new_run()
108+
history.record(params={"x": 3}, score=0.3, metadata={}, eval_time=0.1)
109+
110+
run0 = history.get_run(0)
111+
run1 = history.get_run(1)
112+
113+
assert len(run0) == 2
114+
assert len(run1) == 1
115+
assert run0[0]["params"] == {"x": 1}
116+
assert run0[1]["params"] == {"x": 2}
117+
assert run1[0]["params"] == {"x": 3}
118+
119+
def test_clear(self):
120+
"""Test that clear resets all history."""
121+
history = SearchHistory()
122+
history.record(params={"x": 1}, score=0.5, metadata={}, eval_time=0.1)
123+
124+
history.clear()
125+
126+
assert history.n_trials == 0
127+
assert history.n_runs == 1
128+
assert history.history == []
129+
130+
def test_params_are_copied(self):
131+
"""Test that recorded params are copied, not referenced."""
132+
history = SearchHistory()
133+
params = {"x": 1}
134+
history.record(params=params, score=0.5, metadata={}, eval_time=0.1)
135+
136+
# Modify original
137+
params["x"] = 999
138+
139+
# Recorded params should be unchanged
140+
assert history.history[0]["params"]["x"] == 1
141+
142+
def test_metadata_none_becomes_empty_dict(self):
143+
"""Test that None metadata becomes an empty dict."""
144+
history = SearchHistory()
145+
history.record(params={"x": 1}, score=0.5, metadata=None, eval_time=0.1)
146+
147+
assert history.history[0]["metadata"] == {}
148+
149+
def test_len(self):
150+
"""Test __len__ returns number of trials."""
151+
history = SearchHistory()
152+
assert len(history) == 0
153+
154+
history.record(params={"x": 1}, score=0.5, metadata={}, eval_time=0.1)
155+
assert len(history) == 1
156+
157+
def test_repr(self):
158+
"""Test __repr__ is informative."""
159+
history = SearchHistory()
160+
history.record(params={"x": 1}, score=0.5, metadata={}, eval_time=0.1)
161+
162+
repr_str = repr(history)
163+
assert "n_trials=1" in repr_str
164+
assert "n_runs=1" in repr_str
165+
166+
167+
class TestExperimentDataIntegration:
168+
"""Tests for data tracking in BaseExperiment via accessor pattern."""
169+
170+
def test_experiment_has_data_accessor(self):
171+
"""Test that BaseExperiment has data accessor."""
172+
from hyperactive.base import SearchHistory
173+
from hyperactive.experiment.func import FunctionExperiment
174+
175+
def objective(params):
176+
return params["x"] ** 2
177+
178+
exp = FunctionExperiment(objective)
179+
180+
assert hasattr(exp, "data")
181+
assert isinstance(exp.data, SearchHistory)
182+
assert exp.data.history == []
183+
assert exp.data.n_trials == 0
184+
185+
def test_evaluate_records_data(self):
186+
"""Test that evaluate() records trials to data."""
187+
from hyperactive.experiment.func import FunctionExperiment
188+
189+
def objective(params):
190+
return params["x"] ** 2
191+
192+
exp = FunctionExperiment(objective)
193+
194+
exp.evaluate({"x": 2})
195+
exp.evaluate({"x": 3})
196+
197+
assert exp.data.n_trials == 2
198+
assert len(exp.data.history) == 2
199+
200+
trial0 = exp.data.history[0]
201+
assert trial0["params"] == {"x": 2}
202+
assert trial0["score"] == 4.0
203+
assert trial0["iteration"] == 0
204+
assert trial0["run_id"] == 0
205+
assert "eval_time" in trial0
206+
207+
def test_score_records_via_evaluate(self):
208+
"""Test that score() also records data (via evaluate)."""
209+
from hyperactive.experiment.func import FunctionExperiment
210+
211+
def objective(params):
212+
return params["x"] ** 2
213+
214+
exp = FunctionExperiment(objective)
215+
216+
exp.score({"x": 5})
217+
218+
assert exp.data.n_trials == 1
219+
assert exp.data.history[0]["score"] == 25.0
220+
221+
def test_best_trial_property(self):
222+
"""Test best_trial property via accessor."""
223+
from hyperactive.experiment.func import FunctionExperiment
224+
225+
def objective(params):
226+
return params["x"]
227+
228+
exp = FunctionExperiment(objective)
229+
230+
exp.evaluate({"x": 1})
231+
exp.evaluate({"x": 5})
232+
exp.evaluate({"x": 3})
233+
234+
assert exp.data.best_trial["score"] == 5.0
235+
assert exp.data.best_score == 5.0
236+
237+
def test_clear_data(self):
238+
"""Test data.clear() resets experiment data."""
239+
from hyperactive.experiment.func import FunctionExperiment
240+
241+
def objective(params):
242+
return params["x"]
243+
244+
exp = FunctionExperiment(objective)
245+
exp.evaluate({"x": 1})
246+
247+
exp.data.clear()
248+
249+
assert exp.data.n_trials == 0
250+
assert exp.data.history == []
251+
252+
def test_get_run(self):
253+
"""Test data.get_run() filters by run."""
254+
from hyperactive.experiment.func import FunctionExperiment
255+
256+
def objective(params):
257+
return params["x"]
258+
259+
exp = FunctionExperiment(objective)
260+
261+
exp.evaluate({"x": 1})
262+
263+
exp.data.new_run()
264+
exp.evaluate({"x": 2})
265+
266+
run0 = exp.data.get_run(0)
267+
run1 = exp.data.get_run(1)
268+
269+
assert len(run0) == 1
270+
assert len(run1) == 1
271+
assert run0[0]["params"] == {"x": 1}
272+
assert run1[0]["params"] == {"x": 2}
273+
274+
275+
class TestOptimizerDataIntegration:
276+
"""Tests for data tracking with optimizers."""
277+
278+
def test_optimizer_records_trials(self):
279+
"""Test that optimizer.solve() records trials to experiment data."""
280+
from hyperactive.experiment.func import FunctionExperiment
281+
from hyperactive.opt import RandomSearch
282+
283+
def objective(params):
284+
return -((params["x"] - 2) ** 2)
285+
286+
exp = FunctionExperiment(objective)
287+
opt = RandomSearch(
288+
experiment=exp,
289+
search_space={"x": [0, 1, 2, 3, 4]},
290+
n_iter=5,
291+
)
292+
293+
opt.solve()
294+
295+
assert exp.data.n_trials > 0
296+
assert all(t["run_id"] == 0 for t in exp.data.history)
297+
298+
def test_multiple_solves_accumulate(self):
299+
"""Test that multiple solve() calls accumulate trials."""
300+
from hyperactive.experiment.func import FunctionExperiment
301+
from hyperactive.opt import RandomSearch
302+
303+
def objective(params):
304+
return -((params["x"] - 2) ** 2)
305+
306+
exp = FunctionExperiment(objective)
307+
opt = RandomSearch(
308+
experiment=exp,
309+
search_space={"x": [0, 1, 2, 3, 4]},
310+
n_iter=3,
311+
)
312+
313+
opt.solve()
314+
n_trials_first = exp.data.n_trials
315+
316+
opt.solve()
317+
318+
assert exp.data.n_trials > n_trials_first
319+
iterations = [t["iteration"] for t in exp.data.history]
320+
assert iterations == list(range(len(iterations)))
321+
322+
def test_data_accumulates_different_optimizers(self):
323+
"""Test data accumulates when using different optimizers."""
324+
from hyperactive.experiment.func import FunctionExperiment
325+
from hyperactive.opt import GridSearch, RandomSearch
326+
327+
def objective(params):
328+
return -((params["x"] - 2) ** 2)
329+
330+
exp = FunctionExperiment(objective)
331+
332+
opt1 = RandomSearch(
333+
experiment=exp,
334+
search_space={"x": [0, 1, 2, 3, 4]},
335+
n_iter=3,
336+
)
337+
opt1.solve()
338+
n_trials_after_opt1 = exp.data.n_trials
339+
340+
opt2 = GridSearch(
341+
experiment=exp,
342+
search_space={"x": [0, 1, 2, 3, 4]},
343+
)
344+
opt2.solve()
345+
346+
assert exp.data.n_trials > n_trials_after_opt1
347+
iterations = [t["iteration"] for t in exp.data.history]
348+
assert iterations == list(range(len(iterations)))

0 commit comments

Comments
 (0)